Browse Source

Alternative attempt to expose query binding meta data to the application

Ben Hood 11 years ago
parent
commit
5be3fcc634
3 changed files with 70 additions and 10 deletions
  1. 42 1
      cassandra_test.go
  2. 18 9
      conn.go
  3. 10 0
      session.go

+ 42 - 1
cassandra_test.go

@@ -551,6 +551,47 @@ func TestScanCASWithNilArguments(t *testing.T) {
 	}
 }
 
+func TestQueryInfo(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query("CREATE TABLE expose_query_info (id int, value text, PRIMARY KEY (id))").Exec(); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	if err := session.Query("INSERT INTO expose_query_info (id, value) VALUES (?, ?)", 113, "foo").Exec(); err != nil {
+		t.Fatalf("insert into expose_query_info failed, err '%v'", err)
+	}
+
+	autobinder := func(q *QueryInfo) []interface{} {
+		values := make([]interface{}, 1)
+		values[0] = 113
+		return values
+	}
+
+	qry := session.Bind("SELECT id, value FROM expose_query_info WHERE id = ?", autobinder)
+
+	if err := qry.Exec(); err != nil {
+		t.Fatalf("expose query info failed, error '%v'", err)
+	}
+
+	iter := qry.Iter()
+
+	var id int
+	var value string
+
+	qry.Iter().Scan(&id, &value)
+
+	if err := iter.Close(); err != nil {
+		t.Fatalf("query with exposed info failed, err '%v'", err)
+	}
+
+	if value != "foo" {
+		t.Fatalf("Expected value %s, but got %s", "foo", value)
+	}
+
+}
+
 func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) {
 	if err := session.Query(`CREATE TABLE ` + table + ` (
 			foo   varchar,
@@ -565,7 +606,7 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 	stmtsLRU.mu.Lock()
 	stmtsLRU.lru.Add(conn.addr+stmt, flight)
 	stmtsLRU.mu.Unlock()
-	flight.info = &queryInfo{
+	flight.info = &QueryInfo{
 		id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
 		args: []ColumnInfo{ColumnInfo{
 			Keyspace: "gocql_test",

+ 18 - 9
conn.go

@@ -308,7 +308,7 @@ func (c *Conn) ping() error {
 	return err
 }
 
-func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
+func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
 	stmtsLRU.mu.Lock()
 	if val, ok := stmtsLRU.lru.Get(c.addr + stmt); ok {
 		flight := val.(*inflightPrepare)
@@ -328,7 +328,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 	} else {
 		switch x := resp.(type) {
 		case resultPreparedFrame:
-			flight.info = &queryInfo{
+			flight.info = &QueryInfo{
 				id:   x.PreparedId,
 				args: x.Values,
 			}
@@ -363,13 +363,22 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		if err != nil {
 			return &Iter{err: err}
 		}
-		if len(qry.values) != len(info.args) {
+
+		var values []interface{}
+
+		if qry.binding == nil {
+			values = qry.values
+		} else {
+			values = qry.binding(info)
+		}
+
+		if len(values) != len(info.args) {
 			return &Iter{err: ErrQueryArgLength}
 		}
 		op.Prepared = info.id
-		op.Values = make([][]byte, len(qry.values))
-		for i := 0; i < len(qry.values); i++ {
-			val, err := Marshal(info.args[i].TypeInfo, qry.values[i])
+		op.Values = make([][]byte, len(values))
+		for i := 0; i < len(values); i++ {
+			val, err := Marshal(info.args[i].TypeInfo, values[i])
 			if err != nil {
 				return &Iter{err: err}
 			}
@@ -472,7 +481,7 @@ func (c *Conn) executeBatch(batch *Batch) error {
 
 	for i := 0; i < len(batch.Entries); i++ {
 		entry := &batch.Entries[i]
-		var info *queryInfo
+		var info *QueryInfo
 		if len(entry.Args) > 0 {
 			var err error
 			info, err = c.prepareStatement(entry.Stmt, nil)
@@ -619,7 +628,7 @@ func (c *Conn) setKeepalive(d time.Duration) error {
 	return nil
 }
 
-type queryInfo struct {
+type QueryInfo struct {
 	id   []byte
 	args []ColumnInfo
 	rval []ColumnInfo
@@ -636,7 +645,7 @@ type callResp struct {
 }
 
 type inflightPrepare struct {
-	info *queryInfo
+	info *QueryInfo
 	err  error
 	wg   sync.WaitGroup
 }

+ 10 - 0
session.go

@@ -90,6 +90,15 @@ func (s *Session) Query(stmt string, values ...interface{}) *Query {
 	return qry
 }
 
+func (s *Session) Bind(stmt string, b func(q *QueryInfo) []interface{}) *Query {
+	s.mu.RLock()
+	qry := &Query{stmt: stmt, binding: b, cons: s.cons,
+		session: s, pageSize: s.pageSize, trace: s.trace,
+		prefetch: s.prefetch, rt: s.cfg.RetryPolicy}
+	s.mu.RUnlock()
+	return qry
+}
+
 // Close closes all connections. The session is unusable after this
 // operation.
 func (s *Session) Close() {
@@ -184,6 +193,7 @@ type Query struct {
 	trace     Tracer
 	session   *Session
 	rt        RetryPolicy
+	binding   func(q *QueryInfo) []interface{}
 }
 
 // Consistency sets the consistency level for this query. If no consistency