浏览代码

Allow to Cancel() query at any time (#1174)

* Expose a Cancel() call in RetryableQuery

Solving #1173

* Add a Cancel() call to the RetryableQuery interface
* Add cancelQuery and cancelBatch fields
* Initiate Query/Batch context
* Implement Cancel() call by calling the cancelQuery/cancelBatch
functions
* Add TestCancel to verify that the query is being cancelled by the
cancel call before it's finished.

Signed-off-by: Alex Lourie <alex@instaclustr.com>

* Fixing the interfaces and bad test

* The Cancel() call moved from RetryableQuery to ExecutableQuery
* Simplify context and cancel functions initialisation
* Improve TestCancel with the following:
  * the test is now actually testing that the query was cancelled.
  * doesn't need "veryslow" server query, just using "timeout" is enough
  * using a waitgroup to cleanup after the test
  * cut the runtime to about 20ms, so elapsed times measure is not
  required

Signed-off-by: Alex Lourie <alex@instaclustr.com>
Alex Lourie 7 年之前
父节点
当前提交
08a3a27c42
共有 3 个文件被更改,包括 58 次插入4 次删除
  1. 33 0
      conn_test.go
  2. 1 0
      query_executor.go
  3. 24 4
      session.go

+ 33 - 0
conn_test.go

@@ -282,6 +282,39 @@ func TestTimeout(t *testing.T) {
 	wg.Wait()
 }
 
+func TestCancel(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	srv := NewTestServer(t, defaultProto, ctx)
+	defer srv.Stop()
+
+	cluster := testCluster(defaultProto, srv.Address)
+	cluster.Timeout = 1 * time.Second
+	db, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatalf("NewCluster: %v", err)
+	}
+	defer db.Close()
+
+	qry := db.Query("timeout")
+
+	// Make sure we finish the query without leftovers
+	var wg sync.WaitGroup
+	wg.Add(1)
+
+	go func() {
+		if err := qry.Exec(); err != context.Canceled {
+			t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err)
+		}
+		wg.Done()
+	}()
+
+	// The query will timeout after about 1 seconds, so cancel it after a short pause
+	time.AfterFunc(20 * time.Millisecond, qry.Cancel)
+	wg.Wait()
+}
+
 type testQueryObserver struct {
 	metrics map[string]*queryMetrics
 	verbose bool

+ 1 - 0
query_executor.go

@@ -10,6 +10,7 @@ type ExecutableQuery interface {
 	retryPolicy() RetryPolicy
 	GetRoutingKey() ([]byte, error)
 	Keyspace() string
+	Cancel()
 	RetryableQuery
 }
 

+ 24 - 4
session.go

@@ -683,6 +683,7 @@ type Query struct {
 	defaultTimestampValue int64
 	disableSkipMetadata   bool
 	context               context.Context
+	cancelQuery           func()
 	idempotent            bool
 	metrics               map[string]*queryMetrics
 
@@ -703,6 +704,10 @@ func (q *Query) defaultsFromSession() {
 	q.defaultTimestamp = s.cfg.DefaultTimestamp
 	q.idempotent = s.cfg.DefaultIdempotence
 	q.metrics = make(map[string]*queryMetrics)
+
+	// Initiate an empty context with a cancel call
+	q.WithContext(context.Background())
+
 	s.mu.RUnlock()
 }
 
@@ -823,12 +828,17 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 }
 
 // WithContext will set the context to use during a query, it will be used to
-// timeout when waiting for responses from Cassandra.
+// timeout when waiting for responses from Cassandra. Additionally it adds
+// the cancel function so that it can be called whenever necessary.
 func (q *Query) WithContext(ctx context.Context) *Query {
-	q.context = ctx
+	q.context, q.cancelQuery = context.WithCancel(ctx)
 	return q
 }
 
+func (q *Query) Cancel() {
+	q.cancelQuery()
+}
+
 func (q *Query) execute(conn *Conn) *Iter {
 	return conn.executeQuery(q)
 }
@@ -1418,6 +1428,7 @@ type Batch struct {
 	defaultTimestamp      bool
 	defaultTimestampValue int64
 	context               context.Context
+	cancelBatch           func()
 	keyspace              string
 	metrics               map[string]*queryMetrics
 }
@@ -1442,6 +1453,10 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
 		keyspace:         s.cfg.Keyspace,
 		metrics:          make(map[string]*queryMetrics),
 	}
+
+	// Initiate an empty context with a cancel call
+	batch.WithContext(context.Background())
+
 	s.mu.RUnlock()
 	return batch
 }
@@ -1526,12 +1541,17 @@ func (b *Batch) RetryPolicy(r RetryPolicy) *Batch {
 }
 
 // WithContext will set the context to use during a query, it will be used to
-// timeout when waiting for responses from Cassandra.
+// timeout when waiting for responses from Cassandra. Additionally it adds
+// the cancel function so that it can be called whenever necessary.
 func (b *Batch) WithContext(ctx context.Context) *Batch {
-	b.context = ctx
+	b.context, b.cancelBatch = context.WithCancel(ctx)
 	return b
 }
 
+func (b *Batch) Cancel() {
+	b.cancelBatch()
+}
+
 // Size returns the number of batch statements to be executed by the batch operation.
 func (b *Batch) Size() int {
 	return len(b.Entries)