Browse Source

Merge pull request #491 from ChannelMeter/feature/executeBatchCAS

Add ExecuteBatchCAS and MapExecuteBatchCAS functions for using batches with LWT. Fixes #453
Chris Bannister 10 years ago
parent
commit
231332cc32
3 changed files with 129 additions and 19 deletions
  1. 55 1
      cassandra_test.go
  2. 18 11
      conn.go
  3. 56 7
      session.go

+ 55 - 1
cassandra_test.go

@@ -353,6 +353,60 @@ func TestCAS(t *testing.T) {
 	} else if !applied {
 		t.Fatal("delete should have been applied")
 	}
+
+	if err := session.Query(`TRUNCATE cas_table`).Exec(); err != nil {
+		t.Fatal("truncate:", err)
+	}
+
+	successBatch := session.NewBatch(LoggedBatch)
+	successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
+	if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
+		t.Fatal("insert:", err)
+	} else if !applied {
+		t.Fatal("insert should have been applied")
+	}
+
+	successBatch = session.NewBatch(LoggedBatch)
+	successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified)
+	casMap := make(map[string]interface{})
+	if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil {
+		t.Fatal("insert:", err)
+	} else if !applied {
+		t.Fatal("insert should have been applied")
+	}
+
+	failBatch := session.NewBatch(LoggedBatch)
+	failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
+	if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
+		t.Fatal("insert:", err)
+	} else if applied {
+		t.Fatal("insert shouldn't have been applied")
+	}
+
+	insertBatch := session.NewBatch(LoggedBatch)
+	insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
+	insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
+	if err := session.ExecuteBatch(insertBatch); err != nil {
+		t.Fatal("insert:", err)
+	}
+
+	failBatch = session.NewBatch(LoggedBatch)
+	failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
+	failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
+	if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
+		t.Fatal("insert:", err)
+	} else if applied {
+		t.Fatal("insert shouldn't have been applied")
+	} else {
+		if scan := iter.Scan(&applied, &titleCAS, &revidCAS, &modifiedCAS); scan && applied {
+			t.Fatal("insert shouldn't have been applied")
+		} else if !scan {
+			t.Fatal("should have scanned another row")
+		}
+		if err := iter.Close(); err != nil {
+			t.Fatal("scan:", err)
+		}
+	}
 }
 
 func TestMapScanCAS(t *testing.T) {
@@ -1143,7 +1197,7 @@ func TestReprepareBatch(t *testing.T) {
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
 	batch := session.NewBatch(UnloggedBatch)
 	batch.Query(stmt, "bar")
-	if err := conn.executeBatch(batch); err != nil {
+	if _, err := conn.executeBatch(batch); err != nil {
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 	}
 

+ 18 - 11
conn.go

@@ -728,9 +728,9 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	return nil
 }
 
-func (c *Conn) executeBatch(batch *Batch) error {
+func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 	if c.version == protoVersion1 {
-		return ErrUnsupported
+		return nil, ErrUnsupported
 	}
 
 	n := len(batch.Entries)
@@ -750,7 +750,7 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		if len(entry.Args) > 0 || entry.binding != nil {
 			info, err := c.prepareStatement(entry.Stmt, nil)
 			if err != nil {
-				return err
+				return nil, err
 			}
 
 			var args []interface{}
@@ -764,12 +764,12 @@ func (c *Conn) executeBatch(batch *Batch) error {
 				}
 				args, err = entry.binding(binding)
 				if err != nil {
-					return err
+					return nil, err
 				}
 			}
 
 			if len(args) != len(info.reqMeta.columns) {
-				return ErrQueryArgLength
+				return nil, ErrQueryArgLength
 			}
 
 			b.preparedID = info.preparedID
@@ -780,7 +780,7 @@ func (c *Conn) executeBatch(batch *Batch) error {
 			for j := 0; j < len(info.reqMeta.columns); j++ {
 				val, err := Marshal(info.reqMeta.columns[j].TypeInfo, args[j])
 				if err != nil {
-					return err
+					return nil, err
 				}
 
 				b.values[j].value = val
@@ -794,12 +794,12 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	// TODO: should batch support tracing?
 	resp, err := c.exec(req, nil)
 	if err != nil {
-		return err
+		return nil, err
 	}
 
 	switch x := resp.(type) {
 	case *resultVoidFrame:
-		return nil
+		return nil, nil
 	case *RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
@@ -810,12 +810,19 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		if found {
 			return c.executeBatch(batch)
 		} else {
-			return x
+			return nil, x
 		}
+	case *resultRowsFrame:
+		iter := &Iter{
+			meta: x.meta,
+			rows: x.rows,
+		}
+
+		return iter, nil
 	case error:
-		return x
+		return nil, x
 	default:
-		return NewErrProtocol("Unknown type in response to batch statement: %s", x)
+		return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
 	}
 }
 

+ 56 - 7
session.go

@@ -371,22 +371,21 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	return routingKeyInfo, nil
 }
 
-// ExecuteBatch executes a batch operation and returns nil if successful
-// otherwise an error is returned describing the failure.
-func (s *Session) ExecuteBatch(batch *Batch) error {
+func (s *Session) executeBatch(batch *Batch) (*Iter, error) {
 	// fail fast
 	if s.Closed() {
-		return ErrSessionClosed
+		return nil, ErrSessionClosed
 	}
 
 	// Prevent the execution of the batch if greater than the limit
 	// Currently batches have a limit of 65536 queries.
 	// https://datastax-oss.atlassian.net/browse/JAVA-229
 	if batch.Size() > BatchSizeMaximum {
-		return ErrTooManyStmts
+		return nil, ErrTooManyStmts
 	}
 
 	var err error
+	var iter *Iter
 	batch.attempts = 0
 	batch.totalLatency = 0
 	for {
@@ -398,12 +397,12 @@ func (s *Session) ExecuteBatch(batch *Batch) error {
 			break
 		}
 		t := time.Now()
-		err = conn.executeBatch(batch)
+		iter, err = conn.executeBatch(batch)
 		batch.totalLatency += time.Now().Sub(t).Nanoseconds()
 		batch.attempts++
 		//Exit loop if operation executed correctly
 		if err == nil {
-			return nil
+			return iter, err
 		}
 
 		if batch.rt == nil || !batch.rt.Attempt(batch) {
@@ -411,9 +410,59 @@ func (s *Session) ExecuteBatch(batch *Batch) error {
 		}
 	}
 
+	return nil, err
+}
+
+// ExecuteBatch executes a batch operation and returns nil if successful
+// otherwise an error is returned describing the failure.
+func (s *Session) ExecuteBatch(batch *Batch) error {
+	_, err := s.executeBatch(batch)
 	return err
 }
 
+// ExecuteBatchCAS executes a batch operation and returns nil if successful and
+// an iterator (to scan aditional rows if more than one conditional statement)
+// was sent, otherwise an error is returned describing the failure.
+// Further scans on the interator must also remember to include
+// the applied boolean as the first argument to *Iter.Scan
+func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bool, iter *Iter, err error) {
+	if iter, err := s.executeBatch(batch); err == nil {
+		if err := iter.checkErrAndNotFound(); err != nil {
+			return false, nil, err
+		}
+		if len(iter.Columns()) > 1 {
+			dest = append([]interface{}{&applied}, dest...)
+			iter.Scan(dest...)
+		} else {
+			iter.Scan(&applied)
+		}
+		return applied, iter, nil
+	} else {
+		return false, nil, err
+	}
+}
+
+// MapExecuteBatchCAS executes a batch operation much like ExecuteBatchCAS,
+// however it accepts a map rather than a list of arguments for the initial
+// scan.
+func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) (applied bool, iter *Iter, err error) {
+	if iter, err := s.executeBatch(batch); err == nil {
+		if err := iter.checkErrAndNotFound(); err != nil {
+			return false, nil, err
+		}
+		iter.MapScan(dest)
+		applied = dest["[applied]"].(bool)
+		delete(dest, "[applied]")
+
+		// we usually close here, but instead of closing, just returin an error
+		// if MapScan failed. Although Close just returns err, using Close
+		// here might be confusing as we are not actually closing the iter
+		return applied, iter, iter.err
+	} else {
+		return false, nil, err
+	}
+}
+
 // Query represents a CQL statement that can be executed.
 type Query struct {
 	stmt             string