Browse Source

Added Bind() to Batch interface

Ben Hood 11 years ago
parent
commit
85fd5630d7
3 changed files with 67 additions and 8 deletions
  1. 48 1
      cassandra_test.go
  2. 12 5
      conn.go
  3. 7 2
      session.go

+ 48 - 1
cassandra_test.go

@@ -644,7 +644,6 @@ func TestBoundQueryInfo(t *testing.T) {
 	var value string
 
 	iter.Scan(&id, &cluster, &value)
-	qry.Iter().Scan(&id, &cluster, &value)
 
 	if err := iter.Close(); err != nil {
 		t.Fatalf("query with clustered_query_info info failed, err '%v'", err)
@@ -656,6 +655,54 @@ func TestBoundQueryInfo(t *testing.T) {
 
 }
 
+func TestBatchQueryInfo(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query("CREATE TABLE batch_query_info (id int, cluster int, value text, PRIMARY KEY (id, cluster))").Exec(); err != nil {
+		t.Fatalf("failed to create table with error '%v'", err)
+	}
+
+	write := func(q *QueryInfo) []interface{} {
+		values := make([]interface{}, 3)
+		values[0] = 4000
+		values[1] = 5000
+		values[2] = "bar"
+		return values
+	}
+
+	batch := session.NewBatch(LoggedBatch)
+	batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write)
+
+	if err := session.ExecuteBatch(batch); err != nil {
+		t.Fatalf("batch insert into batch_query_info failed, err '%v'", err)
+	}
+
+	read := func(q *QueryInfo) []interface{} {
+		values := make([]interface{}, 2)
+		values[0] = 4000
+		values[1] = 5000
+		return values
+	}
+
+	qry := session.Bind("SELECT id, cluster, value FROM batch_query_info WHERE id = ? and cluster = ?", read)
+
+	iter := qry.Iter()
+
+	var id, cluster int
+	var value string
+
+	iter.Scan(&id, &cluster, &value)
+
+	if err := iter.Close(); err != nil {
+		t.Fatalf("query with batch_query_info info failed, err '%v'", err)
+	}
+
+	if value != "bar" {
+		t.Fatalf("Expected value %s, but got %s", "bar", value)
+	}
+}
+
 func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) {
 	if err := session.Query(`CREATE TABLE ` + table + ` (
 			foo   varchar,

+ 12 - 5
conn.go

@@ -482,11 +482,18 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	for i := 0; i < len(batch.Entries); i++ {
 		entry := &batch.Entries[i]
 		var info *QueryInfo
-		if len(entry.Args) > 0 {
+		var args []interface{}
+		if len(entry.Args) > 0 || entry.binding != nil {
 			var err error
 			info, err = c.prepareStatement(entry.Stmt, nil)
 
-			if len(entry.Args) != len(info.args) {
+			if entry.binding == nil {
+				args = entry.Args
+			} else {
+				args = entry.binding(info)
+			}
+
+			if len(args) != len(info.args) {
 				return ErrQueryArgLength
 			}
 
@@ -500,9 +507,9 @@ func (c *Conn) executeBatch(batch *Batch) error {
 			f.writeByte(0)
 			f.writeLongString(entry.Stmt)
 		}
-		f.writeShort(uint16(len(entry.Args)))
-		for j := 0; j < len(entry.Args); j++ {
-			val, err := Marshal(info.args[j].TypeInfo, entry.Args[j])
+		f.writeShort(uint16(len(args)))
+		for j := 0; j < len(args); j++ {
+			val, err := Marshal(info.args[j].TypeInfo, args[j])
 			if err != nil {
 				return err
 			}

+ 7 - 2
session.go

@@ -405,6 +405,10 @@ func (b *Batch) Query(stmt string, args ...interface{}) {
 	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
 }
 
+func (b *Batch) Bind(stmt string, bind func(q *QueryInfo) []interface{}) {
+	b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, binding: bind})
+}
+
 // RetryPolicy sets the retry policy to use when executing the batch operation
 func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 	b.rt = r
@@ -425,8 +429,9 @@ const (
 )
 
 type BatchEntry struct {
-	Stmt string
-	Args []interface{}
+	Stmt    string
+	Args    []interface{}
+	binding func(q *QueryInfo) []interface{}
 }
 
 type Consistency int