Jelajahi Sumber

Merge pull request #126 from skoikovs/prepare_all_batch_types

Prepare all batch types
Chris Bannister 11 tahun lalu
induk
melakukan
e618d79fc0
2 mengubah file dengan 56 tambahan dan 14 penghapusan
  1. 33 0
      cass1batch_test.go
  2. 23 14
      conn.go

+ 33 - 0
cass1batch_test.go

@@ -21,3 +21,36 @@ func TestProto1BatchInsert(t *testing.T) {
 	}
 
 }
+
+func TestShouldPrepareFunction(t *testing.T) {
+	var shouldPrepareTests = []struct {
+		Stmt   string
+		Result bool
+	}{
+		{`
+      BEGIN BATCH
+        INSERT INTO users (userID, password)
+        VALUES ('smith', 'secret')
+      APPLY BATCH
+    ;
+      `, true},
+		{`INSERT INTO users (userID, password, name) VALUES ('user2', 'ch@ngem3b', 'second user')`, true},
+		{`BEGIN COUNTER BATCH UPDATE stats SET views = views + 1 WHERE pageid = 1 APPLY BATCH`, true},
+		{`delete name from users where userID = 'smith';`, true},
+		{`  UPDATE users SET password = 'secret' WHERE userID = 'smith'   `, true},
+		{`CREATE TABLE users (
+        user_name varchar PRIMARY KEY,
+        password varchar,
+        gender varchar,
+        session_token varchar,
+        state varchar,
+        birth_year bigint
+      );`, false},
+	}
+
+	for _, test := range shouldPrepareTests {
+		if got := shouldPrepare(test.Stmt); got != test.Result {
+			t.Fatalf("%q: got %v, expected %v\n", test.Stmt, got, test.Result)
+		}
+	}
+}

+ 23 - 14
conn.go

@@ -361,26 +361,35 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*queryInfo, error) {
 	return flight.info, flight.err
 }
 
+func shouldPrepare(stmt string) bool {
+	stmt = strings.TrimLeftFunc(strings.TrimRightFunc(stmt, func(r rune) bool {
+		return unicode.IsSpace(r) || r == ';'
+	}), unicode.IsSpace)
+
+	var stmtType string
+	if n := strings.IndexFunc(stmt, unicode.IsSpace); n >= 0 {
+		stmtType = strings.ToLower(stmt[:n])
+	}
+	if stmtType == "begin" {
+		if n := strings.LastIndexFunc(stmt, unicode.IsSpace); n >= 0 {
+			stmtType = strings.ToLower(stmt[n+1:])
+		}
+	}
+	switch stmtType {
+	case "select", "insert", "update", "delete", "batch":
+		return true
+	}
+	return false
+}
+
 func (c *Conn) executeQuery(qry *Query) *Iter {
 	op := &queryFrame{
-		Stmt:      strings.TrimSpace(qry.stmt),
+		Stmt:      qry.stmt,
 		Cons:      qry.cons,
 		PageSize:  qry.pageSize,
 		PageState: qry.pageState,
 	}
-	stmtType := op.Stmt
-	if n := strings.IndexFunc(stmtType, unicode.IsSpace); n >= 0 {
-		stmtType = strings.ToLower(stmtType[:n])
-		switch stmtType {
-		case "begin":
-			stmtTail := strings.TrimSpace(op.Stmt[n:])
-			if n := strings.IndexFunc(stmtTail, unicode.IsSpace); n >= 0 {
-				stmtType = stmtType + " " + strings.ToLower(stmtTail[:n])
-			}
-		}
-	}
-	switch stmtType {
-	case "select", "insert", "update", "delete", "begin batch":
+	if shouldPrepare(op.Stmt) {
 		// Prepare all DML queries. Other queries can not be prepared.
 		info, err := c.prepareStatement(qry.stmt, qry.trace)
 		if err != nil {