Browse Source

Added query args length check to batch statements

Ben Hood 11 năm trước cách đây
mục cha
commit
5e5d698a1b
2 tập tin đã thay đổi với 30 bổ sung0 xóa
  1. 25 0
      cassandra_test.go
  2. 5 0
      conn.go

+ 25 - 0
cassandra_test.go

@@ -327,6 +327,19 @@ func TestTooManyQueryArgs(t *testing.T) {
 	if err != ErrQueryArgLength {
 		t.Fatalf("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an ErrQueryArgLength, but returned: %s", err)
 	}
+
+	batch := session.NewBatch(UnloggedBatch)
+	batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3)
+	err = session.ExecuteBatch(batch)
+
+	if err == nil {
+		t.Fatal("'`INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an ErrQueryArgLength")
+	}
+
+	if err != ErrQueryArgLength {
+		t.Fatalf("'INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an ErrQueryArgLength, but returned: %s", err)
+	}
+
 }
 
 // TestNotEnoughQueryArgs tests to make sure the library correctly handles the application level bug
@@ -348,6 +361,18 @@ func TestNotEnoughQueryArgs(t *testing.T) {
 	if err != ErrQueryArgLength {
 		t.Fatalf("'`SELECT * FROM too_few_query_args WHERE id = ? and cluster = ?`, 1' should return an ErrQueryArgLength, but returned: %s", err)
 	}
+
+	batch := session.NewBatch(UnloggedBatch)
+	batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2)
+	err = session.ExecuteBatch(batch)
+
+	if err == nil {
+		t.Fatal("'`INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an ErrQueryArgLength")
+	}
+
+	if err != ErrQueryArgLength {
+		t.Fatalf("'INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an ErrQueryArgLength, but returned: %s", err)
+	}
 }
 
 // TestCreateSessionTimeout tests to make sure the CreateSession function timeouts out correctly

+ 5 - 0
conn.go

@@ -476,6 +476,11 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		if len(entry.Args) > 0 {
 			var err error
 			info, err = c.prepareStatement(entry.Stmt, nil)
+
+			if len(entry.Args) != len(info.args) {
+				return ErrQueryArgLength
+			}
+
 			stmts[string(info.id)] = entry.Stmt
 			if err != nil {
 				return err