Browse Source

Merge pull request #124 from nemothekid/cass1batchbug

Fix #117; Make sure to prepare begin batch statements.
Christoph Hack 11 years ago
parent
commit
8c74b3cf90
2 changed files with 31 additions and 1 deletions
  1. 23 0
      cass1batch_test.go
  2. 8 1
      conn.go

+ 23 - 0
cass1batch_test.go

@@ -0,0 +1,23 @@
+package gocql
+
+import (
+	"strings"
+	"testing"
+)
+
+func TestProto1BatchInsert(t *testing.T) {
+	session := createSession(t)
+	if err := session.Query("CREATE TABLE large (id int primary key)").Exec(); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	begin := "BEGIN BATCH"
+	end := "APPLY BATCH"
+	query := "INSERT INTO large (id) VALUES (?)"
+	fullQuery := strings.Join([]string{begin, query, end}, "\n")
+	args := []interface{}{5}
+	if err := session.Query(fullQuery, args...).Consistency(Quorum).Exec(); err != nil {
+		t.Fatal(err)
+	}
+
+}

+ 8 - 1
conn.go

@@ -371,9 +371,16 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	stmtType := op.Stmt
 	stmtType := op.Stmt
 	if n := strings.IndexFunc(stmtType, unicode.IsSpace); n >= 0 {
 	if n := strings.IndexFunc(stmtType, unicode.IsSpace); n >= 0 {
 		stmtType = strings.ToLower(stmtType[:n])
 		stmtType = strings.ToLower(stmtType[:n])
+		switch stmtType {
+		case "begin":
+			stmtTail := strings.TrimSpace(op.Stmt[n:])
+			if n := strings.IndexFunc(stmtTail, unicode.IsSpace); n >= 0 {
+				stmtType = stmtType + " " + strings.ToLower(stmtTail[:n])
+			}
+		}
 	}
 	}
 	switch stmtType {
 	switch stmtType {
-	case "select", "insert", "update", "delete":
+	case "select", "insert", "update", "delete", "begin batch":
 		// Prepare all DML queries. Other queries can not be prepared.
 		// Prepare all DML queries. Other queries can not be prepared.
 		info, err := c.prepareStatement(qry.stmt, qry.trace)
 		info, err := c.prepareStatement(qry.stmt, qry.trace)
 		if err != nil {
 		if err != nil {