Explorar o código

Merge pull request #365 from Zariel/fix-unprepared-batch

Always set the number of values in a batch
Ben Hood %!s(int64=10) %!d(string=hai) anos
pai
achega
75dba14e0f
Modificáronse 2 ficheiros con 44 adicións e 3 borrados
  1. 41 0
      cassandra_test.go
  2. 3 3
      frame.go

+ 41 - 0
cassandra_test.go

@@ -409,6 +409,47 @@ func TestBatch(t *testing.T) {
 	}
 }
 
+func TestUnpreparedBatch(t *testing.T) {
+	if *flagProto == 1 {
+		t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, `CREATE TABLE batch_unprepared (id int primary key, c counter)`); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	var batch *Batch
+	if *flagProto == 2 {
+		batch = NewBatch(CounterBatch)
+	} else {
+		batch = NewBatch(UnloggedBatch)
+	}
+
+	for i := 0; i < 100; i++ {
+		batch.Query(`UPDATE batch_unprepared SET c = c + 1 WHERE id = 1`)
+	}
+
+	if err := session.ExecuteBatch(batch); err != nil {
+		t.Fatal("execute batch:", err)
+	}
+
+	count := 0
+	if err := session.Query(`SELECT COUNT(*) FROM batch_unprepared`).Scan(&count); err != nil {
+		t.Fatal("select count:", err)
+	} else if count != 1 {
+		t.Fatalf("count: expected %d, got %d\n", 100, count)
+	}
+
+	if err := session.Query(`SELECT c FROM batch_unprepared`).Scan(&count); err != nil {
+		t.Fatal("select count:", err)
+	} else if count != 100 {
+		t.Fatalf("count: expected %d, got %d\n", 100, count)
+	}
+}
+
 // TestBatchLimit tests gocql to make sure batch operations larger than the maximum
 // statement limit are not submitted to a cassandra node.
 func TestBatchLimit(t *testing.T) {

+ 3 - 3
frame.go

@@ -1087,11 +1087,11 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame) error {
 		if len(b.preparedID) == 0 {
 			f.writeByte(0)
 			f.writeLongString(b.statement)
-			continue
+		} else {
+			f.writeByte(1)
+			f.writeShortBytes(b.preparedID)
 		}
 
-		f.writeByte(1)
-		f.writeShortBytes(b.preparedID)
 		f.writeShort(uint16(len(b.values)))
 		for j := range b.values {
 			col := &b.values[j]