Browse Source

Merge pull request #250 from matope/MapScanCAS

Add MapScanCAS() to capture a prev row safety.
Ben Hood 11 years ago
parent
commit
f68d4c093c
3 changed files with 77 additions and 10 deletions
  1. 1 0
      AUTHORS
  2. 42 0
      cassandra_test.go
  3. 34 10
      session.go

+ 1 - 0
AUTHORS

@@ -31,3 +31,4 @@ Muir Manders <muir@retailnext.net>
 Sankar P <sankar.curiosity@gmail.com>
 Julien Da Silva <julien.dasilva@gmail.com>
 Dan Kennedy <daniel@firstcs.co.uk>
+Yasuharu Goto <matope.ono@gmail.com>

+ 42 - 0
cassandra_test.go

@@ -298,6 +298,48 @@ func TestCAS(t *testing.T) {
 	}
 }
 
+func TestMapScanCAS(t *testing.T) {
+	if *flagProto == 1 {
+		t.Skip("lightweight transactions not supported. Please use Cassandra >= 2.0")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, `CREATE TABLE cas_table2 (
+			title         varchar,
+			revid   	  timeuuid,
+			last_modified timestamp,
+			deleted boolean,
+			PRIMARY KEY (title, revid)
+		)`); err != nil {
+		t.Fatal("create:", err)
+	}
+
+	title, revid, modified, deleted := "baz", TimeUUID(), time.Now(), false
+	mapCAS := map[string]interface{}{}
+
+	if applied, err := session.Query(`INSERT INTO cas_table2 (title, revid, last_modified, deleted)
+		VALUES (?, ?, ?, ?) IF NOT EXISTS`,
+		title, revid, modified, deleted).MapScanCAS(mapCAS); err != nil {
+		t.Fatal("insert:", err)
+	} else if !applied {
+		t.Fatal("insert should have been applied")
+	}
+
+	mapCAS = map[string]interface{}{}
+	if applied, err := session.Query(`INSERT INTO cas_table2 (title, revid, last_modified, deleted)
+		VALUES (?, ?, ?, ?) IF NOT EXISTS`,
+		title, revid, modified, deleted).MapScanCAS(mapCAS); err != nil {
+		t.Fatal("insert:", err)
+	} else if applied {
+		t.Fatal("insert should not have been applied")
+	} else if title != mapCAS["title"] || revid != mapCAS["revid"] || deleted != mapCAS["deleted"] {
+		t.Fatalf("expected %s/%v/%v/%v but got %s/%v/%v%v", title, revid, modified, false, mapCAS["title"], mapCAS["revid"], mapCAS["last_modified"], mapCAS["deleted"])
+	}
+
+}
+
 func TestBatch(t *testing.T) {
 	if *flagProto == 1 {
 		t.Skip("atomic batches not supported. Please use Cassandra >= 2.0")

+ 34 - 10
session.go

@@ -322,11 +322,8 @@ func (q *Query) Iter() *Iter {
 // were selected, ErrNotFound is returned.
 func (q *Query) Scan(dest ...interface{}) error {
 	iter := q.Iter()
-	if iter.err != nil {
-		return iter.err
-	}
-	if len(iter.rows) == 0 {
-		return ErrNotFound
+	if err := iter.checkErrAndNotFound(); err != nil {
+		return err
 	}
 	iter.Scan(dest...)
 	return iter.Close()
@@ -338,11 +335,8 @@ func (q *Query) Scan(dest ...interface{}) error {
 // in dest.
 func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
 	iter := q.Iter()
-	if iter.err != nil {
-		return false, iter.err
-	}
-	if len(iter.rows) == 0 {
-		return false, ErrNotFound
+	if err := iter.checkErrAndNotFound(); err != nil {
+		return false, err
 	}
 	if len(iter.Columns()) > 1 {
 		dest = append([]interface{}{&applied}, dest...)
@@ -353,6 +347,26 @@ func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
 	return applied, iter.Close()
 }
 
+// MapScanCAS executes a lightweight transaction (i.e. an UPDATE or INSERT
+// statement containing an IF clause). If the transaction fails because
+// the existing values did not match, the previos values will be stored
+// in dest map.
+//
+// As for INSERT .. IF NOT EXISTS, previous values will be returned as if
+// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
+// column mismatching. MapScanCAS is added to capture them safely.
+func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error) {
+	iter := q.Iter()
+	if err := iter.checkErrAndNotFound(); err != nil {
+		return false, err
+	}
+	iter.MapScan(dest)
+	applied = dest["[applied]"].(bool)
+	delete(dest, "[applied]")
+
+	return applied, iter.Close()
+}
+
 // Iter represents an iterator that can be used to iterate over all rows that
 // were returned by a query. The iterator might send additional queries to the
 // database during the iteration if paging was enabled.
@@ -415,6 +429,16 @@ func (iter *Iter) Close() error {
 	return iter.err
 }
 
+// checkErrAndNotFound handle error and NotFound in one method.
+func (iter *Iter) checkErrAndNotFound() error {
+	if iter.err != nil {
+		return iter.err
+	} else if len(iter.rows) == 0 {
+		return ErrNotFound
+	}
+	return nil
+}
+
 type nextIter struct {
 	qry  Query
 	pos  int