Explorar o código

batch: return an iter containing an error like queries

To be consistent with queries batches should return an error inside and
iter instead of directly. To access to the error callers should use
iter.Close().
Chris Bannister %!s(int64=10) %!d(string=hai) anos
pai
achega
d16cdd272b
Modificáronse 3 ficheiros con 67 adicións e 59 borrados
  1. 1 1
      cassandra_test.go
  2. 15 16
      conn.go
  3. 51 42
      session.go

+ 1 - 1
cassandra_test.go

@@ -1068,7 +1068,7 @@ func TestReprepareBatch(t *testing.T) {
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
 	batch := session.NewBatch(UnloggedBatch)
 	batch.Query(stmt, "bar")
-	if _, err := conn.executeBatch(batch); err != nil {
+	if err := conn.executeBatch(batch).Close(); err != nil {
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 	}
 

+ 15 - 16
conn.go

@@ -817,9 +817,9 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	return nil
 }
 
-func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
+func (c *Conn) executeBatch(batch *Batch) *Iter {
 	if c.version == protoVersion1 {
-		return nil, ErrUnsupported
+		return &Iter{err: ErrUnsupported}
 	}
 
 	n := len(batch.Entries)
@@ -831,7 +831,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 		defaultTimestamp:  batch.defaultTimestamp,
 	}
 
-	stmts := make(map[string]string)
+	stmts := make(map[string]string, len(batch.Entries))
 
 	for i := 0; i < n; i++ {
 		entry := &batch.Entries[i]
@@ -839,7 +839,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 		if len(entry.Args) > 0 || entry.binding != nil {
 			info, err := c.prepareStatement(entry.Stmt, nil)
 			if err != nil {
-				return nil, err
+				return &Iter{err: err}
 			}
 
 			var args []interface{}
@@ -848,12 +848,12 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 			} else {
 				args, err = entry.binding(info)
 				if err != nil {
-					return nil, err
+					return &Iter{err: err}
 				}
 			}
 
 			if len(args) != len(info.Args) {
-				return nil, ErrQueryArgLength
+				return &Iter{err: ErrQueryArgLength}
 			}
 
 			b.preparedID = info.Id
@@ -864,7 +864,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 			for j := 0; j < len(info.Args); j++ {
 				val, err := Marshal(info.Args[j].TypeInfo, args[j])
 				if err != nil {
-					return nil, err
+					return &Iter{err: err}
 				}
 
 				b.values[j].value = val
@@ -878,18 +878,18 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 	// TODO: should batch support tracing?
 	framer, err := c.exec(req, nil)
 	if err != nil {
-		return nil, err
+		return &Iter{err: err}
 	}
 
 	resp, err := framer.parseFrame()
 	if err != nil {
-		return nil, err
+		return &Iter{err: err, framer: framer}
 	}
 
 	switch x := resp.(type) {
 	case *resultVoidFrame:
 		framerPool.Put(framer)
-		return nil, nil
+		return &Iter{}
 	case *RequestErrUnprepared:
 		stmt, found := stmts[string(x.StatementId)]
 		if found {
@@ -903,7 +903,7 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 		if found {
 			return c.executeBatch(batch)
 		} else {
-			return nil, x
+			return &Iter{err: err, framer: framer}
 		}
 	case *resultRowsFrame:
 		iter := &Iter{
@@ -912,13 +912,12 @@ func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
 			framer: framer,
 		}
 
-		return iter, nil
+		return iter
 	case error:
-		framerPool.Put(framer)
-		return nil, x
+
+		return &Iter{err: err, framer: framer}
 	default:
-		framerPool.Put(framer)
-		return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
+		return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer}
 	}
 }
 

+ 51 - 42
session.go

@@ -452,100 +452,109 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	return routingKeyInfo, nil
 }
 
-func (s *Session) executeBatch(batch *Batch) (*Iter, error) {
+func (s *Session) executeBatch(batch *Batch) *Iter {
 	// fail fast
 	if s.Closed() {
-		return nil, ErrSessionClosed
+		return &Iter{err: ErrSessionClosed}
 	}
 
 	// Prevent the execution of the batch if greater than the limit
 	// Currently batches have a limit of 65536 queries.
 	// https://datastax-oss.atlassian.net/browse/JAVA-229
 	if batch.Size() > BatchSizeMaximum {
-		return nil, ErrTooManyStmts
+		return &Iter{err: ErrTooManyStmts}
 	}
 
-	var err error
 	var iter *Iter
 	batch.attempts = 0
 	batch.totalLatency = 0
 	for {
 		host, conn := s.pool.Pick(nil)
 
-		//Assign the error unavailable and break loop
+		batch.attempts++
 		if conn == nil {
-			err = ErrNoConnections
+			if batch.rt == nil || !batch.rt.Attempt(batch) {
+				// Assign the error unavailable and break loop
+				iter = &Iter{err: ErrNoConnections}
+				break
+			}
+
+			continue
+		}
+
+		if conn == nil {
+			iter = &Iter{err: ErrNoConnections}
 			break
 		}
+
 		t := time.Now()
-		iter, err = conn.executeBatch(batch)
-		batch.totalLatency += time.Now().Sub(t).Nanoseconds()
-		batch.attempts++
 
-		// Update host
-		host.Mark(err)
+		iter = conn.executeBatch(batch)
 
+		batch.totalLatency += time.Since(t).Nanoseconds()
 		// Exit loop if operation executed correctly
-		if err == nil {
+		if iter.err == nil {
+			host.Mark(nil)
 			break
 		}
 
+		// Mark host with error if returned from Close
+		host.Mark(iter.Close())
+
 		if batch.rt == nil || !batch.rt.Attempt(batch) {
 			break
 		}
 	}
 
-	return iter, err
+	return iter
 }
 
 // ExecuteBatch executes a batch operation and returns nil if successful
 // otherwise an error is returned describing the failure.
 func (s *Session) ExecuteBatch(batch *Batch) error {
-	_, err := s.executeBatch(batch)
-	return err
+	iter := s.executeBatch(batch)
+	return iter.Close()
 }
 
-// ExecuteBatchCAS executes a batch operation and returns nil if successful and
+// ExecuteBatchCAS executes a batch operation and returns true if successful and
 // an iterator (to scan aditional rows if more than one conditional statement)
-// was sent, otherwise an error is returned describing the failure.
+// was sent.
 // Further scans on the interator must also remember to include
 // the applied boolean as the first argument to *Iter.Scan
 func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bool, iter *Iter, err error) {
-	if iter, err := s.executeBatch(batch); err == nil {
-		if err := iter.checkErrAndNotFound(); err != nil {
-			return false, nil, err
-		}
-		if len(iter.Columns()) > 1 {
-			dest = append([]interface{}{&applied}, dest...)
-			iter.Scan(dest...)
-		} else {
-			iter.Scan(&applied)
-		}
-		return applied, iter, nil
-	} else {
+	iter = s.executeBatch(batch)
+	if err := iter.checkErrAndNotFound(); err != nil {
+		iter.Close()
 		return false, nil, err
 	}
+
+	if len(iter.Columns()) > 1 {
+		dest = append([]interface{}{&applied}, dest...)
+		iter.Scan(dest...)
+	} else {
+		iter.Scan(&applied)
+	}
+
+	return applied, iter, nil
 }
 
 // MapExecuteBatchCAS executes a batch operation much like ExecuteBatchCAS,
 // however it accepts a map rather than a list of arguments for the initial
 // scan.
 func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) (applied bool, iter *Iter, err error) {
-	if iter, err := s.executeBatch(batch); err == nil {
-		if err := iter.checkErrAndNotFound(); err != nil {
-			return false, nil, err
-		}
-		iter.MapScan(dest)
-		applied = dest["[applied]"].(bool)
-		delete(dest, "[applied]")
-
-		// we usually close here, but instead of closing, just returin an error
-		// if MapScan failed. Although Close just returns err, using Close
-		// here might be confusing as we are not actually closing the iter
-		return applied, iter, iter.err
-	} else {
+	iter = s.executeBatch(batch)
+	if err := iter.checkErrAndNotFound(); err != nil {
+		iter.Close()
 		return false, nil, err
 	}
+	iter.MapScan(dest)
+	applied = dest["[applied]"].(bool)
+	delete(dest, "[applied]")
+
+	// we usually close here, but instead of closing, just returin an error
+	// if MapScan failed. Although Close just returns err, using Close
+	// here might be confusing as we are not actually closing the iter
+	return applied, iter, iter.err
 }
 
 func (s *Session) connect(addr string, errorHandler ConnErrorHandler) (*Conn, error) {