فهرست منبع

No need to make the caller supply the CAS flag in the binding statement

Ben Hood 12 سال پیش
والد
کامیت
8d9a57e01e
2فایلهای تغییر یافته به همراه26 افزوده شده و 25 حذف شده
  1. 6 3
      gocql_test/main.go
  2. 20 22
      session.go

+ 6 - 3
gocql_test/main.go

@@ -146,12 +146,11 @@ func insertCAS() error {
 
 	var titleCAS string
 	var revidCAS uuid.UUID
-	var casApplied bool
 
 	applied, err := session.Query(
 		`INSERT INTO cas_table (title, revid)
         VALUES (?,?) IF NOT EXISTS`,
-		title, revid).ScanCas(&casApplied, &titleCAS, &revidCAS)
+		title, revid).ScanCas(&titleCAS, &revidCAS)
 
 	if err != nil {
 		return err
@@ -164,7 +163,11 @@ func insertCAS() error {
 	applied, err = session.Query(
 		`INSERT INTO cas_table (title, revid)
         VALUES (?,?) IF NOT EXISTS`,
-		title, revid).ScanCas(&casApplied, &titleCAS, &revidCAS)
+		title, revid).ScanCas(&titleCAS, &revidCAS)
+
+	if err != nil {
+		return err
+	}
 
 	if applied {
 		return fmt.Errorf("Should NOT have applied update for existing random title %s", title)

+ 20 - 22
session.go

@@ -176,47 +176,45 @@ func (q *Query) Scan(dest ...interface{}) error {
 // Otherwise the dest interface will not be bound and the function
 // will return true.
 func (q *Query) ScanCas(dest ...interface{}) (bool, error) {
-
-	// Copy and paste start
-	iter := q.Iter()
-	if iter.err != nil {
-		return false, iter.err
+	result := q.session.executeQuery(q)
+	if result.err != nil {
+		return false, result.err
 	}
-	if len(iter.rows) == 0 {
+	if len(result.rows) == 0 {
 		return false, ErrNotFound
 	}
-	// Copy and paste end
-
-	if iter.next != nil {
-		go iter.next.fetch()
-	}
 
-	switch len(iter.columns) {
+	switch len(result.columns) {
 	case 1:
 		{
 			// The CAS operation was applied
 			return true, nil
 		}
-	case len(dest):
+	case len(dest) + 1:
 		{
-			// Copy and paste start - should this get merged into the upstream, this should
-			// be factored out
-			for i := 0; i < len(iter.columns); i++ {
-				err := Unmarshal(iter.columns[i].TypeInfo, iter.rows[iter.pos][i], dest[i])
+			// The CAS operation was NOT applied
+			// In this case, the result will return the entire row from the database
+			// in addition a flag in indicating the
+			var applied bool
+			Unmarshal(result.columns[0].TypeInfo, result.rows[result.pos][0], &applied)
+
+			if applied {
+				return applied, errors.New("Expected unapplied CAS statement, but received applied CAS statement")
+			}
+
+			for i := 1; i < len(result.columns); i++ {
+				err := Unmarshal(result.columns[i].TypeInfo, result.rows[result.pos][i], dest[i-1])
 				if err != nil {
 					return false, err
 				}
 			}
-			// Copy and paste end
-			return false, nil
+			return false, result.err
 		}
 	default:
 		{
-			return false, errors.New("count mismatch")
+			return false, fmt.Errorf("Expected %d + 1 columns, but received %d columns", len(dest), len(result.columns))
 		}
 	}
-
-	return false, iter.Close()
 }
 
 // Iter represents an iterator that can be used to iterate over all rows that