Просмотр исходного кода

query/batch: improve context usage (#1228)

This also fixes an issue where using WithContext mutated the original
query, it now does what http.Request.WithContext does and returns a
shallow copy of the query.

Rename GetContext() to Context() and have it default o returning
context.Background() when no context has been set.

Use context to cancel speculated queries so they dont leak or run beyond
when one wins.

Fix hiding errors when speculating queries.
Chris Bannister 7 лет назад
Родитель
Сommit
271c061c7f
9 измененных файлов с 169 добавлено и 197 удалено
  1. 16 9
      cassandra_test.go
  2. 26 25
      conn.go
  3. 8 37
      conn_test.go
  4. 3 3
      control.go
  5. 4 0
      frame.go
  6. 3 2
      host_source.go
  7. 1 1
      policies.go
  8. 58 82
      query_executor.go
  9. 50 38
      session.go

+ 16 - 9
cassandra_test.go

@@ -264,7 +264,7 @@ func TestObserve_Pagination(t *testing.T) {
 		Iter().Scanner()
 		Iter().Scanner()
 	for i := 0; i < 50; i++ {
 	for i := 0; i < 50; i++ {
 		if !scanner.Next() {
 		if !scanner.Next() {
-			t.Fatalf("next: should still be true: %d", i)
+			t.Fatalf("next: should still be true: %d: %v", i, scanner.Err())
 		}
 		}
 		if i%10 == 0 {
 		if i%10 == 0 {
 			if observedRows != 10 {
 			if observedRows != 10 {
@@ -1354,15 +1354,15 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 }
 }
 
 
 func TestPrepare_MissingSchemaPrepare(t *testing.T) {
 func TestPrepare_MissingSchemaPrepare(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	s := createSession(t)
 	s := createSession(t)
 	conn := getRandomConn(t, s)
 	conn := getRandomConn(t, s)
 	defer s.Close()
 	defer s.Close()
 
 
-	insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons,
-		session: s, pageSize: s.pageSize, trace: s.trace,
-		prefetch: s.prefetch, rt: s.cfg.RetryPolicy}
-
-	if err := conn.executeQuery(insertQry).err; err == nil {
+	insertQry := s.Query("INSERT INTO invalidschemaprep (val) VALUES (?)", 5)
+	if err := conn.executeQuery(ctx, insertQry).err; err == nil {
 		t.Fatal("expected error, but got nil.")
 		t.Fatal("expected error, but got nil.")
 	}
 	}
 
 
@@ -1370,22 +1370,29 @@ func TestPrepare_MissingSchemaPrepare(t *testing.T) {
 		t.Fatal("create table:", err)
 		t.Fatal("create table:", err)
 	}
 	}
 
 
-	if err := conn.executeQuery(insertQry).err; err != nil {
+	if err := conn.executeQuery(ctx, insertQry).err; err != nil {
 		t.Fatal(err) // unconfigured columnfamily
 		t.Fatal(err) // unconfigured columnfamily
 	}
 	}
 }
 }
 
 
 func TestPrepare_ReprepareStatement(t *testing.T) {
 func TestPrepare_ReprepareStatement(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	session := createSession(t)
 	session := createSession(t)
 	defer session.Close()
 	defer session.Close()
+
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement")
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement")
 	query := session.Query(stmt, "bar")
 	query := session.Query(stmt, "bar")
-	if err := conn.executeQuery(query).Close(); err != nil {
+	if err := conn.executeQuery(ctx, query).Close(); err != nil {
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 	}
 	}
 }
 }
 
 
 func TestPrepare_ReprepareBatch(t *testing.T) {
 func TestPrepare_ReprepareBatch(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	session := createSession(t)
 	session := createSession(t)
 	defer session.Close()
 	defer session.Close()
 
 
@@ -1396,7 +1403,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
 	batch := session.NewBatch(UnloggedBatch)
 	batch := session.NewBatch(UnloggedBatch)
 	batch.Query(stmt, "bar")
 	batch.Query(stmt, "bar")
-	if err := conn.executeBatch(batch).Close(); err != nil {
+	if err := conn.executeBatch(ctx, batch).Close(); err != nil {
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 	}
 	}
 }
 }

+ 26 - 25
conn.go

@@ -964,7 +964,7 @@ func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error
 	return nil
 	return nil
 }
 }
 
 
-func (c *Conn) executeQuery(qry *Query) *Iter {
+func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
 	params := queryParams{
 	params := queryParams{
 		consistency: qry.cons,
 		consistency: qry.cons,
 	}
 	}
@@ -992,7 +992,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	if qry.shouldPrepare() {
 	if qry.shouldPrepare() {
 		// Prepare all DML queries. Other queries can not be prepared.
 		// Prepare all DML queries. Other queries can not be prepared.
 		var err error
 		var err error
-		info, err = c.prepareStatement(qry.context, qry.stmt, qry.trace)
+		info, err = c.prepareStatement(ctx, qry.stmt, qry.trace)
 		if err != nil {
 		if err != nil {
 			return &Iter{err: err}
 			return &Iter{err: err}
 		}
 		}
@@ -1043,7 +1043,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		}
 		}
 	}
 	}
 
 
-	framer, err := c.exec(qry.context, frame, qry.trace)
+	framer, err := c.exec(ctx, frame, qry.trace)
 	if err != nil {
 	if err != nil {
 		return &Iter{err: err}
 		return &Iter{err: err}
 	}
 	}
@@ -1070,7 +1070,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		if params.skipMeta {
 		if params.skipMeta {
 			if info != nil {
 			if info != nil {
 				iter.meta = info.response
 				iter.meta = info.response
-				iter.meta.pagingState = x.meta.pagingState
+				iter.meta.pagingState = copyBytes(x.meta.pagingState)
 			} else {
 			} else {
 				return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")}
 				return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")}
 			}
 			}
@@ -1078,11 +1078,10 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 			iter.meta = x.meta
 			iter.meta = x.meta
 		}
 		}
 
 
-		if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
+		if x.meta.morePages() && !qry.disableAutoPage {
 			iter.next = &nextIter{
 			iter.next = &nextIter{
-				qry:  *qry,
-				pos:  int((1 - qry.prefetch) * float64(x.numRows)),
-				conn: c,
+				qry: qry,
+				pos: int((1 - qry.prefetch) * float64(x.numRows)),
 			}
 			}
 
 
 			iter.next.qry.pageState = copyBytes(x.meta.pagingState)
 			iter.next.qry.pageState = copyBytes(x.meta.pagingState)
@@ -1096,7 +1095,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		return &Iter{framer: framer}
 		return &Iter{framer: framer}
 	case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType:
 	case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType:
 		iter := &Iter{framer: framer}
 		iter := &Iter{framer: framer}
-		if err := c.awaitSchemaAgreement(); err != nil {
+		if err := c.awaitSchemaAgreement(ctx); err != nil {
 			// TODO: should have this behind a flag
 			// TODO: should have this behind a flag
 			Logger.Println(err)
 			Logger.Println(err)
 		}
 		}
@@ -1107,7 +1106,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case *RequestErrUnprepared:
 	case *RequestErrUnprepared:
 		stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
 		stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
 		if c.session.stmtsLRU.remove(stmtCacheKey) {
 		if c.session.stmtsLRU.remove(stmtCacheKey) {
-			return c.executeQuery(qry)
+			return c.executeQuery(ctx, qry)
 		}
 		}
 
 
 		return &Iter{err: x, framer: framer}
 		return &Iter{err: x, framer: framer}
@@ -1167,7 +1166,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	return nil
 	return nil
 }
 }
 
 
-func (c *Conn) executeBatch(batch *Batch) *Iter {
+func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
 	if c.version == protoVersion1 {
 	if c.version == protoVersion1 {
 		return &Iter{err: ErrUnsupported}
 		return &Iter{err: ErrUnsupported}
 	}
 	}
@@ -1190,7 +1189,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 		b := &req.statements[i]
 		b := &req.statements[i]
 
 
 		if len(entry.Args) > 0 || entry.binding != nil {
 		if len(entry.Args) > 0 || entry.binding != nil {
-			info, err := c.prepareStatement(batch.context, entry.Stmt, nil)
+			info, err := c.prepareStatement(batch.Context(), entry.Stmt, nil)
 			if err != nil {
 			if err != nil {
 				return &Iter{err: err}
 				return &Iter{err: err}
 			}
 			}
@@ -1233,7 +1232,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 	}
 	}
 
 
 	// TODO: should batch support tracing?
 	// TODO: should batch support tracing?
-	framer, err := c.exec(batch.context, req, nil)
+	framer, err := c.exec(batch.Context(), req, nil)
 	if err != nil {
 	if err != nil {
 		return &Iter{err: err}
 		return &Iter{err: err}
 	}
 	}
@@ -1254,7 +1253,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 		}
 		}
 
 
 		if found {
 		if found {
-			return c.executeBatch(batch)
+			return c.executeBatch(ctx, batch)
 		} else {
 		} else {
 			return &Iter{err: x, framer: framer}
 			return &Iter{err: x, framer: framer}
 		}
 		}
@@ -1273,13 +1272,13 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 	}
 	}
 }
 }
 
 
-func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) {
+func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
 	q := c.session.Query(statement, values...).Consistency(One)
 	q := c.session.Query(statement, values...).Consistency(One)
 	q.trace = nil
 	q.trace = nil
-	return c.executeQuery(q)
+	return c.executeQuery(ctx, q)
 }
 }
 
 
-func (c *Conn) awaitSchemaAgreement() (err error) {
+func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
 	const (
 	const (
 		peerSchemas  = "SELECT schema_version, peer FROM system.peers"
 		peerSchemas  = "SELECT schema_version, peer FROM system.peers"
 		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
 		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
@@ -1289,7 +1288,7 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 
 
 	endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
 	endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
 	for time.Now().Before(endDeadline) {
 	for time.Now().Before(endDeadline) {
-		iter := c.query(peerSchemas)
+		iter := c.query(ctx, peerSchemas)
 
 
 		versions = make(map[string]struct{})
 		versions = make(map[string]struct{})
 
 
@@ -1309,7 +1308,7 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 			goto cont
 			goto cont
 		}
 		}
 
 
-		iter = c.query(localSchemas)
+		iter = c.query(ctx, localSchemas)
 		for iter.Scan(&schemaVersion) {
 		for iter.Scan(&schemaVersion) {
 			versions[schemaVersion] = struct{}{}
 			versions[schemaVersion] = struct{}{}
 			schemaVersion = ""
 			schemaVersion = ""
@@ -1324,11 +1323,15 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 		}
 		}
 
 
 	cont:
 	cont:
-		time.Sleep(200 * time.Millisecond)
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		case <-time.After(200 * time.Millisecond):
+		}
 	}
 	}
 
 
 	if err != nil {
 	if err != nil {
-		return
+		return err
 	}
 	}
 
 
 	schemas := make([]string, 0, len(versions))
 	schemas := make([]string, 0, len(versions))
@@ -1340,10 +1343,8 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 	return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
 	return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
 }
 }
 
 
-const localHostInfo = "SELECT * FROM system.local WHERE key='local'"
-
-func (c *Conn) localHostInfo() (*HostInfo, error) {
-	row, err := c.query(localHostInfo).rowMap()
+func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) {
+	row, err := c.query(ctx, "SELECT * FROM system.local WHERE key='local'").rowMap()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 8 - 37
conn_test.go

@@ -299,7 +299,7 @@ func TestCancel(t *testing.T) {
 	}
 	}
 	defer db.Close()
 	defer db.Close()
 
 
-	qry := db.Query("timeout")
+	qry := db.Query("timeout").WithContext(ctx)
 
 
 	// Make sure we finish the query without leftovers
 	// Make sure we finish the query without leftovers
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
@@ -313,7 +313,7 @@ func TestCancel(t *testing.T) {
 	}()
 	}()
 
 
 	// The query will timeout after about 1 seconds, so cancel it after a short pause
 	// The query will timeout after about 1 seconds, so cancel it after a short pause
-	time.AfterFunc(20*time.Millisecond, qry.Cancel)
+	time.AfterFunc(20*time.Millisecond, cancel)
 	wg.Wait()
 	wg.Wait()
 }
 }
 
 
@@ -780,41 +780,11 @@ func TestStream0(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestConnClosedBlocked(t *testing.T) {
-	t.Skip("FLAKE: skipping test flake see https://github.com/gocql/gocql/issues/1088")
-	// issue 664
-	const proto = 3
-
-	srv := NewTestServer(t, proto, context.Background())
-	defer srv.Stop()
-	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
-		t.Log(err)
-	})
-
-	s, err := srv.session()
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer s.Close()
-
-	conn, err := s.connect(srv.host(), errorHandler)
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	if err := conn.conn.Close(); err != nil {
-		t.Fatal(err)
-	}
-
-	// This will block indefintaly if #664 is not fixed
-	err = conn.executeQuery(&Query{stmt: "void"}).Close()
-	if !strings.HasSuffix(err.Error(), "use of closed network connection") {
-		t.Fatalf("expected to get use of closed networking connection error got: %v\n", err)
-	}
-}
-
 func TestContext_Timeout(t *testing.T) {
 func TestContext_Timeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto, context.Background())
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 	defer srv.Stop()
 
 
 	cluster := testCluster(defaultProto, srv.Address)
 	cluster := testCluster(defaultProto, srv.Address)
@@ -825,8 +795,9 @@ func TestContext_Timeout(t *testing.T) {
 	}
 	}
 	defer db.Close()
 	defer db.Close()
 
 
-	ctx, cancel := context.WithCancel(context.Background())
+	ctx, cancel = context.WithCancel(ctx)
 	cancel()
 	cancel()
+
 	err = db.Query("timeout").WithContext(ctx).Exec()
 	err = db.Query("timeout").WithContext(ctx).Exec()
 	if err != context.Canceled {
 	if err != context.Canceled {
 		t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
 		t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)

+ 3 - 3
control.go

@@ -271,7 +271,7 @@ func (c *controlConn) setupConn(conn *Conn) error {
 
 
 	// TODO(zariel): do we need to fetch host info everytime
 	// TODO(zariel): do we need to fetch host info everytime
 	// the control conn connects? Surely we have it cached?
 	// the control conn connects? Surely we have it cached?
-	host, err := conn.localHostInfo()
+	host, err := conn.localHostInfo(context.TODO())
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -446,7 +446,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 
 
 	for {
 	for {
 		iter = c.withConn(func(conn *Conn) *Iter {
 		iter = c.withConn(func(conn *Conn) *Iter {
-			return conn.executeQuery(q)
+			return conn.executeQuery(context.TODO(), q)
 		})
 		})
 
 
 		if gocqlDebug && iter.err != nil {
 		if gocqlDebug && iter.err != nil {
@@ -464,7 +464,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 
 
 func (c *controlConn) awaitSchemaAgreement() error {
 func (c *controlConn) awaitSchemaAgreement() error {
 	return c.withConn(func(conn *Conn) *Iter {
 	return c.withConn(func(conn *Conn) *Iter {
-		return &Iter{err: conn.awaitSchemaAgreement()}
+		return &Iter{err: conn.awaitSchemaAgreement(context.TODO())}
 	}).err
 	}).err
 }
 }
 
 

+ 4 - 0
frame.go

@@ -1011,6 +1011,10 @@ type resultMetadata struct {
 	actualColCount int
 	actualColCount int
 }
 }
 
 
+func (r *resultMetadata) morePages() bool {
+	return r.flags&flagHasMorePages == flagHasMorePages
+}
+
 func (r resultMetadata) String() string {
 func (r resultMetadata) String() string {
 	return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns)
 	return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns)
 }
 }

+ 3 - 2
host_source.go

@@ -1,6 +1,7 @@
 package gocql
 package gocql
 
 
 import (
 import (
+	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
@@ -555,7 +556,7 @@ func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
 	var hosts []*HostInfo
 	var hosts []*HostInfo
 	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
 	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
 		hosts = append(hosts, ch.host)
 		hosts = append(hosts, ch.host)
-		return ch.conn.query("SELECT * FROM system.peers")
+		return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
 	})
 	})
 
 
 	if iter == nil {
 	if iter == nil {
@@ -622,7 +623,7 @@ func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
 			return nil
 			return nil
 		}
 		}
 
 
-		return ch.conn.query("SELECT * FROM system.peers")
+		return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
 	})
 	})
 
 
 	if iter != nil {
 	if iter != nil {

+ 1 - 1
policies.go

@@ -132,7 +132,7 @@ type RetryableQuery interface {
 	Attempts() int
 	Attempts() int
 	SetConsistency(c Consistency)
 	SetConsistency(c Consistency)
 	GetConsistency() Consistency
 	GetConsistency() Consistency
-	GetContext() context.Context
+	Context() context.Context
 }
 }
 
 
 type RetryType uint16
 type RetryType uint16

+ 58 - 82
query_executor.go

@@ -1,19 +1,21 @@
 package gocql
 package gocql
 
 
 import (
 import (
-	"sync"
+	"context"
 	"time"
 	"time"
 )
 )
 
 
 type ExecutableQuery interface {
 type ExecutableQuery interface {
-	execute(conn *Conn) *Iter
+	execute(ctx context.Context, conn *Conn) *Iter
 	attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
 	attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
 	retryPolicy() RetryPolicy
 	retryPolicy() RetryPolicy
 	speculativeExecutionPolicy() SpeculativeExecutionPolicy
 	speculativeExecutionPolicy() SpeculativeExecutionPolicy
 	GetRoutingKey() ([]byte, error)
 	GetRoutingKey() ([]byte, error)
 	Keyspace() string
 	Keyspace() string
-	Cancel()
 	IsIdempotent() bool
 	IsIdempotent() bool
+
+	withContext(context.Context) ExecutableQuery
+
 	RetryableQuery
 	RetryableQuery
 }
 }
 
 
@@ -22,14 +24,9 @@ type queryExecutor struct {
 	policy HostSelectionPolicy
 	policy HostSelectionPolicy
 }
 }
 
 
-type queryResponse struct {
-	iter *Iter
-	err  error
-}
-
-func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
+func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter {
 	start := time.Now()
 	start := time.Now()
-	iter := qry.execute(conn)
+	iter := qry.execute(ctx, conn)
 	end := time.Now()
 	end := time.Now()
 
 
 	qry.attempt(q.pool.keyspace, end, start, iter, conn.host)
 	qry.attempt(q.pool.keyspace, end, start, iter, conn.host)
@@ -38,7 +35,6 @@ func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
 }
 }
 
 
 func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
-
 	// check if the query is not marked as idempotent, if
 	// check if the query is not marked as idempotent, if
 	// it is, we force the policy to NonSpeculative
 	// it is, we force the policy to NonSpeculative
 	sp := qry.speculativeExecutionPolicy()
 	sp := qry.speculativeExecutionPolicy()
@@ -46,27 +42,18 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 		sp = NonSpeculativeExecution{}
 		sp = NonSpeculativeExecution{}
 	}
 	}
 
 
-	results := make(chan queryResponse, 1)
-	stop := make(chan struct{})
-	defer close(stop)
-	var specWG sync.WaitGroup
+	ctx, cancel := context.WithCancel(qry.Context())
+	defer cancel()
+
+	results := make(chan *Iter, 1)
 
 
 	// Launch the main execution
 	// Launch the main execution
-	specWG.Add(1)
-	go q.run(qry, &specWG, results, stop)
+	go q.run(ctx, qry, results)
 
 
 	// The speculative executions are launched _in addition_ to the main
 	// The speculative executions are launched _in addition_ to the main
 	// execution, on a timer. So Speculation{2} would make 3 executions running
 	// execution, on a timer. So Speculation{2} would make 3 executions running
 	// in total.
 	// in total.
 	go func() {
 	go func() {
-		// Handle the closing of the resources. We do it here because it's
-		// right after we finish launching executions. Otherwise clearing the
-		// wait group is complicated.
-		defer func() {
-			specWG.Wait()
-			close(results)
-		}()
-
 		// setup a ticker
 		// setup a ticker
 		ticker := time.NewTicker(sp.Delay())
 		ticker := time.NewTicker(sp.Delay())
 		defer ticker.Stop()
 		defer ticker.Stop()
@@ -75,34 +62,28 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 			select {
 			select {
 			case <-ticker.C:
 			case <-ticker.C:
 				// Launch the additional execution
 				// Launch the additional execution
-				specWG.Add(1)
-				go q.run(qry, &specWG, results, stop)
-			case <-qry.GetContext().Done():
-				// not starting additional executions
-				return
-			case <-stop:
+				go q.run(ctx, qry, results)
+			case <-ctx.Done():
 				// not starting additional executions
 				// not starting additional executions
 				return
 				return
 			}
 			}
 		}
 		}
 	}()
 	}()
 
 
-	res := <-results
-	if res.iter == nil && res.err == nil {
-		// if we're here, the results channel was closed, so no more hosts
-		return nil, ErrNoConnections
+	select {
+	case iter := <-results:
+		return iter, nil
+	case <-ctx.Done():
+		return &Iter{err: ctx.Err()}, nil
 	}
 	}
-	return res.iter, res.err
 }
 }
 
 
-func (q *queryExecutor) run(qry ExecutableQuery, specWG *sync.WaitGroup, results chan queryResponse, stop chan struct{}) {
-	// Handle the wait group
-	defer specWG.Done()
-
+func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
 	hostIter := q.policy.Pick(qry)
 	hostIter := q.policy.Pick(qry)
 	selectedHost := hostIter()
 	selectedHost := hostIter()
 	rt := qry.retryPolicy()
 	rt := qry.retryPolicy()
 
 
+	var lastErr error
 	var iter *Iter
 	var iter *Iter
 	for selectedHost != nil {
 	for selectedHost != nil {
 		host := selectedHost.Info()
 		host := selectedHost.Info()
@@ -123,50 +104,45 @@ func (q *queryExecutor) run(qry ExecutableQuery, specWG *sync.WaitGroup, results
 			continue
 			continue
 		}
 		}
 
 
-		select {
-		case <-stop:
-			// stop this execution and return
-			return
-		default:
-			// Run the query
-			iter = q.attemptQuery(qry, conn)
-			iter.host = selectedHost.Info()
-			// Update host
-			selectedHost.Mark(iter.err)
-
-			// Exit if the query was successful
-			// or no retry policy defined or retry attempts were reached
-			if iter.err == nil || rt == nil || !rt.Attempt(qry) {
-				results <- queryResponse{iter: iter}
-				return
-			}
+		iter = q.attemptQuery(ctx, qry, conn)
+		iter.host = selectedHost.Info()
+		// Update host
+		selectedHost.Mark(iter.err)
 
 
-			// If query is unsuccessful, check the error with RetryPolicy to retry
-			switch rt.GetRetryType(iter.err) {
-			case Retry:
-				// retry on the same host
-				continue
-			case Rethrow:
-				results <- queryResponse{err: iter.err}
-				return
-			case Ignore:
-				results <- queryResponse{iter: iter}
-				return
-			case RetryNextHost:
-				// retry on the next host
-				selectedHost = hostIter()
-				if selectedHost == nil {
-					results <- queryResponse{iter: iter}
-					return
-				}
-				continue
-			default:
-				// Undefined? Return nil and error, this will panic in the requester
-				results <- queryResponse{iter: nil, err: ErrUnknownRetryType}
-				return
-			}
+		// Exit if the query was successful
+		// or no retry policy defined or retry attempts were reached
+		if iter.err == nil || rt == nil || !rt.Attempt(qry) {
+			return iter
+		}
+		lastErr = iter.err
+
+		// If query is unsuccessful, check the error with RetryPolicy to retry
+		switch rt.GetRetryType(iter.err) {
+		case Retry:
+			// retry on the same host
+			continue
+		case Rethrow, Ignore:
+			return iter
+		case RetryNextHost:
+			// retry on the next host
+			selectedHost = hostIter()
+			continue
+		default:
+			// Undefined? Return nil and error, this will panic in the requester
+			return &Iter{err: ErrUnknownRetryType}
 		}
 		}
+	}
+
+	if lastErr != nil {
+		return &Iter{err: lastErr}
+	}
+
+	return &Iter{err: ErrNoConnections}
+}
 
 
+func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan *Iter) {
+	select {
+	case results <- q.do(ctx, qry):
+	case <-ctx.Done():
 	}
 	}
-	// All hosts are exhausted, return nothing
 }
 }

+ 50 - 38
session.go

@@ -570,8 +570,8 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
 	return routingKeyInfo, nil
 	return routingKeyInfo, nil
 }
 }
 
 
-func (b *Batch) execute(conn *Conn) *Iter {
-	return conn.executeBatch(b)
+func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter {
+	return conn.executeBatch(ctx, b)
 }
 }
 
 
 func (s *Session) executeBatch(batch *Batch) *Iter {
 func (s *Session) executeBatch(batch *Batch) *Iter {
@@ -689,7 +689,6 @@ type Query struct {
 	defaultTimestampValue int64
 	defaultTimestampValue int64
 	disableSkipMetadata   bool
 	disableSkipMetadata   bool
 	context               context.Context
 	context               context.Context
-	cancelQuery           func()
 	idempotent            bool
 	idempotent            bool
 	customPayload         map[string][]byte
 	customPayload         map[string][]byte
 	metrics               *queryMetrics
 	metrics               *queryMetrics
@@ -712,9 +711,6 @@ func (q *Query) defaultsFromSession() {
 	q.idempotent = s.cfg.DefaultIdempotence
 	q.idempotent = s.cfg.DefaultIdempotence
 	q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
 	q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
 
 
-	// Initiate an empty context with a cancel call
-	q.WithContext(context.Background())
-
 	q.spec = &NonSpeculativeExecution{}
 	q.spec = &NonSpeculativeExecution{}
 	s.mu.RUnlock()
 	s.mu.RUnlock()
 }
 }
@@ -808,7 +804,10 @@ func (q *Query) CustomPayload(customPayload map[string][]byte) *Query {
 	return q
 	return q
 }
 }
 
 
-func (q *Query) GetContext() context.Context {
+func (q *Query) Context() context.Context {
+	if q.context == nil {
+		return context.Background()
+	}
 	return q.context
 	return q.context
 }
 }
 
 
@@ -865,20 +864,30 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 	return q
 	return q
 }
 }
 
 
-// WithContext will set the context to use during a query, it will be used to
-// timeout when waiting for responses from Cassandra. Additionally it adds
-// the cancel function so that it can be called whenever necessary.
+func (q *Query) withContext(ctx context.Context) ExecutableQuery {
+	// I really wish go had covariant types
+	return q.WithContext(ctx)
+}
+
+// WithContext returns a shallow copy of q with its context
+// set to ctx.
+//
+// The provided context controls the entire lifetime of executing a
+// query, queries will be canceled and return once the context is
+// canceled.
 func (q *Query) WithContext(ctx context.Context) *Query {
 func (q *Query) WithContext(ctx context.Context) *Query {
-	q.context, q.cancelQuery = context.WithCancel(ctx)
-	return q
+	q2 := *q
+	q2.context = ctx
+	return &q2
 }
 }
 
 
+// Deprecate: does nothing, cancel the context passed to WithContext
 func (q *Query) Cancel() {
 func (q *Query) Cancel() {
-	q.cancelQuery()
+	// TODO: delete
 }
 }
 
 
-func (q *Query) execute(conn *Conn) *Iter {
-	return conn.executeQuery(q)
+func (q *Query) execute(ctx context.Context, conn *Conn) *Iter {
+	return conn.executeQuery(ctx, q)
 }
 }
 
 
 func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
 func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
@@ -886,7 +895,7 @@ func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host
 	q.AddLatency(end.Sub(start).Nanoseconds(), host)
 	q.AddLatency(end.Sub(start).Nanoseconds(), host)
 
 
 	if q.observer != nil {
 	if q.observer != nil {
-		q.observer.ObserveQuery(q.context, ObservedQuery{
+		q.observer.ObserveQuery(q.Context(), ObservedQuery{
 			Keyspace:  keyspace,
 			Keyspace:  keyspace,
 			Statement: q.stmt,
 			Statement: q.stmt,
 			Start:     start,
 			Start:     start,
@@ -930,7 +939,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
 	}
 	}
 
 
 	// try to determine the routing key
 	// try to determine the routing key
-	routingKeyInfo, err := q.session.routingKeyInfo(q.context, q.stmt)
+	routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -1351,7 +1360,7 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 		return false
 		return false
 	}
 	}
 
 
-	if iter.next != nil && iter.pos == iter.next.pos {
+	if iter.next != nil && iter.pos >= iter.next.pos {
 		go iter.next.fetch()
 		go iter.next.fetch()
 	}
 	}
 
 
@@ -1447,21 +1456,15 @@ func (iter *Iter) NumRows() int {
 }
 }
 
 
 type nextIter struct {
 type nextIter struct {
-	qry  Query
+	qry  *Query
 	pos  int
 	pos  int
 	once sync.Once
 	once sync.Once
 	next *Iter
 	next *Iter
-	conn *Conn
 }
 }
 
 
 func (n *nextIter) fetch() *Iter {
 func (n *nextIter) fetch() *Iter {
 	n.once.Do(func() {
 	n.once.Do(func() {
-		iter := n.qry.session.executor.attemptQuery(&n.qry, n.conn)
-		if iter != nil && iter.err == nil {
-			n.next = iter
-		} else {
-			n.next = n.qry.session.executeQuery(&n.qry)
-		}
+		n.next = n.qry.session.executeQuery(n.qry)
 	})
 	})
 	return n.next
 	return n.next
 }
 }
@@ -1509,9 +1512,6 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
 		spec:             &NonSpeculativeExecution{},
 		spec:             &NonSpeculativeExecution{},
 	}
 	}
 
 
-	// Initiate an empty context with a cancel call
-	batch.WithContext(context.Background())
-
 	s.mu.RUnlock()
 	s.mu.RUnlock()
 	return batch
 	return batch
 }
 }
@@ -1597,7 +1597,10 @@ func (b *Batch) SetConsistency(c Consistency) {
 	b.Cons = c
 	b.Cons = c
 }
 }
 
 
-func (b *Batch) GetContext() context.Context {
+func (b *Batch) Context() context.Context {
+	if b.context == nil {
+		return context.Background()
+	}
 	return b.context
 	return b.context
 }
 }
 
 
@@ -1641,16 +1644,25 @@ func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 	return b
 	return b
 }
 }
 
 
-// WithContext will set the context to use during a query, it will be used to
-// timeout when waiting for responses from Cassandra. Additionally it adds
-// the cancel function so that it can be called whenever necessary.
+func (b *Batch) withContext(ctx context.Context) ExecutableQuery {
+	return b.WithContext(ctx)
+}
+
+// WithContext returns a shallow copy of b with its context
+// set to ctx.
+//
+// The provided context controls the entire lifetime of executing a
+// query, queries will be canceled and return once the context is
+// canceled.
 func (b *Batch) WithContext(ctx context.Context) *Batch {
 func (b *Batch) WithContext(ctx context.Context) *Batch {
-	b.context, b.cancelBatch = context.WithCancel(ctx)
-	return b
+	b2 := *b
+	b2.context = ctx
+	return &b2
 }
 }
 
 
-func (b *Batch) Cancel() {
-	b.cancelBatch()
+// Deprecate: does nothing, cancel the context passed to WithContext
+func (*Batch) Cancel() {
+	// TODO: delete
 }
 }
 
 
 // Size returns the number of batch statements to be executed by the batch operation.
 // Size returns the number of batch statements to be executed by the batch operation.
@@ -1706,7 +1718,7 @@ func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host
 		statements[i] = entry.Stmt
 		statements[i] = entry.Stmt
 	}
 	}
 
 
-	b.observer.ObserveBatch(b.context, ObservedBatch{
+	b.observer.ObserveBatch(b.Context(), ObservedBatch{
 		Keyspace:   keyspace,
 		Keyspace:   keyspace,
 		Statements: statements,
 		Statements: statements,
 		Start:      start,
 		Start:      start,