Browse Source

query: add context support (#710)

Add WithContext on Batch and Query to supply a context which will cause
queries to abort when the context is cancelled/timeouted or done.
Chris Bannister 9 năm trước cách đây
mục cha
commit
437c5ceaa1
5 tập tin đã thay đổi với 73 bổ sung24 xóa
  1. 6 6
      cassandra_test.go
  2. 19 10
      conn.go
  3. 22 1
      conn_test.go
  4. 6 4
      control.go
  5. 20 3
      session.go

+ 6 - 6
cassandra_test.go

@@ -4,6 +4,7 @@ package gocql
 
 import (
 	"bytes"
+	"golang.org/x/net/context"
 	"io"
 	"math"
 	"math/big"
@@ -1118,7 +1119,7 @@ func TestQueryInfo(t *testing.T) {
 	defer session.Close()
 
 	conn := getRandomConn(t, session)
-	info, err := conn.prepareStatement("SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
+	info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
 
 	if err != nil {
 		t.Fatalf("Failed to execute query for preparing statement: %v", err)
@@ -1833,7 +1834,7 @@ func TestRoutingKey(t *testing.T) {
 		t.Fatalf("failed to create table with error '%v'", err)
 	}
 
-	routingKeyInfo, err := session.routingKeyInfo("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
+	routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
 	if err != nil {
 		t.Fatalf("failed to get routing key info due to error: %v", err)
 	}
@@ -1857,7 +1858,7 @@ func TestRoutingKey(t *testing.T) {
 	}
 
 	// verify the cache is working
-	routingKeyInfo, err = session.routingKeyInfo("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
+	routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
 	if err != nil {
 		t.Fatalf("failed to get routing key info due to error: %v", err)
 	}
@@ -1891,7 +1892,7 @@ func TestRoutingKey(t *testing.T) {
 		t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey)
 	}
 
-	routingKeyInfo, err = session.routingKeyInfo("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
+	routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
 	if err != nil {
 		t.Fatalf("failed to get routing key info due to error: %v", err)
 	}
@@ -1999,7 +2000,7 @@ func TestNegativeStream(t *testing.T) {
 		return f.finishWrite()
 	})
 
-	frame, err := conn.exec(writer, nil)
+	frame, err := conn.exec(context.Background(), writer, nil)
 	if err == nil {
 		t.Fatalf("expected to get an error on stream %d", stream)
 	} else if frame != nil {
@@ -2411,5 +2412,4 @@ func TestSchemaReset(t *testing.T) {
 	if val != expVal {
 		t.Errorf("expected to get val=%q got=%q", expVal, val)
 	}
-
 }

+ 19 - 10
conn.go

@@ -9,6 +9,7 @@ import (
 	"crypto/tls"
 	"errors"
 	"fmt"
+	"golang.org/x/net/context"
 	"io"
 	"io/ioutil"
 	"log"
@@ -284,7 +285,7 @@ func (c *Conn) startup(frameTicker chan struct{}) error {
 	}
 
 	frameTicker <- struct{}{}
-	framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
+	framer, err := c.exec(context.Background(), &writeStartupFrame{opts: m}, nil)
 	if err != nil {
 		return err
 	}
@@ -320,7 +321,7 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame, frameTicker c
 
 	for {
 		frameTicker <- struct{}{}
-		framer, err := c.exec(req, nil)
+		framer, err := c.exec(context.Background(), req, nil)
 		if err != nil {
 			return err
 		}
@@ -531,7 +532,7 @@ type callReq struct {
 	timer *time.Timer
 }
 
-func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
+func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
 	// TODO: move tracer onto conn
 	stream, ok := c.streams.GetStream()
 	if !ok {
@@ -593,6 +594,11 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 		timeoutCh = call.timer.C
 	}
 
+	var ctxDone <-chan struct{}
+	if ctx != nil {
+		ctxDone = ctx.Done()
+	}
+
 	select {
 	case err := <-call.resp:
 		close(call.timeout)
@@ -610,6 +616,9 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 		close(call.timeout)
 		c.handleTimeout()
 		return nil, ErrTimeoutNoResponse
+	case <-ctxDone:
+		close(call.timeout)
+		return nil, ctx.Err()
 	case <-c.quit:
 		return nil, ErrConnectionClosed
 	}
@@ -642,7 +651,7 @@ type inflightPrepare struct {
 	preparedStatment *preparedStatment
 }
 
-func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment, error) {
+func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
 	stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
 	flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
 		flight := new(inflightPrepare)
@@ -660,7 +669,7 @@ func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*preparedStatment,
 		statement: stmt,
 	}
 
-	framer, err := c.exec(prep, tracer)
+	framer, err := c.exec(ctx, prep, tracer)
 	if err != nil {
 		flight.err = err
 		flight.wg.Done()
@@ -732,7 +741,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 	if qry.shouldPrepare() {
 		// Prepare all DML queries. Other queries can not be prepared.
 		var err error
-		info, err = c.prepareStatement(qry.stmt, qry.trace)
+		info, err = c.prepareStatement(qry.context, qry.stmt, qry.trace)
 		if err != nil {
 			return &Iter{err: err}
 		}
@@ -783,7 +792,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		}
 	}
 
-	framer, err := c.exec(frame, qry.trace)
+	framer, err := c.exec(qry.context, frame, qry.trace)
 	if err != nil {
 		return &Iter{err: err}
 	}
@@ -883,7 +892,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
 	q.params.consistency = Any
 
-	framer, err := c.exec(q, nil)
+	framer, err := c.exec(context.Background(), q, nil)
 	if err != nil {
 		return err
 	}
@@ -926,7 +935,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 		entry := &batch.Entries[i]
 		b := &req.statements[i]
 		if len(entry.Args) > 0 || entry.binding != nil {
-			info, err := c.prepareStatement(entry.Stmt, nil)
+			info, err := c.prepareStatement(batch.context, entry.Stmt, nil)
 			if err != nil {
 				return &Iter{err: err}
 			}
@@ -970,7 +979,7 @@ func (c *Conn) executeBatch(batch *Batch) *Iter {
 	}
 
 	// TODO: should batch support tracing?
-	framer, err := c.exec(req, nil)
+	framer, err := c.exec(batch.context, req, nil)
 	if err != nil {
 		return &Iter{err: err}
 	}

+ 22 - 1
conn_test.go

@@ -9,6 +9,7 @@ import (
 	"crypto/tls"
 	"crypto/x509"
 	"fmt"
+	"golang.org/x/net/context"
 	"io"
 	"io/ioutil"
 	"net"
@@ -482,7 +483,7 @@ func TestStream0(t *testing.T) {
 	})
 
 	// need to write out an invalid frame, which we need a connection to do
-	framer, err := conn.exec(writer, nil)
+	framer, err := conn.exec(context.Background(), writer, nil)
 	if err == nil {
 		t.Fatal("expected to get an error on stream 0")
 	} else if !strings.HasPrefix(err.Error(), expErr) {
@@ -523,6 +524,26 @@ func TestConnClosedBlocked(t *testing.T) {
 	}
 }
 
+func TestContext_Timeout(t *testing.T) {
+	srv := NewTestServer(t, defaultProto)
+	defer srv.Stop()
+
+	cluster := testCluster(srv.Address, defaultProto)
+	cluster.Timeout = 5 * time.Second
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer db.Close()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+	err = db.Query("timeout").WithContext(ctx).Exec()
+	if err != context.Canceled {
+		t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err)
+	}
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {

+ 6 - 4
control.go

@@ -4,6 +4,7 @@ import (
 	crand "crypto/rand"
 	"errors"
 	"fmt"
+	"golang.org/x/net/context"
 	"log"
 	"math/rand"
 	"net"
@@ -193,9 +194,10 @@ func (c *controlConn) registerEvents(conn *Conn) error {
 		return nil
 	}
 
-	framer, err := conn.exec(&writeRegisterFrame{
-		events: events,
-	}, nil)
+	framer, err := conn.exec(context.Background(),
+		&writeRegisterFrame{
+			events: events,
+		}, nil)
 	if err != nil {
 		return err
 	}
@@ -282,7 +284,7 @@ func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
 		return nil, errNoControl
 	}
 
-	framer, err := conn.exec(w, nil)
+	framer, err := conn.exec(context.Background(), w, nil)
 	if err != nil {
 		return nil, err
 	}

+ 20 - 3
session.go

@@ -9,6 +9,7 @@ import (
 	"encoding/binary"
 	"errors"
 	"fmt"
+	"golang.org/x/net/context"
 	"io"
 	"net"
 	"strconv"
@@ -359,7 +360,7 @@ func (s *Session) getConn() *Conn {
 }
 
 // returns routing key indexes and type info
-func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
+func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) {
 	s.routingKeyInfoCache.mu.Lock()
 
 	entry, cached := s.routingKeyInfoCache.lru.Get(stmt)
@@ -402,7 +403,7 @@ func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	}
 
 	// get the query info for the statement
-	info, inflight.err = conn.prepareStatement(stmt, nil)
+	info, inflight.err = conn.prepareStatement(ctx, stmt, nil)
 	if inflight.err != nil {
 		// don't cache this error
 		s.routingKeyInfoCache.Remove(stmt)
@@ -587,6 +588,7 @@ type Query struct {
 	defaultTimestamp      bool
 	defaultTimestampValue int64
 	disableSkipMetadata   bool
+	context               context.Context
 
 	disableAutoPage bool
 }
@@ -669,6 +671,13 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 	return q
 }
 
+// WithContext will set the context to use during a query, it will be used to
+// timeout when waiting for responses from Cassandra.
+func (q *Query) WithContext(ctx context.Context) *Query {
+	q.context = ctx
+	return q
+}
+
 func (q *Query) execute(conn *Conn) *Iter {
 	return conn.executeQuery(q)
 }
@@ -700,7 +709,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
 	}
 
 	// try to determine the routing key
-	routingKeyInfo, err := q.session.routingKeyInfo(q.stmt)
+	routingKeyInfo, err := q.session.routingKeyInfo(q.context, q.stmt)
 	if err != nil {
 		return nil, err
 	}
@@ -1078,6 +1087,7 @@ type Batch struct {
 	totalLatency     int64
 	serialCons       SerialConsistency
 	defaultTimestamp bool
+	context          context.Context
 }
 
 // NewBatch creates a new batch operation without defaults from the cluster
@@ -1135,6 +1145,13 @@ func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 	return b
 }
 
+// WithContext will set the context to use during a query, it will be used to
+// timeout when waiting for responses from Cassandra.
+func (b *Batch) WithContext(ctx context.Context) *Batch {
+	b.context = ctx
+	return b
+}
+
 // Size returns the number of batch statements to be executed by the batch operation.
 func (b *Batch) Size() int {
 	return len(b.Entries)