Prechádzať zdrojové kódy

support for tracing

Christoph Hack 12 rokov pred
rodič
commit
23a4654e20
5 zmenil súbory, kde vykonal 187 pridanie a 62 odobranie
  1. 93 13
      conn.go
  2. 1 0
      frame.go
  3. 8 2
      gocql_test/main.go
  4. 59 1
      marshal.go
  5. 26 46
      session.go

+ 93 - 13
conn.go

@@ -183,15 +183,17 @@ func (c *Conn) execSimple(op operation) (interface{}, error) {
 	if f, err = c.recv(); err != nil {
 		return nil, err
 	}
-	return c.decodeFrame(f)
+	return c.decodeFrame(f, nil)
 }
 
-func (c *Conn) exec(op operation) (interface{}, error) {
-	//fmt.Printf("exec: %#v\n", op)
+func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
 	req, err := op.encodeFrame(c.version, nil)
 	if err != nil {
 		return nil, err
 	}
+	if trace != nil {
+		req[1] |= flagTrace
+	}
 	if len(req) > headerSize && c.compressor != nil {
 		body, err := c.compressor.Encode([]byte(req[headerSize:]))
 		if err != nil {
@@ -224,7 +226,7 @@ func (c *Conn) exec(op operation) (interface{}, error) {
 	if reply.err != nil {
 		return nil, reply.err
 	}
-	return c.decodeFrame(reply.buf)
+	return c.decodeFrame(reply.buf, trace)
 }
 
 func (c *Conn) dispatch(resp frame) {
@@ -241,11 +243,11 @@ func (c *Conn) dispatch(resp frame) {
 }
 
 func (c *Conn) ping() error {
-	_, err := c.exec(&optionsFrame{})
+	_, err := c.exec(&optionsFrame{}, nil)
 	return err
 }
 
-func (c *Conn) prepareStatement(stmt string) (*queryInfo, error) {
+func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 	c.prepMu.Lock()
 	info := c.prep[stmt]
 	if info != nil {
@@ -258,7 +260,7 @@ func (c *Conn) prepareStatement(stmt string) (*queryInfo, error) {
 	c.prep[stmt] = info
 	c.prepMu.Unlock()
 
-	resp, err := c.exec(&prepareFrame{Stmt: stmt})
+	resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
 	if err != nil {
 		return nil, err
 	}
@@ -275,6 +277,48 @@ func (c *Conn) prepareStatement(stmt string) (*queryInfo, error) {
 	return info, nil
 }
 
+func (c *Conn) executeQuery(qry *Query, pageState []byte) *Iter {
+	op := &queryFrame{
+		Stmt:      qry.Stmt,
+		Cons:      qry.Cons,
+		PageSize:  qry.PageSize,
+		PageState: pageState,
+	}
+	if len(qry.Args) > 0 {
+		info, err := c.prepareStatement(qry.Stmt, qry.Trace)
+		if err != nil {
+			return &Iter{err: err}
+		}
+		op.Prepared = info.id
+		op.Values = make([][]byte, len(qry.Args))
+		for i := 0; i < len(qry.Args); i++ {
+			val, err := Marshal(info.args[i].TypeInfo, qry.Args[i])
+			if err != nil {
+				return &Iter{err: err}
+			}
+			op.Values[i] = val
+		}
+	}
+	resp, err := c.exec(op, qry.Trace)
+	if err != nil {
+		return &Iter{err: err}
+	}
+	switch x := resp.(type) {
+	case resultVoidFrame:
+		return &Iter{}
+	case resultRowsFrame:
+		iter := &Iter{columns: x.Columns, rows: x.Rows, pageState: x.PagingState}
+		return iter
+	case resultKeyspaceFrame:
+		c.cluster.HandleKeyspace(c, x.Keyspace)
+		return &Iter{}
+	case error:
+		return &Iter{err: x}
+	default:
+		return &Iter{err: ErrProtocol}
+	}
+}
+
 func (c *Conn) Pick(qry *Query) *Conn {
 	return c
 }
@@ -288,7 +332,7 @@ func (c *Conn) Address() string {
 }
 
 func (c *Conn) UseKeyspace(keyspace string) error {
-	resp, err := c.exec(&queryFrame{Stmt: "USE " + keyspace, Cons: Any})
+	resp, err := c.exec(&queryFrame{Stmt: "USE " + keyspace, Cons: Any}, nil)
 	if err != nil {
 		return err
 	}
@@ -315,7 +359,7 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		var info *queryInfo
 		if len(entry.Args) > 0 {
 			var err error
-			info, err = c.prepareStatement(entry.Stmt)
+			info, err = c.prepareStatement(entry.Stmt, nil)
 			if err != nil {
 				return err
 			}
@@ -336,7 +380,7 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	}
 	f.writeConsistency(batch.Cons)
 
-	resp, err := c.exec(f)
+	resp, err := c.exec(f, nil)
 	if err != nil {
 		return err
 	}
@@ -350,7 +394,7 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	}
 }
 
-func (c *Conn) decodeFrame(f frame) (rval interface{}, err error) {
+func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error) {
 	defer func() {
 		if r := recover(); r != nil {
 			if e, ok := r.(error); ok && e == ErrProtocol {
@@ -371,6 +415,16 @@ func (c *Conn) decodeFrame(f frame) (rval interface{}, err error) {
 			f = frame(buf)
 		}
 	}
+	if flags&flagTrace != 0 {
+		if len(f) < 16 {
+			return nil, ErrProtocol
+		}
+		var traceId []byte
+		traceId, f = f[:16], f[16:]
+		if err := c.gatherTrace(traceId, trace); err != nil {
+			return nil, err
+		}
+	}
 
 	switch op {
 	case opReady:
@@ -387,8 +441,8 @@ func (c *Conn) decodeFrame(f frame) (rval interface{}, err error) {
 				values[i] = f.readBytes()
 			}
 			rows := make([][][]byte, numRows)
-			for i := 0; i < len(values); i += len(columns) {
-				rows[i] = values[i : i+len(columns)]
+			for i := 0; i < numRows; i++ {
+				rows[i], values = values[:len(columns)], values[len(columns):]
 			}
 			return resultRowsFrame{columns, rows, pageState}, nil
 		case resultKindKeyspace:
@@ -412,6 +466,32 @@ func (c *Conn) decodeFrame(f frame) (rval interface{}, err error) {
 	}
 }
 
+func (c *Conn) gatherTrace(traceId []byte, trace Tracer) error {
+	if trace == nil {
+		return nil
+	}
+	iter := c.executeQuery(&Query{
+		Stmt: `SELECT event_id, activity, source, source_elapsed
+			FROM system_traces.events
+			WHERE session_id = ?`,
+		Args: []interface{}{traceId},
+		Cons: One,
+	}, nil)
+	var (
+		time     time.Time
+		activity string
+		source   string
+		elapsed  int
+	)
+	for iter.Scan(&time, &activity, &source, &elapsed) {
+		trace.Trace(time, activity, source, elapsed)
+	}
+	if err := iter.Close(); err != nil {
+		return err
+	}
+	return nil
+}
+
 type queryInfo struct {
 	id   []byte
 	args []ColumnInfo

+ 1 - 0
frame.go

@@ -37,6 +37,7 @@ const (
 
 	flagQueryValues uint8 = 1
 	flagCompress    uint8 = 1
+	flagTrace       uint8 = 2
 	flagPageSize    uint8 = 4
 	flagPageState   uint8 = 8
 	flagHasMore     uint8 = 2

+ 8 - 2
gocql_test/main.go

@@ -6,6 +6,7 @@ package main
 
 import (
 	"log"
+	"os"
 	"reflect"
 	"sort"
 	"time"
@@ -165,12 +166,17 @@ func main() {
 		}
 	}
 
+	trace := gocql.NewTraceWriter(os.Stdout)
+	if err := session.Query("SELECT COUNT(*) FROM page").Trace(trace).Scan(&count); err != nil {
+		log.Fatal("trace: ", err)
+	}
+
 	if err := session.Query("CREATE TABLE large (id int primary key)").Exec(); err != nil {
-		log.Fatal("create table", err)
+		log.Fatal("create table: ", err)
 	}
 	for i := 0; i < 100; i++ {
 		if err := session.Query("INSERT INTO large (id) VALUES (?)", i).Exec(); err != nil {
-			log.Fatal("insert", err)
+			log.Fatal("insert: ", err)
 		}
 	}
 	iter := session.Query("SELECT id FROM large").PageSize(10).Iter()

+ 59 - 1
marshal.go

@@ -49,6 +49,8 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 		return marshalList(info, value)
 	case TypeMap:
 		return marshalMap(info, value)
+	case TypeUUID:
+		return marshalUUID(info, value)
 	}
 	// TODO(tux21b): add the remaining types
 	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
@@ -80,6 +82,10 @@ func Unmarshal(info *TypeInfo, data []byte, value interface{}) error {
 		return unmarshalList(info, data, value)
 	case TypeMap:
 		return unmarshalMap(info, data, value)
+	case TypeTimeUUID:
+		return unmarshalTimeUUID(info, data, value)
+	case TypeInet:
+		return unmarshalInet(info, data, value)
 	}
 	// TODO(tux21b): add the remaining types
 	return fmt.Errorf("can not unmarshal %s into %T", info, value)
@@ -803,7 +809,6 @@ func unmarshalList(info *TypeInfo, data []byte, value interface{}) error {
 		}
 		return nil
 	}
-
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }
 
@@ -895,6 +900,57 @@ func unmarshalMap(info *TypeInfo, data []byte, value interface{}) error {
 	return nil
 }
 
+func marshalUUID(info *TypeInfo, value interface{}) ([]byte, error) {
+	if val, ok := value.([]byte); ok && len(val) == 16 {
+		return val, nil
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func unmarshalTimeUUID(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *time.Time:
+		if len(data) != 16 {
+			return unmarshalErrorf("invalid timeuuid")
+		}
+		if version := int(data[6] & 0xF0 >> 4); version != 1 {
+			return unmarshalErrorf("invalid timeuuid")
+		}
+		timestamp := uint64(data[0])<<24 + uint64(data[1])<<16 +
+			uint64(data[2])<<8 + uint64(data[3]) + uint64(data[4])<<40 +
+			uint64(data[5])<<32 + uint64(data[7])<<48 + uint64(data[6]&0x0F)<<56
+		if timestamp == 0 {
+			*v = time.Time{}
+			return nil
+		}
+		sec := timestamp / 10000000
+		nsec := timestamp - sec
+		*v = time.Unix(int64(sec)+timeBase, int64(nsec))
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
+func unmarshalInet(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *string:
+		if len(data) == 0 {
+			*v = ""
+			return nil
+		}
+		if len(data) == 4 {
+			*v = fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3])
+			return nil
+		}
+		// TODO: support IPv6
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
 // TypeInfo describes a Cassandra specific data type.
 type TypeInfo struct {
 	Type   Type
@@ -1005,3 +1061,5 @@ func (m UnmarshalError) Error() string {
 func unmarshalErrorf(format string, args ...interface{}) UnmarshalError {
 	return UnmarshalError(fmt.Sprintf(format, args...))
 }
+
+var timeBase = time.Date(1582, time.October, 15, 0, 0, 0, 0, time.UTC).Unix()

+ 26 - 46
session.go

@@ -6,6 +6,9 @@ package gocql
 
 import (
 	"errors"
+	"fmt"
+	"io"
+	"time"
 )
 
 // Session is the interface used by users to interact with the database.
@@ -51,55 +54,16 @@ func (s *Session) executeQuery(qry *Query, pageState []byte) *Iter {
 	if qry.Cons == 0 {
 		qry.Cons = s.Cons
 	}
-
 	conn := s.Node.Pick(qry)
 	if conn == nil {
 		return &Iter{err: ErrUnavailable}
 	}
-	op := &queryFrame{
-		Stmt:      qry.Stmt,
-		Cons:      qry.Cons,
-		PageSize:  qry.PageSize,
-		PageState: pageState,
-	}
-	if len(qry.Args) > 0 {
-		info, err := conn.prepareStatement(qry.Stmt)
-		if err != nil {
-			return &Iter{err: err}
-		}
-		op.Prepared = info.id
-		op.Values = make([][]byte, len(qry.Args))
-		for i := 0; i < len(qry.Args); i++ {
-			val, err := Marshal(info.args[i].TypeInfo, qry.Args[i])
-			if err != nil {
-				return &Iter{err: err}
-			}
-			op.Values[i] = val
-		}
-	}
-	resp, err := conn.exec(op)
-	if err != nil {
-		return &Iter{err: err}
-	}
-	switch x := resp.(type) {
-	case resultVoidFrame:
-		return &Iter{}
-	case resultRowsFrame:
-		iter := &Iter{columns: x.Columns, rows: x.Rows}
-		if len(x.PagingState) > 0 {
-			iter.session = s
-			iter.qry = qry
-			iter.pageState = x.PagingState
-		}
-		return iter
-	case resultKeyspaceFrame:
-		conn.cluster.HandleKeyspace(conn, x.Keyspace)
-		return &Iter{}
-	case error:
-		return &Iter{err: x}
-	default:
-		return &Iter{err: ErrProtocol}
+	iter := conn.executeQuery(qry, pageState)
+	if len(iter.pageState) > 0 {
+		iter.qry = qry
+		iter.session = s
 	}
+	return iter
 }
 
 func (s *Session) ExecuteBatch(batch *Batch) error {
@@ -116,7 +80,7 @@ type Query struct {
 	Cons     Consistency
 	Token    string
 	PageSize int
-	Trace    bool
+	Trace    Tracer
 }
 
 func NewQuery(stmt string, args ...interface{}) *Query {
@@ -146,7 +110,7 @@ func (b QueryBuilder) Token(token string) QueryBuilder {
 	return b
 }
 
-func (b QueryBuilder) Trace(trace bool) QueryBuilder {
+func (b QueryBuilder) Trace(trace Tracer) QueryBuilder {
 	b.qry.Trace = trace
 	return b
 }
@@ -282,6 +246,22 @@ type ColumnInfo struct {
 	TypeInfo *TypeInfo
 }
 
+type Tracer interface {
+	Trace(time time.Time, activity string, source string, elapsed int)
+}
+
+type traceWriter struct {
+	w io.Writer
+}
+
+func NewTraceWriter(w io.Writer) Tracer {
+	return traceWriter{w}
+}
+
+func (t traceWriter) Trace(time time.Time, activity string, source string, elapsed int) {
+	fmt.Fprintf(t.w, "%s: %s (source: %s, elapsed: %d)\n", time, activity, source, elapsed)
+}
+
 type Error struct {
 	Code    int
 	Message string