Browse Source

cleaned up CAS branch a bit

Christoph Hack 12 năm trước cách đây
mục cha
commit
92ccf5fec6
2 tập tin đã thay đổi với 22 bổ sung46 xóa
  1. 7 6
      gocql_test/main.go
  2. 15 40
      session.go

+ 7 - 6
gocql_test/main.go

@@ -150,7 +150,7 @@ func insertCAS() error {
 	applied, err := session.Query(
 		`INSERT INTO cas_table (title, revid)
         VALUES (?,?) IF NOT EXISTS`,
-		title, revid).ScanCas(&titleCAS, &revidCAS)
+		title, revid).ScanCAS(&titleCAS, &revidCAS)
 
 	if err != nil {
 		return err
@@ -163,7 +163,7 @@ func insertCAS() error {
 	applied, err = session.Query(
 		`INSERT INTO cas_table (title, revid)
         VALUES (?,?) IF NOT EXISTS`,
-		title, revid).ScanCas(&titleCAS, &revidCAS)
+		title, revid).ScanCAS(&titleCAS, &revidCAS)
 
 	if err != nil {
 		return err
@@ -212,10 +212,6 @@ func main() {
 		log.Fatal("insertTestData: ", err)
 	}
 
-	if err := insertCAS(); err != nil {
-		log.Fatal("insertCAS: ", err)
-	}
-
 	var count int
 	if err := session.Query("SELECT COUNT(*) FROM page").Scan(&count); err != nil {
 		log.Fatal("getCount: ", err)
@@ -284,5 +280,10 @@ func main() {
 		if err := insertBatch(); err != nil {
 			log.Fatal("insertBatch: ", err)
 		}
+
+		// CAS
+		if err := insertCAS(); err != nil {
+			log.Fatal("insertCAS: ", err)
+		}
 	}
 }

+ 15 - 40
session.go

@@ -171,50 +171,25 @@ func (q *Query) Scan(dest ...interface{}) error {
 	return iter.Close()
 }
 
-// If the CAS operation was applied, this function
-// will bind the result to the dest interface and return false.
-// Otherwise the dest interface will not be bound and the function
-// will return true.
-func (q *Query) ScanCas(dest ...interface{}) (bool, error) {
-	result := q.session.executeQuery(q)
-	if result.err != nil {
-		return false, result.err
+// ScanCAS executes a lightweight transaction (i.e. an UPDATE or INSERT
+// statement containing an IF clause). If the transaction fails because
+// the existing values did not match, the previos values will be stored
+// in dest.
+func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
+	iter := q.Iter()
+	if iter.err != nil {
+		return false, iter.err
 	}
-	if len(result.rows) == 0 {
+	if len(iter.rows) == 0 {
 		return false, ErrNotFound
 	}
-
-	switch len(result.columns) {
-	case 1:
-		{
-			// The CAS operation was applied
-			return true, nil
-		}
-	case len(dest) + 1:
-		{
-			// The CAS operation was NOT applied
-			// In this case, the result will return the entire row from the database
-			// in addition a flag in indicating the
-			var applied bool
-			Unmarshal(result.columns[0].TypeInfo, result.rows[result.pos][0], &applied)
-
-			if applied {
-				return applied, errors.New("Expected unapplied CAS statement, but received applied CAS statement")
-			}
-
-			for i := 1; i < len(result.columns); i++ {
-				err := Unmarshal(result.columns[i].TypeInfo, result.rows[result.pos][i], dest[i-1])
-				if err != nil {
-					return false, err
-				}
-			}
-			return false, result.err
-		}
-	default:
-		{
-			return false, fmt.Errorf("Expected %d + 1 columns, but received %d columns", len(dest), len(result.columns))
-		}
+	if len(iter.Columns()) > 1 {
+		dest = append([]interface{}{&applied}, dest...)
+		iter.Scan(dest...)
+	} else {
+		iter.Scan(&applied)
 	}
+	return applied, iter.Close()
 }
 
 // Iter represents an iterator that can be used to iterate over all rows that