Browse Source

First cut at adding CAS statements

Ben Hood 12 years ago
parent
commit
f1f71b5f46
2 changed files with 98 additions and 0 deletions
  1. 50 0
      gocql_test/main.go
  2. 48 0
      session.go

+ 50 - 0
gocql_test/main.go

@@ -5,6 +5,7 @@
 package main
 
 import (
+	"fmt"
 	"log"
 	"os"
 	"reflect"
@@ -80,6 +81,14 @@ func initSchema() error {
 		return err
 	}
 
+	if err := session.Query(`CREATE TABLE cas_table (
+            title   varchar,
+            revid   timeuuid,
+            PRIMARY KEY (title, revid)
+        )`).Exec(); err != nil {
+		return err
+	}
+
 	return nil
 }
 
@@ -131,6 +140,43 @@ func insertBatch() error {
 	return nil
 }
 
+func insertCAS() error {
+	title := "baz"
+	revid := uuid.TimeUUID()
+
+	var titleCAS string
+	var revidCAS uuid.UUID
+	var casApplied bool
+
+	applied, err := session.Query(
+		`INSERT INTO cas_table (title, revid)
+        VALUES (?,?) IF NOT EXISTS`,
+		title, revid).ScanCas(&casApplied, &titleCAS, &revidCAS)
+
+	if err != nil {
+		return err
+	}
+
+	if !applied {
+		return fmt.Errorf("Should have applied update for new random title %s", title)
+	}
+
+	applied, err = session.Query(
+		`INSERT INTO cas_table (title, revid)
+        VALUES (?,?) IF NOT EXISTS`,
+		title, revid).ScanCas(&casApplied, &titleCAS, &revidCAS)
+
+	if applied {
+		return fmt.Errorf("Should NOT have applied update for existing random title %s", title)
+	}
+
+	if title != titleCAS || revid != revidCAS {
+		return fmt.Errorf("Expected %s/%v but got %s/%v", title, revid, titleCAS, revidCAS)
+	}
+
+	return nil
+}
+
 func getPage(title string, revid uuid.UUID) (*Page, error) {
 	p := new(Page)
 	err := session.Query(`SELECT title, revid, body, views, protected, modified,
@@ -163,6 +209,10 @@ func main() {
 		log.Fatal("insertTestData: ", err)
 	}
 
+	if err := insertCAS(); err != nil {
+		log.Fatal("insertCAS: ", err)
+	}
+
 	var count int
 	if err := session.Query("SELECT COUNT(*) FROM page").Scan(&count); err != nil {
 		log.Fatal("getCount: ", err)

+ 48 - 0
session.go

@@ -171,6 +171,54 @@ func (q *Query) Scan(dest ...interface{}) error {
 	return iter.Close()
 }
 
+// If the CAS operation was applied, this function
+// will bind the result to the dest interface and return false.
+// Otherwise the dest interface will not be bound and the function
+// will return true.
+func (q *Query) ScanCas(dest ...interface{}) (bool, error) {
+
+	// Copy and paste start
+	iter := q.Iter()
+	if iter.err != nil {
+		return false, iter.err
+	}
+	if len(iter.rows) == 0 {
+		return false, ErrNotFound
+	}
+	// Copy and paste end
+
+	if iter.next != nil {
+		go iter.next.fetch()
+	}
+
+	switch len(iter.columns) {
+	case 1:
+		{
+			// The CAS operation was applied
+			return true, nil
+		}
+	case len(dest):
+		{
+			// Copy and paste start - should this get merged into the upstream, this should
+			// be factored out
+			for i := 0; i < len(iter.columns); i++ {
+				err := Unmarshal(iter.columns[i].TypeInfo, iter.rows[iter.pos][i], dest[i])
+				if err != nil {
+					return false, err
+				}
+			}
+			// Copy and paste end
+			return false, nil
+		}
+	default:
+		{
+			return false, errors.New("count mismatch")
+		}
+	}
+
+	return false, iter.Close()
+}
+
 // Iter represents an iterator that can be used to iterate over all rows that
 // were returned by a query. The iterator might send additional queries to the
 // database during the iteration if paging was enabled.