Browse Source

clientv3: do not reconnect on request context cancellation

Anthony Romano 9 years ago
parent
commit
16c35167df

+ 6 - 2
clientv3/client.go

@@ -22,6 +22,7 @@ import (
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
+	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
 	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc"
 	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc"
 	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc/credentials"
 	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc/credentials"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/transport"
@@ -200,6 +201,9 @@ func dialEndpointList(c *Client) (*grpc.ClientConn, error) {
 	return nil, err
 	return nil, err
 }
 }
 
 
-func isRPCError(err error) bool {
-	return strings.HasPrefix(grpc.ErrorDesc(err), "etcdserver: ")
+// isHalted returns true if the given error and context indicate no forward
+// progress can be made, even after reconnecting.
+func isHalted(ctx context.Context, err error) bool {
+	isRPCError := strings.HasPrefix(grpc.ErrorDesc(err), "etcdserver: ")
+	return isRPCError || ctx.Err() != nil
 }
 }

+ 16 - 0
clientv3/client_test.go

@@ -15,9 +15,11 @@
 package clientv3
 package clientv3
 
 
 import (
 import (
+	"fmt"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
+	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
 	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc"
 	"github.com/coreos/etcd/Godeps/_workspace/src/google.golang.org/grpc"
 )
 )
 
 
@@ -52,3 +54,17 @@ func TestDialTimeout(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestIsHalted(t *testing.T) {
+	if !isHalted(nil, fmt.Errorf("etcdserver: some etcdserver error")) {
+		t.Errorf(`error prefixed with "etcdserver: " should be Halted`)
+	}
+	ctx, cancel := context.WithCancel(context.TODO())
+	if isHalted(ctx, nil) {
+		t.Errorf("no error and active context should not be Halted")
+	}
+	cancel()
+	if !isHalted(ctx, nil) {
+		t.Errorf("cancel on context should be Halted")
+	}
+}

+ 4 - 4
clientv3/cluster.go

@@ -73,7 +73,7 @@ func (c *cluster) MemberAdd(ctx context.Context, peerAddrs []string) (*MemberAdd
 		return (*MemberAddResponse)(resp), nil
 		return (*MemberAddResponse)(resp), nil
 	}
 	}
 
 
-	if isRPCError(err) {
+	if isHalted(ctx, err) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
@@ -88,7 +88,7 @@ func (c *cluster) MemberRemove(ctx context.Context, id uint64) (*MemberRemoveRes
 		return (*MemberRemoveResponse)(resp), nil
 		return (*MemberRemoveResponse)(resp), nil
 	}
 	}
 
 
-	if isRPCError(err) {
+	if isHalted(ctx, err) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
@@ -105,7 +105,7 @@ func (c *cluster) MemberUpdate(ctx context.Context, id uint64, peerAddrs []strin
 			return (*MemberUpdateResponse)(resp), nil
 			return (*MemberUpdateResponse)(resp), nil
 		}
 		}
 
 
-		if isRPCError(err) {
+		if isHalted(ctx, err) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
@@ -124,7 +124,7 @@ func (c *cluster) MemberList(ctx context.Context) (*MemberListResponse, error) {
 			return (*MemberListResponse)(resp), nil
 			return (*MemberListResponse)(resp), nil
 		}
 		}
 
 
-		if isRPCError(err) {
+		if isHalted(ctx, err) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 

+ 23 - 0
clientv3/integration/kv_test.go

@@ -454,3 +454,26 @@ func TestKVPutFailGetRetry(t *testing.T) {
 	case <-donec:
 	case <-donec:
 	}
 	}
 }
 }
+
+// TestKVGetCancel tests that a context cancel on a Get terminates as expected.
+func TestKVGetCancel(t *testing.T) {
+	defer testutil.AfterTest(t)
+
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1})
+	defer clus.Terminate(t)
+
+	oldconn := clus.Client(0).ActiveConnection()
+	kv := clientv3.NewKV(clus.Client(0))
+
+	ctx, cancel := context.WithCancel(context.TODO())
+	cancel()
+
+	resp, err := kv.Get(ctx, "abc")
+	if err == nil {
+		t.Fatalf("cancel on get response %v, expected context error", resp)
+	}
+	newconn := clus.Client(0).ActiveConnection()
+	if oldconn != newconn {
+		t.Fatalf("cancel on get broke client connection")
+	}
+}

+ 2 - 2
clientv3/kv.go

@@ -116,7 +116,7 @@ func (kv *kv) Compact(ctx context.Context, rev int64) error {
 		return nil
 		return nil
 	}
 	}
 
 
-	if isRPCError(err) {
+	if isHalted(ctx, err) {
 		return err
 		return err
 	}
 	}
 
 
@@ -166,7 +166,7 @@ func (kv *kv) Do(ctx context.Context, op Op) (OpResponse, error) {
 			panic("Unknown op")
 			panic("Unknown op")
 		}
 		}
 
 
-		if isRPCError(err) {
+		if isHalted(ctx, err) {
 			return OpResponse{}, err
 			return OpResponse{}, err
 		}
 		}
 
 

+ 3 - 5
clientv3/lease.go

@@ -112,8 +112,7 @@ func (l *lessor) Create(ctx context.Context, ttl int64) (*LeaseCreateResponse, e
 		if err == nil {
 		if err == nil {
 			return (*LeaseCreateResponse)(resp), nil
 			return (*LeaseCreateResponse)(resp), nil
 		}
 		}
-
-		if isRPCError(err) {
+		if isHalted(cctx, err) {
 			return nil, err
 			return nil, err
 		}
 		}
 		if nerr := l.switchRemoteAndStream(err); nerr != nil {
 		if nerr := l.switchRemoteAndStream(err); nerr != nil {
@@ -134,8 +133,7 @@ func (l *lessor) Revoke(ctx context.Context, id lease.LeaseID) (*LeaseRevokeResp
 		if err == nil {
 		if err == nil {
 			return (*LeaseRevokeResponse)(resp), nil
 			return (*LeaseRevokeResponse)(resp), nil
 		}
 		}
-
-		if isRPCError(err) {
+		if isHalted(ctx, err) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
@@ -261,7 +259,7 @@ func (l *lessor) recvKeepAliveLoop() {
 	for serr == nil {
 	for serr == nil {
 		resp, err := stream.Recv()
 		resp, err := stream.Recv()
 		if err != nil {
 		if err != nil {
-			if isRPCError(err) {
+			if isHalted(l.stopCtx, err) {
 				return
 				return
 			}
 			}
 			stream, serr = l.resetRecv()
 			stream, serr = l.resetRecv()

+ 1 - 1
clientv3/txn.go

@@ -144,7 +144,7 @@ func (txn *txn) Commit() (*TxnResponse, error) {
 			return (*TxnResponse)(resp), nil
 			return (*TxnResponse)(resp), nil
 		}
 		}
 
 
-		if isRPCError(err) {
+		if isHalted(txn.ctx, err) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 

+ 1 - 1
clientv3/watch.go

@@ -452,7 +452,7 @@ func (w *watcher) openWatchClient() (ws pb.Watch_WatchClient, err error) {
 	for {
 	for {
 		if ws, err = w.remote.Watch(w.ctx); ws != nil {
 		if ws, err = w.remote.Watch(w.ctx); ws != nil {
 			break
 			break
-		} else if isRPCError(err) {
+		} else if isHalted(w.ctx, err) {
 			return nil, err
 			return nil, err
 		}
 		}
 		newConn, nerr := w.c.retryConnection(w.conn, nil)
 		newConn, nerr := w.c.retryConnection(w.conn, nil)