Browse Source

query/batch: improve context usage (#1228)

This also fixes an issue where using WithContext mutated the original
query, it now does what http.Request.WithContext does and returns a
shallow copy of the query.

Rename GetContext() to Context() and have it default o returning
context.Background() when no context has been set.

Use context to cancel speculated queries so they dont leak or run beyond
when one wins.

Fix hiding errors when speculating queries.
Chris Bannister 7 years ago
parent
commit
271c061c7f
9 changed files with 169 additions and 197 deletions
  1. 16 9
      cassandra_test.go
  2. 26 25
      conn.go
  3. 8 37
      conn_test.go
  4. 3 3
      control.go
  5. 4 0
      frame.go
  6. 3 2
      host_source.go
  7. 1 1
      policies.go
  8. 58 82
      query_executor.go
  9. 50 38
      session.go

+ 16 - 9
cassandra_test.go

@@ -264,7 +264,7 @@ func TestObserve_Pagination(t *testing.T) {
 		Iter().Scanner()
 	for i := 0; i < 50; i++ {
 		if !scanner.Next() {
-			t.Fatalf("next: should still be true: %d", i)
+			t.Fatalf("next: should still be true: %d: %v", i, scanner.Err())
 		}
 		if i%10 == 0 {
 			if observedRows != 10 {
@@ -1354,15 +1354,15 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string
 }
 
 func TestPrepare_MissingSchemaPrepare(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	s := createSession(t)
 	conn := getRandomConn(t, s)
 	defer s.Close()
 
-	insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons,
-		session: s, pageSize: s.pageSize, trace: s.trace,
-		prefetch: s.prefetch, rt: s.cfg.RetryPolicy}
-
-	if err := conn.executeQuery(insertQry).err; err == nil {
+	insertQry := s.Query("INSERT INTO invalidschemaprep (val) VALUES (?)", 5)
+	if err := conn.executeQuery(ctx, insertQry).err; err == nil {
 		t.Fatal("expected error, but got nil.")
 	}
 
@@ -1370,22 +1370,29 @@ func TestPrepare_MissingSchemaPrepare(t *testing.T) {
 		t.Fatal("create table:", err)
 	}
 
-	if err := conn.executeQuery(insertQry).err; err != nil {
+	if err := conn.executeQuery(ctx, insertQry).err; err != nil {
 		t.Fatal(err) // unconfigured columnfamily
 	}
 }
 
 func TestPrepare_ReprepareStatement(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	session := createSession(t)
 	defer session.Close()
+
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement")
 	query := session.Query(stmt, "bar")
-	if err := conn.executeQuery(query).Close(); err != nil {
+	if err := conn.executeQuery(ctx, query).Close(); err != nil {
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 	}
 }
 
 func TestPrepare_ReprepareBatch(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
 	session := createSession(t)
 	defer session.Close()
 
@@ -1396,7 +1403,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
 	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
 	batch := session.NewBatch(UnloggedBatch)
 	batch.Query(stmt, "bar")
-	if err := conn.executeBatch(batch).Close(); err != nil {
+	if err := conn.executeBatch(ctx, batch).Close(); err != nil {
 		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
 	}
 }

+ 26 - 25
conn.go

@@ -964,7 +964,7 @@ func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error
 	return nil
 }
 
-func (c *Conn) executeQuery(qry *Query) *Iter {
+func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
 	params := queryParams{
 		consistency: qry.cons,
 	}
@@ -992,7 +992,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	if qry.shouldPrepare() {
 		// Prepare all DML queries. Other queries can not be prepared.
 		var err error
-		info, err = c.prepareStatement(qry.context, qry.stmt, qry.trace)
+		info, err = c.prepareStatement(ctx, qry.stmt, qry.trace)
 		if err != nil {
 			return &Iter{err: err}
 		}
@@ -1043,7 +1043,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		}
 	}
 
-	framer, err := c.exec(qry.context, frame, qry.trace)
+	framer, err := c.exec(ctx, frame, qry.trace)
 	if err != nil {
 		return &Iter{err: err}
 	}
@@ -1070,7 +1070,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		if params.skipMeta {
 			if info != nil {
 				iter.meta = info.response
-				iter.meta.pagingState = x.meta.pagingState
+				iter.meta.pagingState = copyBytes(x.meta.pagingState)
 			} else {
 				return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")}
 			}
@@ -1078,11 +1078,10 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 			iter.meta = x.meta
 		}
 
-		if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
+		if x.meta.morePages() && !qry.disableAutoPage {
 			iter.next = &nextIter{
-				qry:  *qry,
-				pos:  int((1 - qry.prefetch) * float64(x.numRows)),
-				conn: c,
+				qry: qry,
+				pos: int((1 - qry.prefetch) * float64(x.numRows)),
 			}
 
 			iter.next.qry.pageState = copyBytes(x.meta.pagingState)
@@ -1096,7 +1095,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		return &Iter{framer: framer}
 	case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType:
 		iter := &Iter{framer: framer}
-		if err := c.awaitSchemaAgreement(); err != nil {
+		if err := c.awaitSchemaAgreement(ctx); err != nil {
 			// TODO: should have this behind a flag
 			Logger.Println(err)
 		}
@@ -1107,7 +1106,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	case *RequestErrUnprepared:
 		stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
 		if c.session.stmtsLRU.remove(stmtCacheKey) {
-			return c.executeQuery(qry)
+			return c.executeQuery(ctx, qry)
 		}
 
 		return &Iter{err: x, framer: framer}
@@ -1167,7 +1166,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	return nil
 }
 
-func (c *Conn) executeBatch(batch *Batch) *Iter {
+func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
 	if c.version == protoVersion1 {
 		return &Iter{err: ErrUnsupported}
 	}
@@ -1190,7 +1189,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 		b := &req.statements[i]
 
 		if len(entry.Args) > 0 || entry.binding != nil {
-			info, err := c.prepareStatement(batch.context, entry.Stmt, nil)
+			info, err := c.prepareStatement(batch.Context(), entry.Stmt, nil)
 			if err != nil {
 				return &Iter{err: err}
 			}
@@ -1233,7 +1232,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 	}
 
 	// TODO: should batch support tracing?
-	framer, err := c.exec(batch.context, req, nil)
+	framer, err := c.exec(batch.Context(), req, nil)
 	if err != nil {
 		return &Iter{err: err}
 	}
@@ -1254,7 +1253,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 		}
 
 		if found {
-			return c.executeBatch(batch)
+			return c.executeBatch(ctx, batch)
 		} else {
 			return &Iter{err: x, framer: framer}
 		}
@@ -1273,13 +1272,13 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 	}
 }
 
-func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) {
+func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
 	q := c.session.Query(statement, values...).Consistency(One)
 	q.trace = nil
-	return c.executeQuery(q)
+	return c.executeQuery(ctx, q)
 }
 
-func (c *Conn) awaitSchemaAgreement() (err error) {
+func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
 	const (
 		peerSchemas  = "SELECT schema_version, peer FROM system.peers"
 		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
@@ -1289,7 +1288,7 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 
 	endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
 	for time.Now().Before(endDeadline) {
-		iter := c.query(peerSchemas)
+		iter := c.query(ctx, peerSchemas)
 
 		versions = make(map[string]struct{})
 
@@ -1309,7 +1308,7 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 			goto cont
 		}
 
-		iter = c.query(localSchemas)
+		iter = c.query(ctx, localSchemas)
 		for iter.Scan(&schemaVersion) {
 			versions[schemaVersion] = struct{}{}
 			schemaVersion = ""
@@ -1324,11 +1323,15 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 		}
 
 	cont:
-		time.Sleep(200 * time.Millisecond)
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		case <-time.After(200 * time.Millisecond):
+		}
 	}
 
 	if err != nil {
-		return
+		return err
 	}
 
 	schemas := make([]string, 0, len(versions))
@@ -1340,10 +1343,8 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 	return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
 }
 
-const localHostInfo = "SELECT * FROM system.local WHERE key='local'"
-
-func (c *Conn) localHostInfo() (*HostInfo, error) {
-	row, err := c.query(localHostInfo).rowMap()
+func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) {
+	row, err := c.query(ctx, "SELECT * FROM system.local WHERE key='local'").rowMap()
 	if err != nil {
 		return nil, err
 	}

+ 8 - 37
conn_test.go

@@ -299,7 +299,7 @@ func TestCancel(t *testing.T) {
 	}
 	defer db.Close()
 
-	qry := db.Query("timeout")
+	qry := db.Query("timeout").WithContext(ctx)
 
 	// Make sure we finish the query without leftovers
 	var wg sync.WaitGroup
@@ -313,7 +313,7 @@ func TestCancel(t *testing.T) {
 	}()
 
 	// The query will timeout after about 1 seconds, so cancel it after a short pause
-	time.AfterFunc(20*time.Millisecond, qry.Cancel)
+	time.AfterFunc(20*time.Millisecond, cancel)
 	wg.Wait()
 }
 
@@ -780,41 +780,11 @@ func TestStream0(t *testing.T) {
 	}
 }
 
-func TestConnClosedBlocked(t *testing.T) {
-	t.Skip("FLAKE: skipping test flake see https://github.com/gocql/gocql/issues/1088")
-	// issue 664
-	const proto = 3
-
-	srv := NewTestServer(t, proto, context.Background())
-	defer srv.Stop()
-	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
-		t.Log(err)
-	})
-
-	s, err := srv.session()
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer s.Close()
-
-	conn, err := s.connect(srv.host(), errorHandler)
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	if err := conn.conn.Close(); err != nil {
-		t.Fatal(err)
-	}
-
-	// This will block indefintaly if #664 is not fixed
-	err = conn.executeQuery(&Query{stmt: "void"}).Close()
-	if !strings.HasSuffix(err.Error(), "use of closed network connection") {
-		t.Fatalf("expected to get use of closed networking connection error got: %v\n", err)
-	}
-}
-
 func TestContext_Timeout(t *testing.T) {
-	srv := NewTestServer(t, defaultProto, context.Background())
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
 	defer srv.Stop()
 
 	cluster := testCluster(defaultProto, srv.Address)
@@ -825,8 +795,9 @@ func TestContext_Timeout(t *testing.T) {
 	}
 	defer db.Close()
 
-	ctx, cancel := context.WithCancel(context.Background())
+	ctx, cancel = context.WithCancel(ctx)
 	cancel()
+
 	err = db.Query("timeout").WithContext(ctx).Exec()
 	if err != context.Canceled {
 		t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)

+ 3 - 3
control.go

@@ -271,7 +271,7 @@ func (c *controlConn) setupConn(conn *Conn) error {
 
 	// TODO(zariel): do we need to fetch host info everytime
 	// the control conn connects? Surely we have it cached?
-	host, err := conn.localHostInfo()
+	host, err := conn.localHostInfo(context.TODO())
 	if err != nil {
 		return err
 	}
@@ -446,7 +446,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 
 	for {
 		iter = c.withConn(func(conn *Conn) *Iter {
-			return conn.executeQuery(q)
+			return conn.executeQuery(context.TODO(), q)
 		})
 
 		if gocqlDebug && iter.err != nil {
@@ -464,7 +464,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
 
 func (c *controlConn) awaitSchemaAgreement() error {
 	return c.withConn(func(conn *Conn) *Iter {
-		return &Iter{err: conn.awaitSchemaAgreement()}
+		return &Iter{err: conn.awaitSchemaAgreement(context.TODO())}
 	}).err
 }
 

+ 4 - 0
frame.go

@@ -1011,6 +1011,10 @@ type resultMetadata struct {
 	actualColCount int
 }
 
+func (r *resultMetadata) morePages() bool {
+	return r.flags&flagHasMorePages == flagHasMorePages
+}
+
 func (r resultMetadata) String() string {
 	return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns)
 }

+ 3 - 2
host_source.go

@@ -1,6 +1,7 @@
 package gocql
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"net"
@@ -555,7 +556,7 @@ func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
 	var hosts []*HostInfo
 	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
 		hosts = append(hosts, ch.host)
-		return ch.conn.query("SELECT * FROM system.peers")
+		return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
 	})
 
 	if iter == nil {
@@ -622,7 +623,7 @@ func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
 			return nil
 		}
 
-		return ch.conn.query("SELECT * FROM system.peers")
+		return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
 	})
 
 	if iter != nil {

+ 1 - 1
policies.go

@@ -132,7 +132,7 @@ type RetryableQuery interface {
 	Attempts() int
 	SetConsistency(c Consistency)
 	GetConsistency() Consistency
-	GetContext() context.Context
+	Context() context.Context
 }
 
 type RetryType uint16

+ 58 - 82
query_executor.go

@@ -1,19 +1,21 @@
 package gocql
 
 import (
-	"sync"
+	"context"
 	"time"
 )
 
 type ExecutableQuery interface {
-	execute(conn *Conn) *Iter
+	execute(ctx context.Context, conn *Conn) *Iter
 	attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
 	retryPolicy() RetryPolicy
 	speculativeExecutionPolicy() SpeculativeExecutionPolicy
 	GetRoutingKey() ([]byte, error)
 	Keyspace() string
-	Cancel()
 	IsIdempotent() bool
+
+	withContext(context.Context) ExecutableQuery
+
 	RetryableQuery
 }
 
@@ -22,14 +24,9 @@ type queryExecutor struct {
 	policy HostSelectionPolicy
 }
 
-type queryResponse struct {
-	iter *Iter
-	err  error
-}
-
-func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
+func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter {
 	start := time.Now()
-	iter := qry.execute(conn)
+	iter := qry.execute(ctx, conn)
 	end := time.Now()
 
 	qry.attempt(q.pool.keyspace, end, start, iter, conn.host)
@@ -38,7 +35,6 @@ func (q *queryExecutor) attemptQuery(qry ExecutableQuery, conn *Conn) *Iter {
 }
 
 func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
-
 	// check if the query is not marked as idempotent, if
 	// it is, we force the policy to NonSpeculative
 	sp := qry.speculativeExecutionPolicy()
@@ -46,27 +42,18 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 		sp = NonSpeculativeExecution{}
 	}
 
-	results := make(chan queryResponse, 1)
-	stop := make(chan struct{})
-	defer close(stop)
-	var specWG sync.WaitGroup
+	ctx, cancel := context.WithCancel(qry.Context())
+	defer cancel()
+
+	results := make(chan *Iter, 1)
 
 	// Launch the main execution
-	specWG.Add(1)
-	go q.run(qry, &specWG, results, stop)
+	go q.run(ctx, qry, results)
 
 	// The speculative executions are launched _in addition_ to the main
 	// execution, on a timer. So Speculation{2} would make 3 executions running
 	// in total.
 	go func() {
-		// Handle the closing of the resources. We do it here because it's
-		// right after we finish launching executions. Otherwise clearing the
-		// wait group is complicated.
-		defer func() {
-			specWG.Wait()
-			close(results)
-		}()
-
 		// setup a ticker
 		ticker := time.NewTicker(sp.Delay())
 		defer ticker.Stop()
@@ -75,34 +62,28 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
 			select {
 			case <-ticker.C:
 				// Launch the additional execution
-				specWG.Add(1)
-				go q.run(qry, &specWG, results, stop)
-			case <-qry.GetContext().Done():
-				// not starting additional executions
-				return
-			case <-stop:
+				go q.run(ctx, qry, results)
+			case <-ctx.Done():
 				// not starting additional executions
 				return
 			}
 		}
 	}()
 
-	res := <-results
-	if res.iter == nil && res.err == nil {
-		// if we're here, the results channel was closed, so no more hosts
-		return nil, ErrNoConnections
+	select {
+	case iter := <-results:
+		return iter, nil
+	case <-ctx.Done():
+		return &Iter{err: ctx.Err()}, nil
 	}
-	return res.iter, res.err
 }
 
-func (q *queryExecutor) run(qry ExecutableQuery, specWG *sync.WaitGroup, results chan queryResponse, stop chan struct{}) {
-	// Handle the wait group
-	defer specWG.Done()
-
+func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
 	hostIter := q.policy.Pick(qry)
 	selectedHost := hostIter()
 	rt := qry.retryPolicy()
 
+	var lastErr error
 	var iter *Iter
 	for selectedHost != nil {
 		host := selectedHost.Info()
@@ -123,50 +104,45 @@ func (q *queryExecutor) run(qry ExecutableQuery, specWG *sync.WaitGroup, results
 			continue
 		}
 
-		select {
-		case <-stop:
-			// stop this execution and return
-			return
-		default:
-			// Run the query
-			iter = q.attemptQuery(qry, conn)
-			iter.host = selectedHost.Info()
-			// Update host
-			selectedHost.Mark(iter.err)
-
-			// Exit if the query was successful
-			// or no retry policy defined or retry attempts were reached
-			if iter.err == nil || rt == nil || !rt.Attempt(qry) {
-				results <- queryResponse{iter: iter}
-				return
-			}
+		iter = q.attemptQuery(ctx, qry, conn)
+		iter.host = selectedHost.Info()
+		// Update host
+		selectedHost.Mark(iter.err)
 
-			// If query is unsuccessful, check the error with RetryPolicy to retry
-			switch rt.GetRetryType(iter.err) {
-			case Retry:
-				// retry on the same host
-				continue
-			case Rethrow:
-				results <- queryResponse{err: iter.err}
-				return
-			case Ignore:
-				results <- queryResponse{iter: iter}
-				return
-			case RetryNextHost:
-				// retry on the next host
-				selectedHost = hostIter()
-				if selectedHost == nil {
-					results <- queryResponse{iter: iter}
-					return
-				}
-				continue
-			default:
-				// Undefined? Return nil and error, this will panic in the requester
-				results <- queryResponse{iter: nil, err: ErrUnknownRetryType}
-				return
-			}
+		// Exit if the query was successful
+		// or no retry policy defined or retry attempts were reached
+		if iter.err == nil || rt == nil || !rt.Attempt(qry) {
+			return iter
+		}
+		lastErr = iter.err
+
+		// If query is unsuccessful, check the error with RetryPolicy to retry
+		switch rt.GetRetryType(iter.err) {
+		case Retry:
+			// retry on the same host
+			continue
+		case Rethrow, Ignore:
+			return iter
+		case RetryNextHost:
+			// retry on the next host
+			selectedHost = hostIter()
+			continue
+		default:
+			// Undefined? Return nil and error, this will panic in the requester
+			return &Iter{err: ErrUnknownRetryType}
 		}
+	}
+
+	if lastErr != nil {
+		return &Iter{err: lastErr}
+	}
+
+	return &Iter{err: ErrNoConnections}
+}
 
+func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan *Iter) {
+	select {
+	case results <- q.do(ctx, qry):
+	case <-ctx.Done():
 	}
-	// All hosts are exhausted, return nothing
 }

+ 50 - 38
session.go

@@ -570,8 +570,8 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
 	return routingKeyInfo, nil
 }
 
-func (b *Batch) execute(conn *Conn) *Iter {
-	return conn.executeBatch(b)
+func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter {
+	return conn.executeBatch(ctx, b)
 }
 
 func (s *Session) executeBatch(batch *Batch) *Iter {
@@ -689,7 +689,6 @@ type Query struct {
 	defaultTimestampValue int64
 	disableSkipMetadata   bool
 	context               context.Context
-	cancelQuery           func()
 	idempotent            bool
 	customPayload         map[string][]byte
 	metrics               *queryMetrics
@@ -712,9 +711,6 @@ func (q *Query) defaultsFromSession() {
 	q.idempotent = s.cfg.DefaultIdempotence
 	q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
 
-	// Initiate an empty context with a cancel call
-	q.WithContext(context.Background())
-
 	q.spec = &NonSpeculativeExecution{}
 	s.mu.RUnlock()
 }
@@ -808,7 +804,10 @@ func (q *Query) CustomPayload(customPayload map[string][]byte) *Query {
 	return q
 }
 
-func (q *Query) GetContext() context.Context {
+func (q *Query) Context() context.Context {
+	if q.context == nil {
+		return context.Background()
+	}
 	return q.context
 }
 
@@ -865,20 +864,30 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 	return q
 }
 
-// WithContext will set the context to use during a query, it will be used to
-// timeout when waiting for responses from Cassandra. Additionally it adds
-// the cancel function so that it can be called whenever necessary.
+func (q *Query) withContext(ctx context.Context) ExecutableQuery {
+	// I really wish go had covariant types
+	return q.WithContext(ctx)
+}
+
+// WithContext returns a shallow copy of q with its context
+// set to ctx.
+//
+// The provided context controls the entire lifetime of executing a
+// query, queries will be canceled and return once the context is
+// canceled.
 func (q *Query) WithContext(ctx context.Context) *Query {
-	q.context, q.cancelQuery = context.WithCancel(ctx)
-	return q
+	q2 := *q
+	q2.context = ctx
+	return &q2
 }
 
+// Deprecate: does nothing, cancel the context passed to WithContext
 func (q *Query) Cancel() {
-	q.cancelQuery()
+	// TODO: delete
 }
 
-func (q *Query) execute(conn *Conn) *Iter {
-	return conn.executeQuery(q)
+func (q *Query) execute(ctx context.Context, conn *Conn) *Iter {
+	return conn.executeQuery(ctx, q)
 }
 
 func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
@@ -886,7 +895,7 @@ func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host
 	q.AddLatency(end.Sub(start).Nanoseconds(), host)
 
 	if q.observer != nil {
-		q.observer.ObserveQuery(q.context, ObservedQuery{
+		q.observer.ObserveQuery(q.Context(), ObservedQuery{
 			Keyspace:  keyspace,
 			Statement: q.stmt,
 			Start:     start,
@@ -930,7 +939,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
 	}
 
 	// try to determine the routing key
-	routingKeyInfo, err := q.session.routingKeyInfo(q.context, q.stmt)
+	routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt)
 	if err != nil {
 		return nil, err
 	}
@@ -1351,7 +1360,7 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 		return false
 	}
 
-	if iter.next != nil && iter.pos == iter.next.pos {
+	if iter.next != nil && iter.pos >= iter.next.pos {
 		go iter.next.fetch()
 	}
 
@@ -1447,21 +1456,15 @@ func (iter *Iter) NumRows() int {
 }
 
 type nextIter struct {
-	qry  Query
+	qry  *Query
 	pos  int
 	once sync.Once
 	next *Iter
-	conn *Conn
 }
 
 func (n *nextIter) fetch() *Iter {
 	n.once.Do(func() {
-		iter := n.qry.session.executor.attemptQuery(&n.qry, n.conn)
-		if iter != nil && iter.err == nil {
-			n.next = iter
-		} else {
-			n.next = n.qry.session.executeQuery(&n.qry)
-		}
+		n.next = n.qry.session.executeQuery(n.qry)
 	})
 	return n.next
 }
@@ -1509,9 +1512,6 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
 		spec:             &NonSpeculativeExecution{},
 	}
 
-	// Initiate an empty context with a cancel call
-	batch.WithContext(context.Background())
-
 	s.mu.RUnlock()
 	return batch
 }
@@ -1597,7 +1597,10 @@ func (b *Batch) SetConsistency(c Consistency) {
 	b.Cons = c
 }
 
-func (b *Batch) GetContext() context.Context {
+func (b *Batch) Context() context.Context {
+	if b.context == nil {
+		return context.Background()
+	}
 	return b.context
 }
 
@@ -1641,16 +1644,25 @@ func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 	return b
 }
 
-// WithContext will set the context to use during a query, it will be used to
-// timeout when waiting for responses from Cassandra. Additionally it adds
-// the cancel function so that it can be called whenever necessary.
+func (b *Batch) withContext(ctx context.Context) ExecutableQuery {
+	return b.WithContext(ctx)
+}
+
+// WithContext returns a shallow copy of b with its context
+// set to ctx.
+//
+// The provided context controls the entire lifetime of executing a
+// query, queries will be canceled and return once the context is
+// canceled.
 func (b *Batch) WithContext(ctx context.Context) *Batch {
-	b.context, b.cancelBatch = context.WithCancel(ctx)
-	return b
+	b2 := *b
+	b2.context = ctx
+	return &b2
 }
 
-func (b *Batch) Cancel() {
-	b.cancelBatch()
+// Deprecate: does nothing, cancel the context passed to WithContext
+func (*Batch) Cancel() {
+	// TODO: delete
 }
 
 // Size returns the number of batch statements to be executed by the batch operation.
@@ -1706,7 +1718,7 @@ func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host
 		statements[i] = entry.Stmt
 	}
 
-	b.observer.ObserveBatch(b.context, ObservedBatch{
+	b.observer.ObserveBatch(b.Context(), ObservedBatch{
 		Keyspace:   keyspace,
 		Statements: statements,
 		Start:      start,