浏览代码

Merge pull request #202 from relops/expose_query_info

Expose query info declaratively
Ben Hood 11 年之前
父节点
当前提交
29c14280b6
共有 3 个文件被更改,包括 241 次插入32 次删除
  1. 169 7
      cassandra_test.go
  2. 46 23
      conn.go
  3. 26 2
      session.go

+ 169 - 7
cassandra_test.go

@@ -14,6 +14,7 @@ import (
 	"sync"
 	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
+	"unicode"
 )
 )
 
 
 var (
 var (
@@ -551,6 +552,167 @@ func TestScanCASWithNilArguments(t *testing.T) {
 	}
 	}
 }
 }
 
 
+//TestStaticQueryInfo makes sure that the application can manually bind query parameters using the simplest possible static binding strategy
+func TestStaticQueryInfo(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query("CREATE TABLE static_query_info (id int, value text, PRIMARY KEY (id))").Exec(); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	if err := session.Query("INSERT INTO static_query_info (id, value) VALUES (?, ?)", 113, "foo").Exec(); err != nil {
+		t.Fatalf("insert into static_query_info failed, err '%v'", err)
+	}
+
+	autobinder := func(q *QueryInfo) ([]interface{}, error) {
+		values := make([]interface{}, 1)
+		values[0] = 113
+		return values, nil
+	}
+
+	qry := session.Bind("SELECT id, value FROM static_query_info WHERE id = ?", autobinder)
+
+	if err := qry.Exec(); err != nil {
+		t.Fatalf("expose query info failed, error '%v'", err)
+	}
+
+	iter := qry.Iter()
+
+	var id int
+	var value string
+
+	iter.Scan(&id, &value)
+
+	if err := iter.Close(); err != nil {
+		t.Fatalf("query with exposed info failed, err '%v'", err)
+	}
+
+	if value != "foo" {
+		t.Fatalf("Expected value %s, but got %s", "foo", value)
+	}
+
+}
+
+type ClusteredKeyValue struct {
+	Id      int
+	Cluster int
+	Value   string
+}
+
+func (kv *ClusteredKeyValue) Bind(q *QueryInfo) ([]interface{}, error) {
+	values := make([]interface{}, len(q.Args))
+
+	for i, info := range q.Args {
+		fieldName := upcaseInitial(info.Name)
+		value := reflect.ValueOf(kv)
+		field := reflect.Indirect(value).FieldByName(fieldName)
+		values[i] = field.Addr().Interface()
+	}
+
+	return values, nil
+}
+
+func upcaseInitial(str string) string {
+	for i, v := range str {
+		return string(unicode.ToUpper(v)) + str[i+1:]
+	}
+	return ""
+}
+
+//TestBoundQueryInfo makes sure that the application can manually bind query parameters using the query meta data supplied at runtime
+func TestBoundQueryInfo(t *testing.T) {
+
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query("CREATE TABLE clustered_query_info (id int, cluster int, value text, PRIMARY KEY (id, cluster))").Exec(); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	write := &ClusteredKeyValue{Id: 200, Cluster: 300, Value: "baz"}
+
+	insert := session.Bind("INSERT INTO clustered_query_info (id, cluster, value) VALUES (?, ?,?)", write.Bind)
+
+	if err := insert.Exec(); err != nil {
+		t.Fatalf("insert into clustered_query_info failed, err '%v'", err)
+	}
+
+	read := &ClusteredKeyValue{Id: 200, Cluster: 300}
+
+	qry := session.Bind("SELECT id, cluster, value FROM clustered_query_info WHERE id = ? and cluster = ?", read.Bind)
+
+	iter := qry.Iter()
+
+	var id, cluster int
+	var value string
+
+	iter.Scan(&id, &cluster, &value)
+
+	if err := iter.Close(); err != nil {
+		t.Fatalf("query with clustered_query_info info failed, err '%v'", err)
+	}
+
+	if value != "baz" {
+		t.Fatalf("Expected value %s, but got %s", "baz", value)
+	}
+
+}
+
+//TestBatchQueryInfo makes sure that the application can manually bind query parameters when executing in a batch
+func TestBatchQueryInfo(t *testing.T) {
+
+	if *flagProto == 1 {
+		t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query("CREATE TABLE batch_query_info (id int, cluster int, value text, PRIMARY KEY (id, cluster))").Exec(); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	write := func(q *QueryInfo) ([]interface{}, error) {
+		values := make([]interface{}, 3)
+		values[0] = 4000
+		values[1] = 5000
+		values[2] = "bar"
+		return values, nil
+	}
+
+	batch := session.NewBatch(LoggedBatch)
+	batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write)
+
+	if err := session.ExecuteBatch(batch); err != nil {
+		t.Fatalf("batch insert into batch_query_info failed, err '%v'", err)
+	}
+
+	read := func(q *QueryInfo) ([]interface{}, error) {
+		values := make([]interface{}, 2)
+		values[0] = 4000
+		values[1] = 5000
+		return values, nil
+	}
+
+	qry := session.Bind("SELECT id, cluster, value FROM batch_query_info WHERE id = ? and cluster = ?", read)
+
+	iter := qry.Iter()
+
+	var id, cluster int
+	var value string
+
+	iter.Scan(&id, &cluster, &value)
+
+	if err := iter.Close(); err != nil {
+		t.Fatalf("query with batch_query_info info failed, err '%v'", err)
+	}
+
+	if value != "bar" {
+		t.Fatalf("Expected value %s, but got %s", "bar", value)
+	}
+}
+
 func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) {
 func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) {
 	if err := session.Query(`CREATE TABLE ` + table + ` (
 	if err := session.Query(`CREATE TABLE ` + table + ` (
 			foo   varchar,
 			foo   varchar,
@@ -565,9 +727,9 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	stmtsLRU.mu.Lock()
 	stmtsLRU.mu.Lock()
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
 	stmtsLRU.mu.Unlock()
 	stmtsLRU.mu.Unlock()
-	flight.info = &queryInfo{
-		id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
-		args: []ColumnInfo{ColumnInfo{
+	flight.info = &QueryInfo{
+		Id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
+		Args: []ColumnInfo{ColumnInfo{
 			Keyspace: "gocql_test",
 			Keyspace: "gocql_test",
 			Table:    table,
 			Table:    table,
 			Name:     "foo",
 			Name:     "foo",
@@ -615,13 +777,13 @@ func TestQueryInfo(t *testing.T) {
 		t.Fatalf("Failed to execute query for preparing statement: %v", err)
 		t.Fatalf("Failed to execute query for preparing statement: %v", err)
 	}
 	}
 
 
-	if len(info.args) != 1 {
-		t.Fatalf("Was not expecting meta data for %d query arguments, but got %d\n", 1, len(info.args))
+	if len(info.Args) != 1 {
+		t.Fatalf("Was not expecting meta data for %d query arguments, but got %d\n", 1, len(info.Args))
 	}
 	}
 
 
 	if *flagProto > 1 {
 	if *flagProto > 1 {
-		if len(info.rval) != 2 {
-			t.Fatalf("Was not expecting meta data for %d result columns, but got %d\n", 2, len(info.rval))
+		if len(info.Rval) != 2 {
+			t.Fatalf("Was not expecting meta data for %d result columns, but got %d\n", 2, len(info.Rval))
 		}
 		}
 	}
 	}
 }
 }

+ 46 - 23
conn.go

@@ -308,7 +308,7 @@ func (c *Conn) ping() error {
 	return err
 	return err
 }
 }
 
 
-func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
+func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	stmtsLRU.mu.Lock()
 	stmtsLRU.mu.Lock()
 	if val, ok := stmtsLRU.lru.Get(c.addr + stmt); ok {
 	if val, ok := stmtsLRU.lru.Get(c.addr + stmt); ok {
 		flight := val.(*inflightPrepare)
 		flight := val.(*inflightPrepare)
@@ -328,10 +328,10 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 	} else {
 	} else {
 		switch x := resp.(type) {
 		switch x := resp.(type) {
 		case resultPreparedFrame:
 		case resultPreparedFrame:
-			flight.info = &queryInfo{
-				id:   x.PreparedId,
-				args: x.Arguments,
-				rval: x.ReturnValues,
+			flight.info = &QueryInfo{
+				Id:   x.PreparedId,
+				Args: x.Arguments,
+				Rval: x.ReturnValues,
 			}
 			}
 		case error:
 		case error:
 			flight.err = x
 			flight.err = x
@@ -364,13 +364,25 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		if err != nil {
 		if err != nil {
 			return &Iter{err: err}
 			return &Iter{err: err}
 		}
 		}
-		if len(qry.values) != len(info.args) {
+
+		var values []interface{}
+
+		if qry.binding == nil {
+			values = qry.values
+		} else {
+			values, err = qry.binding(info)
+			if err != nil {
+				return &Iter{err: err}
+			}
+		}
+
+		if len(values) != len(info.Args) {
 			return &Iter{err: ErrQueryArgLength}
 			return &Iter{err: ErrQueryArgLength}
 		}
 		}
-		op.Prepared = info.id
-		op.Values = make([][]byte, len(qry.values))
-		for i := 0; i < len(qry.values); i++ {
-			val, err := Marshal(info.args[i].TypeInfo, qry.values[i])
+		op.Prepared = info.Id
+		op.Values = make([][]byte, len(values))
+		for i := 0; i < len(values); i++ {
+			val, err := Marshal(info.Args[i].TypeInfo, values[i])
 			if err != nil {
 			if err != nil {
 				return &Iter{err: err}
 				return &Iter{err: err}
 			}
 			}
@@ -473,28 +485,38 @@ func (c *Conn) executeBatch(batch *Batch) error {
 
 
 	for i := 0; i < len(batch.Entries); i++ {
 	for i := 0; i < len(batch.Entries); i++ {
 		entry := &batch.Entries[i]
 		entry := &batch.Entries[i]
-		var info *queryInfo
-		if len(entry.Args) > 0 {
+		var info *QueryInfo
+		var args []interface{}
+		if len(entry.Args) > 0 || entry.binding != nil {
 			var err error
 			var err error
 			info, err = c.prepareStatement(entry.Stmt, nil)
 			info, err = c.prepareStatement(entry.Stmt, nil)
 
 
-			if len(entry.Args) != len(info.args) {
+			if entry.binding == nil {
+				args = entry.Args
+			} else {
+				args, err = entry.binding(info)
+				if err != nil {
+					return err
+				}
+			}
+
+			if len(args) != len(info.Args) {
 				return ErrQueryArgLength
 				return ErrQueryArgLength
 			}
 			}
 
 
-			stmts[string(info.id)] = entry.Stmt
+			stmts[string(info.Id)] = entry.Stmt
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
 			f.writeByte(1)
 			f.writeByte(1)
-			f.writeShortBytes(info.id)
+			f.writeShortBytes(info.Id)
 		} else {
 		} else {
 			f.writeByte(0)
 			f.writeByte(0)
 			f.writeLongString(entry.Stmt)
 			f.writeLongString(entry.Stmt)
 		}
 		}
-		f.writeShort(uint16(len(entry.Args)))
-		for j := 0; j < len(entry.Args); j++ {
-			val, err := Marshal(info.args[j].TypeInfo, entry.Args[j])
+		f.writeShort(uint16(len(args)))
+		for j := 0; j < len(args); j++ {
+			val, err := Marshal(info.Args[j].TypeInfo, args[j])
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -624,10 +646,11 @@ func (c *Conn) setKeepalive(d time.Duration) error {
 	return nil
 	return nil
 }
 }
 
 
-type queryInfo struct {
-	id   []byte
-	args []ColumnInfo
-	rval []ColumnInfo
+// QueryInfo represents the meta data associated with a prepared CQL statement.
+type QueryInfo struct {
+	Id   []byte
+	Args []ColumnInfo
+	Rval []ColumnInfo
 }
 }
 
 
 type callReq struct {
 type callReq struct {
@@ -641,7 +664,7 @@ type callResp struct {
 }
 }
 
 
 type inflightPrepare struct {
 type inflightPrepare struct {
-	info *queryInfo
+	info *QueryInfo
 	err  error
 	err  error
 	wg   sync.WaitGroup
 	wg   sync.WaitGroup
 }
 }

+ 26 - 2
session.go

@@ -90,6 +90,21 @@ func (s *Session) Query(stmt string, values ...interface{}) *Query {
 	return qry
 	return qry
 }
 }
 
 
+// Bind generates a new query object based on the query statement passed in.
+// The query is automatically prepared if it has not previously been executed.
+// The binding callback allows the application to define which query argument
+// values will be marshalled as part of the query execution.
+// During execution, the meta data of the prepared query will be routed to the
+// binding callback, which is responsible for producing the query argument values.
+func (s *Session) Bind(stmt string, b func(q *QueryInfo) ([]interface{}, error)) *Query {
+	s.mu.RLock()
+	qry := &Query{stmt: stmt, binding: b, cons: s.cons,
+		session: s, pageSize: s.pageSize, trace: s.trace,
+		prefetch: s.prefetch, rt: s.cfg.RetryPolicy}
+	s.mu.RUnlock()
+	return qry
+}
+
 // Close closes all connections. The session is unusable after this
 // Close closes all connections. The session is unusable after this
 // operation.
 // operation.
 func (s *Session) Close() {
 func (s *Session) Close() {
@@ -184,6 +199,7 @@ type Query struct {
 	trace     Tracer
 	trace     Tracer
 	session   *Session
 	session   *Session
 	rt        RetryPolicy
 	rt        RetryPolicy
+	binding   func(q *QueryInfo) ([]interface{}, error)
 }
 }
 
 
 // Consistency sets the consistency level for this query. If no consistency
 // Consistency sets the consistency level for this query. If no consistency
@@ -395,6 +411,13 @@ func (b *Batch) Query(stmt string, args ...interface{}) {
 	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
 	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
 }
 }
 
 
+// Bind adds the query to the batch operation and correlates it with a binding callback
+// that will be invoked when the batch is executed. The binding callback allows the application
+// to define which query argument values will be marshalled as part of the batch execution.
+func (b *Batch) Bind(stmt string, bind func(q *QueryInfo) ([]interface{}, error)) {
+	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, binding: bind})
+}
+
 // RetryPolicy sets the retry policy to use when executing the batch operation
 // RetryPolicy sets the retry policy to use when executing the batch operation
 func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 	b.rt = r
 	b.rt = r
@@ -415,8 +438,9 @@ const (
 )
 )
 
 
 type BatchEntry struct {
 type BatchEntry struct {
-	Stmt string
-	Args []interface{}
+	Stmt    string
+	Args    []interface{}
+	binding func(q *QueryInfo) ([]interface{}, error)
 }
 }
 
 
 type Consistency int
 type Consistency int