Browse Source

adding observer to query (#1023)

* introduced query observer

* PR Feedback

* TestObserve_Pagination

* TestObserve_Pagination fixed

* query.attempt removing unnecessary iter.pos

* Batch observation

* tidy up

* PR Feedback
Javier Zunzunegui 8 years ago
parent
commit
2f405b2923
8 changed files with 354 additions and 15 deletions
  1. 2 1
      AUTHORS
  2. 212 0
      cassandra_test.go
  3. 8 0
      cluster.go
  4. 1 1
      cluster_test.go
  5. 8 3
      common_test.go
  6. 3 2
      query_executor.go
  7. 105 6
      session.go
  8. 15 2
      session_test.go

+ 2 - 1
AUTHORS

@@ -99,4 +99,5 @@ Ben Krebsbach <ben.krebsbach@gmail.com>
 Vivian Mathews <vivian.mathews.3@gmail.com>
 Sascha Steinbiss <satta@debian.org>
 Seth Rosenblum <seth.t.rosenblum@gmail.com>
-Luke Hines <lukehines@protonmail.com>
+Javier Zunzunegui <javier.zunzunegui.b@gmail.com>
+Luke Hines <lukehines@protonmail.com>

+ 212 - 0
cassandra_test.go

@@ -5,6 +5,8 @@ package gocql
 import (
 	"bytes"
 	"context"
+	"errors"
+	"fmt"
 	"io"
 	"math"
 	"math/big"
@@ -185,6 +187,151 @@ func TestTracing(t *testing.T) {
 	}
 }
 
+func TestObserve(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, `CREATE TABLE gocql_test.observe (id int primary key)`); err != nil {
+		t.Fatal("create:", err)
+	}
+
+	var (
+		observedErr      error
+		observedKeyspace string
+		observedStmt     string
+	)
+
+	const keyspace = "gocql_test"
+
+	resetObserved := func() {
+		observedErr = errors.New("placeholder only") // used to distinguish err=nil cases
+		observedKeyspace = ""
+		observedStmt = ""
+	}
+
+	observer := funcQueryObserver(func(ctx context.Context, o ObservedQuery) {
+		observedKeyspace = o.Keyspace
+		observedStmt = o.Statement
+		observedErr = o.Err
+	})
+
+	// select before inserted, will error but the reporting is err=nil as the query is valid
+	resetObserved()
+	var value int
+	if err := session.Query(`SELECT id FROM observe WHERE id = ?`, 43).Observer(observer).Scan(&value); err == nil {
+		t.Fatal("select: expected error")
+	} else if observedErr != nil {
+		t.Fatalf("select: observed error expected nil, got %q", observedErr)
+	} else if observedKeyspace != keyspace {
+		t.Fatal("select: unexpected observed keyspace", observedKeyspace)
+	} else if observedStmt != `SELECT id FROM observe WHERE id = ?` {
+		t.Fatal("select: unexpected observed stmt", observedStmt)
+	}
+
+	resetObserved()
+	if err := session.Query(`INSERT INTO observe (id) VALUES (?)`, 42).Observer(observer).Exec(); err != nil {
+		t.Fatal("insert:", err)
+	} else if observedErr != nil {
+		t.Fatal("insert:", observedErr)
+	} else if observedKeyspace != keyspace {
+		t.Fatal("insert: unexpected observed keyspace", observedKeyspace)
+	} else if observedStmt != `INSERT INTO observe (id) VALUES (?)` {
+		t.Fatal("insert: unexpected observed stmt", observedStmt)
+	}
+
+	resetObserved()
+	value = 0
+	if err := session.Query(`SELECT id FROM observe WHERE id = ?`, 42).Observer(observer).Scan(&value); err != nil {
+		t.Fatal("select:", err)
+	} else if value != 42 {
+		t.Fatalf("value: expected %d, got %d", 42, value)
+	} else if observedErr != nil {
+		t.Fatal("select:", observedErr)
+	} else if observedKeyspace != keyspace {
+		t.Fatal("select: unexpected observed keyspace", observedKeyspace)
+	} else if observedStmt != `SELECT id FROM observe WHERE id = ?` {
+		t.Fatal("select: unexpected observed stmt", observedStmt)
+	}
+
+	// also works from session observer
+	resetObserved()
+	oSession := createSession(t, func(config *ClusterConfig) { config.QueryObserver = observer })
+	if err := oSession.Query(`SELECT id FROM observe WHERE id = ?`, 42).Scan(&value); err != nil {
+		t.Fatal("select:", err)
+	} else if observedErr != nil {
+		t.Fatal("select:", err)
+	} else if observedKeyspace != keyspace {
+		t.Fatal("select: unexpected observed keyspace", observedKeyspace)
+	} else if observedStmt != `SELECT id FROM observe WHERE id = ?` {
+		t.Fatal("select: unexpected observed stmt", observedStmt)
+	}
+
+	// reports errors when the query is poorly formed
+	resetObserved()
+	value = 0
+	if err := session.Query(`SELECT id FROM unknown_table WHERE id = ?`, 42).Observer(observer).Scan(&value); err == nil {
+		t.Fatal("select: expecting error")
+	} else if observedErr == nil {
+		t.Fatal("select: expecting observed error")
+	} else if observedKeyspace != keyspace {
+		t.Fatal("select: unexpected observed keyspace", observedKeyspace)
+	} else if observedStmt != `SELECT id FROM unknown_table WHERE id = ?` {
+		t.Fatal("select: unexpected observed stmt", observedStmt)
+	}
+}
+
+func TestObserve_Pagination(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, `CREATE TABLE gocql_test.observe2 (id int, PRIMARY KEY (id))`); err != nil {
+		t.Fatal("create:", err)
+	}
+
+	var observedRows int
+
+	resetObserved := func() {
+		observedRows = -1
+	}
+
+	observer := funcQueryObserver(func(ctx context.Context, o ObservedQuery) {
+		observedRows = o.Rows
+	})
+
+	// insert 100 entries, relevant for pagination
+	for i := 0; i < 50; i++ {
+		if err := session.Query(`INSERT INTO observe2 (id) VALUES (?)`, i).Exec(); err != nil {
+			t.Fatal("insert:", err)
+		}
+	}
+
+	resetObserved()
+
+	// read the 100 entries in paginated entries of size 10. Expecting 5 observations, each with 10 rows
+	scanner := session.Query(`SELECT id FROM observe2 LIMIT 100`).
+		Observer(observer).
+		PageSize(10).
+		Iter().Scanner()
+	for i := 0; i < 50; i++ {
+		if !scanner.Next() {
+			t.Fatalf("next: should still be true: %d", i)
+		}
+		if i%10 == 0 {
+			if observedRows != 10 {
+				t.Fatalf("next: expecting a paginated query with 10 entries, got: %d (%d)", observedRows, i)
+			}
+		} else if observedRows != -1 {
+			t.Fatalf("next: not expecting paginated query (-1 entries), got: %d", observedRows)
+		}
+
+		resetObserved()
+	}
+
+	if scanner.Next() {
+		t.Fatal("next: no more entries where expected")
+	}
+}
+
 func TestPaging(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()
@@ -1644,6 +1791,71 @@ func TestBatchStats(t *testing.T) {
 	}
 }
 
+type funcBatchObserver func(context.Context, ObservedBatch)
+
+func (f funcBatchObserver) ObserveBatch(ctx context.Context, o ObservedBatch) {
+	f(ctx, o)
+}
+
+func TestBatchObserve(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if session.cfg.ProtoVersion == 1 {
+		t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
+	}
+
+	if err := createTable(session, `CREATE TABLE gocql_test.batch_observe_table (id int, other int, PRIMARY KEY (id))`); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	type observation struct {
+		observedErr      error
+		observedKeyspace string
+		observedStmts    []string
+	}
+
+	var observedBatch *observation
+
+	batch := NewBatch(LoggedBatch)
+	batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) {
+		if observedBatch != nil {
+			t.Fatal("batch observe called more than once")
+		}
+
+		observedBatch = &observation{
+			observedKeyspace: o.Keyspace,
+			observedStmts:    o.Statements,
+			observedErr:      o.Err,
+		}
+	}))
+	for i := 0; i < 100; i++ {
+		// hard coding 'i' into one of the values for better  testing of observation
+		batch.Query(fmt.Sprintf(`INSERT INTO batch_observe_table (id,other) VALUES (?,%d)`, i), i)
+	}
+
+	if err := session.ExecuteBatch(batch); err != nil {
+		t.Fatal("execute batch:", err)
+	}
+	if observedBatch == nil {
+		t.Fatal("batch observation has not been called")
+	}
+	if len(observedBatch.observedStmts) != 100 {
+		t.Fatal("expecting 100 observed statements, got", len(observedBatch.observedStmts))
+	}
+	if observedBatch.observedErr != nil {
+		t.Fatal("not expecting to observe an error", observedBatch.observedErr)
+	}
+	if observedBatch.observedKeyspace != "gocql_test" {
+		t.Fatalf("expecting keyspace 'gocql_test', got %q", observedBatch.observedKeyspace)
+	}
+	for i, stmt := range observedBatch.observedStmts {
+		if stmt != fmt.Sprintf(`INSERT INTO batch_observe_table (id,other) VALUES (?,%d)`, i) {
+			t.Fatal("unexpected query", stmt)
+		}
+	}
+}
+
 //TestNilInQuery tests to see that a nil value passed to a query is handled by Cassandra
 //TODO validate the nil value by reading back the nil. Need to fix Unmarshalling.
 func TestNilInQuery(t *testing.T) {

+ 8 - 0
cluster.go

@@ -115,6 +115,14 @@ type ClusterConfig struct {
 	// See https://issues.apache.org/jira/browse/CASSANDRA-10786
 	DisableSkipMetadata bool
 
+	// QueryObserver will set the provided query observer on all queries created from this session.
+	// Use it to collect metrics / stats from queries by providing an implementation of QueryObserver.
+	QueryObserver QueryObserver
+
+	// BatchObserver will set the provided batch observer on all queries created from this session.
+	// Use it to collect metrics / stats from batche queries by providing an implementation of BatchObserver.
+	BatchObserver BatchObserver
+
 	// internal config for testing
 	disableControlConn bool
 }

+ 1 - 1
cluster_test.go

@@ -1,9 +1,9 @@
 package gocql
 
 import (
+	"net"
 	"testing"
 	"time"
-	"net"
 )
 
 func TestNewCluster_Defaults(t *testing.T) {

+ 8 - 3
common_test.go

@@ -70,7 +70,7 @@ func createTable(s *Session, table string) error {
 	return nil
 }
 
-func createCluster() *ClusterConfig {
+func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig {
 	cluster := NewCluster(clusterHosts...)
 	cluster.ProtoVersion = *flagProto
 	cluster.CQLVersion = *flagCQL
@@ -90,6 +90,11 @@ func createCluster() *ClusterConfig {
 	}
 
 	cluster = addSslOptions(cluster)
+
+	for _, opt := range opts {
+		opt(cluster)
+	}
+
 	return cluster
 }
 
@@ -140,8 +145,8 @@ func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session {
 	return session
 }
 
-func createSession(tb testing.TB) *Session {
-	cluster := createCluster()
+func createSession(tb testing.TB, opts ...func(config *ClusterConfig)) *Session {
+	cluster := createCluster(opts...)
 	return createSessionFromCluster(cluster, tb)
 }
 

+ 3 - 2
query_executor.go

@@ -6,7 +6,7 @@ import (
 
 type ExecutableQuery interface {
 	execute(conn *Conn) *Iter
-	attempt(time.Duration)
+	attempt(keyspace string, end, start time.Time, iter *Iter)
 	retryPolicy() RetryPolicy
 	GetRoutingKey() ([]byte, error)
 	Keyspace() string
@@ -21,8 +21,9 @@ type queryExecutor struct {
 func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
 	start := time.Now()
 	iter := qry.execute(conn)
+	end := time.Now()
 
-	qry.attempt(time.Since(start))
+	qry.attempt(q.pool.keyspace, end, start, iter)
 
 	return iter
 }

+ 105 - 6
session.go

@@ -37,6 +37,8 @@ type Session struct {
 	routingKeyInfoCache routingKeyInfoLRU
 	schemaDescriber     *schemaDescriber
 	trace               Tracer
+	queryObserver       QueryObserver
+	batchObserver       BatchObserver
 	hostSource          *ringDescriber
 	stmtsLRU            *preparedLRU
 
@@ -134,6 +136,9 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		policy: cfg.PoolConfig.HostSelectionPolicy,
 	}
 
+	s.queryObserver = cfg.QueryObserver
+	s.batchObserver = cfg.BatchObserver
+
 	//Check the TLS Config before trying to connect to anything external
 	connCfg, err := connConfig(&s.cfg)
 	if err != nil {
@@ -314,6 +319,7 @@ func (s *Session) Query(stmt string, values ...interface{}) *Query {
 	qry.session = s
 	qry.pageSize = s.pageSize
 	qry.trace = s.trace
+	qry.observer = s.queryObserver
 	qry.prefetch = s.prefetch
 	qry.rt = s.cfg.RetryPolicy
 	qry.serialCons = s.cfg.SerialConsistency
@@ -338,7 +344,7 @@ type QueryInfo struct {
 func (s *Session) Bind(stmt string, b func(q *QueryInfo) ([]interface{}, error)) *Query {
 	s.mu.RLock()
 	qry := &Query{stmt: stmt, binding: b, cons: s.cons,
-		session: s, pageSize: s.pageSize, trace: s.trace,
+		session: s, pageSize: s.pageSize, trace: s.trace, observer: s.queryObserver,
 		prefetch: s.prefetch, rt: s.cfg.RetryPolicy}
 	s.mu.RUnlock()
 	return qry
@@ -383,7 +389,7 @@ func (s *Session) Closed() bool {
 	return closed
 }
 
-func (s *Session) executeQuery(qry *Query) *Iter {
+func (s *Session) executeQuery(qry *Query) (it *Iter) {
 	// fail fast
 	if s.Closed() {
 		return &Iter{err: ErrSessionClosed}
@@ -656,6 +662,7 @@ type Query struct {
 	pageState             []byte
 	prefetch              float64
 	trace                 Tracer
+	observer              QueryObserver
 	session               *Session
 	rt                    RetryPolicy
 	binding               func(q *QueryInfo) ([]interface{}, error)
@@ -709,6 +716,13 @@ func (q *Query) Trace(trace Tracer) *Query {
 	return q
 }
 
+// Observer enables query-level observer on this query.
+// The provided observer will be called every time this query is executed.
+func (q *Query) Observer(observer QueryObserver) *Query {
+	q.observer = observer
+	return q
+}
+
 // PageSize will tell the iterator to fetch the result in pages of size n.
 // This is useful for iterating over large result sets, but setting the
 // page size too low might decrease the performance. This feature is only
@@ -759,10 +773,21 @@ func (q *Query) execute(conn *Conn) *Iter {
 	return conn.executeQuery(q)
 }
 
-func (q *Query) attempt(d time.Duration) {
+func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter) {
 	q.attempts++
-	q.totalLatency += d.Nanoseconds()
+	q.totalLatency += end.Sub(start).Nanoseconds()
 	// TODO: track latencies per host and things as well instead of just total
+
+	if q.observer != nil {
+		q.observer.ObserveQuery(q.context, ObservedQuery{
+			Keyspace:  keyspace,
+			Statement: q.stmt,
+			Start:     start,
+			End:       end,
+			Rows:      iter.numRows,
+			Err:       iter.err,
+		})
+	}
 }
 
 func (q *Query) retryPolicy() RetryPolicy {
@@ -1334,6 +1359,7 @@ type Batch struct {
 	Entries               []BatchEntry
 	Cons                  Consistency
 	rt                    RetryPolicy
+	observer              BatchObserver
 	attempts              int
 	totalLatency          int64
 	serialCons            SerialConsistency
@@ -1357,6 +1383,7 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
 		Type:             typ,
 		rt:               s.cfg.RetryPolicy,
 		serialCons:       s.cfg.SerialConsistency,
+		observer: s.batchObserver,
 		Cons:             s.cons,
 		defaultTimestamp: s.cfg.DefaultTimestamp,
 		keyspace:         s.cfg.Keyspace,
@@ -1365,6 +1392,13 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
 	return batch
 }
 
+// Observer enables batch-level observer on this batch.
+// The provided observer will be called every time this batched query is executed.
+func (b *Batch) Observer(observer BatchObserver) *Batch {
+	b.observer = observer
+	return b
+}
+
 func (b *Batch) Keyspace() string {
 	return b.keyspace
 }
@@ -1457,10 +1491,28 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch {
 	return b
 }
 
-func (b *Batch) attempt(d time.Duration) {
+func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter) {
 	b.attempts++
-	b.totalLatency += d.Nanoseconds()
+	b.totalLatency += end.Sub(start).Nanoseconds()
 	// TODO: track latencies per host and things as well instead of just total
+
+	if b.observer == nil {
+		return
+	}
+
+	statements := make([]string, len(b.Entries))
+	for i, entry := range b.Entries {
+		statements[i] = entry.Stmt
+	}
+
+	b.observer.ObserveBatch(b.context, ObservedBatch{
+		Keyspace:   keyspace,
+		Statements: statements,
+		Start:      start,
+		End:        end,
+		// Rows not used in batch observations // TODO - might be able to support it when using BatchCAS
+		Err: iter.err,
+	})
 }
 
 func (b *Batch) GetRoutingKey() ([]byte, error) {
@@ -1596,6 +1648,53 @@ func (t *traceWriter) Trace(traceId []byte) {
 	}
 }
 
+type ObservedQuery struct {
+	Keyspace  string
+	Statement string
+
+	Start time.Time // time immediately before the query was called
+	End   time.Time // time immediately after the query returned
+
+	// Rows is the number of rows in the current iter.
+	// In paginated queries, rows from previous scans are not counted.
+	// Rows is not used in batch queries and remains at the default value
+	Rows int
+
+	// Err is the error in the query.
+	// It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error
+	Err error
+}
+
+// QueryObserver is the interface implemented by query observers / stat collectors.
+type QueryObserver interface {
+	// ObserveQuery gets called on every query to cassandra, including all queries in an iterator when paging is enabled.
+	// It doesn't get called if there is no query because the session is closed or there are no connections available.
+	// The error reported only shows query errors, i.e. if a SELECT is valid but finds no matches it will be nil.
+	ObserveQuery(context.Context, ObservedQuery)
+}
+
+type ObservedBatch struct {
+	Keyspace   string
+	Statements []string
+
+	Start time.Time // time immediately before the batch query was called
+	End   time.Time // time immediately after the batch query returned
+
+	// Err is the error in the batch query.
+	// It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error
+	Err error
+}
+
+// BatchObserver is the interface implemented by batch observers / stat collectors.
+type BatchObserver interface {
+	// ObserveBatch gets called on every batch query to cassandra.
+	// It also gets called once for each query in a batch.
+	// It doesn't get called if there is no query because the session is closed or there are no connections available.
+	// The error reported only shows query errors, i.e. if a SELECT is valid but finds no matches it will be nil.
+	// Unlike QueryObserver.ObserveQuery it does no reporting on rows read.
+	ObserveBatch(context.Context, ObservedBatch)
+}
+
 type Error struct {
 	Code    int
 	Message string

+ 15 - 2
session_test.go

@@ -3,6 +3,7 @@
 package gocql
 
 import (
+	"context"
 	"fmt"
 	"testing"
 )
@@ -89,6 +90,12 @@ func TestSessionAPI(t *testing.T) {
 	}
 }
 
+type funcQueryObserver func(context.Context, ObservedQuery)
+
+func (f funcQueryObserver) ObserveQuery(ctx context.Context, o ObservedQuery) {
+	f(ctx, o)
+}
+
 func TestQueryBasicAPI(t *testing.T) {
 	qry := &Query{}
 
@@ -116,6 +123,12 @@ func TestQueryBasicAPI(t *testing.T) {
 		t.Fatalf("expected Query.Trace to be '%v', got '%v'", trace, qry.trace)
 	}
 
+	observer := funcQueryObserver(func(context.Context, ObservedQuery) {})
+	qry.Observer(observer)
+	if qry.observer == nil { // can't compare func to func, checking not nil instead
+		t.Fatal("expected Query.QueryObserver to be set, got nil")
+	}
+
 	qry.PageSize(10)
 	if qry.pageSize != 10 {
 		t.Fatalf("expected Query.PageSize to be 10, got %v", qry.pageSize)
@@ -202,7 +215,7 @@ func TestBatchBasicAPI(t *testing.T) {
 
 	b.Query("test", 1)
 	if b.Entries[0].Stmt != "test" {
-		t.Fatalf("expected batch.Entries[0].Stmt to be 'test', got '%v'", b.Entries[0].Stmt)
+		t.Fatalf("expected batch.Entries[0].Statement to be 'test', got '%v'", b.Entries[0].Stmt)
 	} else if b.Entries[0].Args[0].(int) != 1 {
 		t.Fatalf("expected batch.Entries[0].Args[0] to be 1, got %v", b.Entries[0].Args[0])
 	}
@@ -212,7 +225,7 @@ func TestBatchBasicAPI(t *testing.T) {
 	})
 
 	if b.Entries[1].Stmt != "test2" {
-		t.Fatalf("expected batch.Entries[1].Stmt to be 'test2', got '%v'", b.Entries[1].Stmt)
+		t.Fatalf("expected batch.Entries[1].Statement to be 'test2', got '%v'", b.Entries[1].Stmt)
 	} else if b.Entries[1].binding == nil {
 		t.Fatal("expected batch.Entries[1].binding to be defined, got nil")
 	}