Browse Source

Merge pull request #590 from Zariel/fix-batch-iter-errors

batch: return an iter containing an error like queries
Chris Bannister 10 years ago
parent
commit
c4ad114e96
3 changed files with 67 additions and 59 deletions
  1. 1 1
      cassandra_test.go
  2. 15 16
      conn.go
  3. 51 42
      session.go

+ 1 - 1
cassandra_test.go

@@ -1068,7 +1068,7 @@ func TestReprepareBatch(t *testing.T) {
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
 	batch := session.NewBatch(UnloggedBatch)
 	batch.Query(stmt, "bar")
-	if _, err := conn.executeBatch(batch); err != nil {
+	if err := conn.executeBatch(batch).Close(); err != nil {
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 	}
 

+ 15 - 16
conn.go

@@ -817,9 +817,9 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	return nil
 }
 
-func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
+func (c *Conn) executeBatch(batch *Batch) *Iter {
 	if c.version == protoVersion1 {
-		return nil, ErrUnsupported
+		return &Iter{err: ErrUnsupported}
 	}
 
 	n := len(batch.Entries)
@@ -831,7 +831,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 		defaultTimestamp:  batch.defaultTimestamp,
 	}
 
-	stmts := make(map[string]string)
+	stmts := make(map[string]string, len(batch.Entries))
 
 	for i := 0; i < n; i++ {
 		entry := &batch.Entries[i]
@@ -839,7 +839,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 		if len(entry.Args) > 0 || entry.binding != nil {
 			info, err := c.prepareStatement(entry.Stmt, nil)
 			if err != nil {
-				return nil, err
+				return &Iter{err: err}
 			}
 
 			var args []interface{}
@@ -848,12 +848,12 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 			} else {
 				args, err = entry.binding(info)
 				if err != nil {
-					return nil, err
+					return &Iter{err: err}
 				}
 			}
 
 			if len(args) != len(info.Args) {
-				return nil, ErrQueryArgLength
+				return &Iter{err: ErrQueryArgLength}
 			}
 
 			b.preparedID = info.Id
@@ -864,7 +864,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 			for j := 0; j < len(info.Args); j++ {
 				val, err := Marshal(info.Args[j].TypeInfo, args[j])
 				if err != nil {
-					return nil, err
+					return &Iter{err: err}
 				}
 
 				b.values[j].value = val
@@ -878,18 +878,18 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 	// TODO: should batch support tracing?
 	framer, err := c.exec(req, nil)
 	if err != nil {
-		return nil, err
+		return &Iter{err: err}
 	}
 
 	resp, err := framer.parseFrame()
 	if err != nil {
-		return nil, err
+		return &Iter{err: err, framer: framer}
 	}
 
 	switch x := resp.(type) {
 	case *resultVoidFrame:
 		framerPool.Put(framer)
-		return nil, nil
+		return &Iter{}
 	case *RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
@@ -903,7 +903,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 		if found {
 			return c.executeBatch(batch)
 		} else {
-			return nil, x
+			return &Iter{err: err, framer: framer}
 		}
 	case *resultRowsFrame:
 		iter := &Iter{
@@ -912,13 +912,12 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 			framer: framer,
 		}
 
-		return iter, nil
+		return iter
 	case error:
-		framerPool.Put(framer)
-		return nil, x
+
+		return &Iter{err: err, framer: framer}
 	default:
-		framerPool.Put(framer)
-		return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
+		return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer}
 	}
 }
 

+ 51 - 42
session.go

@@ -452,100 +452,109 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	return routingKeyInfo, nil
 }
 
-func (s *Session) executeBatch(batch *Batch) (*Iter, error) {
+func (s *Session) executeBatch(batch *Batch) *Iter {
 	// fail fast
 	if s.Closed() {
-		return nil, ErrSessionClosed
+		return &Iter{err: ErrSessionClosed}
 	}
 
 	// Prevent the execution of the batch if greater than the limit
 	// Currently batches have a limit of 65536 queries.
 	// https://datastax-oss.atlassian.net/browse/JAVA-229
 	if batch.Size() > BatchSizeMaximum {
-		return nil, ErrTooManyStmts
+		return &Iter{err: ErrTooManyStmts}
 	}
 
-	var err error
 	var iter *Iter
 	batch.attempts = 0
 	batch.totalLatency = 0
 	for {
 		host, conn := s.pool.Pick(nil)
 
-		//Assign the error unavailable and break loop
+		batch.attempts++
 		if conn == nil {
-			err = ErrNoConnections
+			if batch.rt == nil || !batch.rt.Attempt(batch) {
+				// Assign the error unavailable and break loop
+				iter = &Iter{err: ErrNoConnections}
+				break
+			}
+
+			continue
+		}
+
+		if conn == nil {
+			iter = &Iter{err: ErrNoConnections}
 			break
 		}
+
 		t := time.Now()
-		iter, err = conn.executeBatch(batch)
-		batch.totalLatency += time.Now().Sub(t).Nanoseconds()
-		batch.attempts++
 
-		// Update host
-		host.Mark(err)
+		iter = conn.executeBatch(batch)
 
+		batch.totalLatency += time.Since(t).Nanoseconds()
 		// Exit loop if operation executed correctly
-		if err == nil {
+		if iter.err == nil {
+			host.Mark(nil)
 			break
 		}
 
+		// Mark host with error if returned from Close
+		host.Mark(iter.Close())
+
 		if batch.rt == nil || !batch.rt.Attempt(batch) {
 			break
 		}
 	}
 
-	return iter, err
+	return iter
 }
 
 // ExecuteBatch executes a batch operation and returns nil if successful
 // otherwise an error is returned describing the failure.
 func (s *Session) ExecuteBatch(batch *Batch) error {
-	_, err := s.executeBatch(batch)
-	return err
+	iter := s.executeBatch(batch)
+	return iter.Close()
 }
 
-// ExecuteBatchCAS executes a batch operation and returns nil if successful and
+// ExecuteBatchCAS executes a batch operation and returns true if successful and
 // an iterator (to scan aditional rows if more than one conditional statement)
-// was sent, otherwise an error is returned describing the failure.
+// was sent.
 // Further scans on the interator must also remember to include
 // the applied boolean as the first argument to *Iter.Scan
 func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bool, iter *Iter, err error) {
-	if iter, err := s.executeBatch(batch); err == nil {
-		if err := iter.checkErrAndNotFound(); err != nil {
-			return false, nil, err
-		}
-		if len(iter.Columns()) > 1 {
-			dest = append([]interface{}{&applied}, dest...)
-			iter.Scan(dest...)
-		} else {
-			iter.Scan(&applied)
-		}
-		return applied, iter, nil
-	} else {
+	iter = s.executeBatch(batch)
+	if err := iter.checkErrAndNotFound(); err != nil {
+		iter.Close()
 		return false, nil, err
 	}
+
+	if len(iter.Columns()) > 1 {
+		dest = append([]interface{}{&applied}, dest...)
+		iter.Scan(dest...)
+	} else {
+		iter.Scan(&applied)
+	}
+
+	return applied, iter, nil
 }
 
 // MapExecuteBatchCAS executes a batch operation much like ExecuteBatchCAS,
 // however it accepts a map rather than a list of arguments for the initial
 // scan.
 func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) (applied bool, iter *Iter, err error) {
-	if iter, err := s.executeBatch(batch); err == nil {
-		if err := iter.checkErrAndNotFound(); err != nil {
-			return false, nil, err
-		}
-		iter.MapScan(dest)
-		applied = dest["[applied]"].(bool)
-		delete(dest, "[applied]")
-
-		// we usually close here, but instead of closing, just returin an error
-		// if MapScan failed. Although Close just returns err, using Close
-		// here might be confusing as we are not actually closing the iter
-		return applied, iter, iter.err
-	} else {
+	iter = s.executeBatch(batch)
+	if err := iter.checkErrAndNotFound(); err != nil {
+		iter.Close()
 		return false, nil, err
 	}
+	iter.MapScan(dest)
+	applied = dest["[applied]"].(bool)
+	delete(dest, "[applied]")
+
+	// we usually close here, but instead of closing, just returin an error
+	// if MapScan failed. Although Close just returns err, using Close
+	// here might be confusing as we are not actually closing the iter
+	return applied, iter, iter.err
 }
 
 func (s *Session) connect(addr string, errorHandler ConnErrorHandler) (*Conn, error) {