Browse Source

clientv3: Fix TLS test failures by returning DeadlineExceeded error from dial without any additional wrapping

Joe Betz 7 years ago
parent
commit
9304d1abd1

+ 2 - 3
clientv3/auth.go

@@ -21,7 +21,6 @@ import (
 
 
 	"github.com/coreos/etcd/auth/authpb"
 	"github.com/coreos/etcd/auth/authpb"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
-
 	"google.golang.org/grpc"
 	"google.golang.org/grpc"
 )
 )
 
 
@@ -216,8 +215,8 @@ func (auth *authenticator) close() {
 	auth.conn.Close()
 	auth.conn.Close()
 }
 }
 
 
-func newAuthenticator(ctx context.Context, endpoint string, opts []grpc.DialOption, c *Client) (*authenticator, error) {
-	conn, err := grpc.DialContext(ctx, endpoint, opts...)
+func newAuthenticator(ctx context.Context, target string, opts []grpc.DialOption, c *Client) (*authenticator, error) {
+	conn, err := grpc.DialContext(ctx, target, opts...)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 20 - 20
clientv3/balancer/resolver/endpoint/endpoint.go

@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // limitations under the License.
 
 
-// resolves to etcd entpoints for grpc targets of the form 'endpoint://<cluster-name>/<endpoint>'.
+// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint://<clientId>/<endpoint>'.
 package endpoint
 package endpoint
 
 
 import (
 import (
@@ -36,13 +36,13 @@ var (
 
 
 func init() {
 func init() {
 	bldr = &builder{
 	bldr = &builder{
-		clusterResolvers: make(map[string]*Resolver),
+		clientResolvers: make(map[string]*Resolver),
 	}
 	}
 	resolver.Register(bldr)
 	resolver.Register(bldr)
 }
 }
 
 
 type builder struct {
 type builder struct {
-	clusterResolvers map[string]*Resolver
+	clientResolvers map[string]*Resolver
 	sync.RWMutex
 	sync.RWMutex
 }
 }
 
 
@@ -59,16 +59,16 @@ func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts res
 	return r, nil
 	return r, nil
 }
 }
 
 
-func (b *builder) getResolver(clusterName string) *Resolver {
+func (b *builder) getResolver(clientId string) *Resolver {
 	b.RLock()
 	b.RLock()
-	r, ok := b.clusterResolvers[clusterName]
+	r, ok := b.clientResolvers[clientId]
 	b.RUnlock()
 	b.RUnlock()
 	if !ok {
 	if !ok {
 		r = &Resolver{
 		r = &Resolver{
-			clusterName: clusterName,
+			clientId: clientId,
 		}
 		}
 		b.Lock()
 		b.Lock()
-		b.clusterResolvers[clusterName] = r
+		b.clientResolvers[clientId] = r
 		b.Unlock()
 		b.Unlock()
 	}
 	}
 	return r
 	return r
@@ -76,13 +76,13 @@ func (b *builder) getResolver(clusterName string) *Resolver {
 
 
 func (b *builder) addResolver(r *Resolver) {
 func (b *builder) addResolver(r *Resolver) {
 	bldr.Lock()
 	bldr.Lock()
-	bldr.clusterResolvers[r.clusterName] = r
+	bldr.clientResolvers[r.clientId] = r
 	bldr.Unlock()
 	bldr.Unlock()
 }
 }
 
 
 func (b *builder) removeResolver(r *Resolver) {
 func (b *builder) removeResolver(r *Resolver) {
 	bldr.Lock()
 	bldr.Lock()
-	delete(bldr.clusterResolvers, r.clusterName)
+	delete(bldr.clientResolvers, r.clientId)
 	bldr.Unlock()
 	bldr.Unlock()
 }
 }
 
 
@@ -91,15 +91,15 @@ func (r *builder) Scheme() string {
 }
 }
 
 
 // EndpointResolver gets the resolver for  given etcd cluster name.
 // EndpointResolver gets the resolver for  given etcd cluster name.
-func EndpointResolver(clusterName string) *Resolver {
-	return bldr.getResolver(clusterName)
+func EndpointResolver(clientId string) *Resolver {
+	return bldr.getResolver(clientId)
 }
 }
 
 
 // Resolver provides a resolver for a single etcd cluster, identified by name.
 // Resolver provides a resolver for a single etcd cluster, identified by name.
 type Resolver struct {
 type Resolver struct {
-	clusterName string
-	cc          resolver.ClientConn
-	addrs       []resolver.Address
+	clientId string
+	cc       resolver.ClientConn
+	addrs    []resolver.Address
 	sync.RWMutex
 	sync.RWMutex
 }
 }
 
 
@@ -146,14 +146,14 @@ func (r *Resolver) Close() {
 	bldr.removeResolver(r)
 	bldr.removeResolver(r)
 }
 }
 
 
-// Target constructs a endpoint target with current resolver's clusterName.
+// Target constructs a endpoint target with current resolver's clientId.
 func (r *Resolver) Target(endpoint string) string {
 func (r *Resolver) Target(endpoint string) string {
-	return Target(r.clusterName, endpoint)
+	return Target(r.clientId, endpoint)
 }
 }
 
 
 // Target constructs a endpoint resolver target.
 // Target constructs a endpoint resolver target.
-func Target(clusterName, endpoint string) string {
-	return fmt.Sprintf("%s://%s/%s", scheme, clusterName, endpoint)
+func Target(clientId, endpoint string) string {
+	return fmt.Sprintf("%s://%s/%s", scheme, clientId, endpoint)
 }
 }
 
 
 // IsTarget checks if a given target string in an endpoint resolver target.
 // IsTarget checks if a given target string in an endpoint resolver target.
@@ -185,7 +185,7 @@ func ParseEndpoint(endpoint string) (proto string, host string, scheme string) {
 	return proto, host, scheme
 	return proto, host, scheme
 }
 }
 
 
-// ParseTarget parses a endpoint://<clusterName>/<endpoint> string and returns the parsed clusterName and endpoint.
+// ParseTarget parses a endpoint://<clientId>/<endpoint> string and returns the parsed clientId and endpoint.
 // If the target is malformed, an error is returned.
 // If the target is malformed, an error is returned.
 func ParseTarget(target string) (string, string, error) {
 func ParseTarget(target string) (string, string, error) {
 	noPrefix := strings.TrimPrefix(target, targetPrefix)
 	noPrefix := strings.TrimPrefix(target, targetPrefix)
@@ -194,7 +194,7 @@ func ParseTarget(target string) (string, string, error) {
 	}
 	}
 	parts := strings.SplitN(noPrefix, "/", 2)
 	parts := strings.SplitN(noPrefix, "/", 2)
 	if len(parts) != 2 {
 	if len(parts) != 2 {
-		return "", "", fmt.Errorf("malformed target, expected %s://<clusterName>/<endpoint>, but got %s", scheme, target)
+		return "", "", fmt.Errorf("malformed target, expected %s://<clientId>/<endpoint>, but got %s", scheme, target)
 	}
 	}
 	return parts[0], parts[1], nil
 	return parts[0], parts[1], nil
 }
 }

+ 7 - 5
clientv3/client.go

@@ -297,15 +297,17 @@ func (c *Client) getToken(ctx context.Context) error {
 	var auth *authenticator
 	var auth *authenticator
 
 
 	for i := 0; i < len(c.cfg.Endpoints); i++ {
 	for i := 0; i < len(c.cfg.Endpoints); i++ {
-		endpoint := c.cfg.Endpoints[i]
+		ep := c.cfg.Endpoints[i]
 		// use dial options without dopts to avoid reusing the client balancer
 		// use dial options without dopts to avoid reusing the client balancer
 		var dOpts []grpc.DialOption
 		var dOpts []grpc.DialOption
-		dOpts, err = c.dialSetupOpts(c.resolver.Target(endpoint), c.cfg.DialOptions...)
+		_, host, _ := endpoint.ParseEndpoint(ep)
+		target := c.resolver.Target(host)
+		dOpts, err = c.dialSetupOpts(target, c.cfg.DialOptions...)
 		if err != nil {
 		if err != nil {
 			err = fmt.Errorf("failed to configure auth dialer: %v", err)
 			err = fmt.Errorf("failed to configure auth dialer: %v", err)
 			continue
 			continue
 		}
 		}
-		auth, err = newAuthenticator(ctx, endpoint, dOpts, c)
+		auth, err = newAuthenticator(ctx, target, dOpts, c)
 		if err != nil {
 		if err != nil {
 			continue
 			continue
 		}
 		}
@@ -369,7 +371,7 @@ func (c *Client) dial(ep string, dopts ...grpc.DialOption) (*grpc.ClientConn, er
 	if c.cfg.DialTimeout > 0 {
 	if c.cfg.DialTimeout > 0 {
 		var cancel context.CancelFunc
 		var cancel context.CancelFunc
 		dctx, cancel = context.WithTimeout(c.ctx, c.cfg.DialTimeout)
 		dctx, cancel = context.WithTimeout(c.ctx, c.cfg.DialTimeout)
-		defer cancel()
+		defer cancel() // TODO: Is this right for cases where grpc.WithBlock() is not set on the dial options?
 	}
 	}
 
 
 	conn, err := grpc.DialContext(dctx, target, opts...)
 	conn, err := grpc.DialContext(dctx, target, opts...)
@@ -456,7 +458,7 @@ func newClient(cfg *Config) (*Client, error) {
 	if err != nil {
 	if err != nil {
 		client.cancel()
 		client.cancel()
 		client.resolver.Close()
 		client.resolver.Close()
-		return nil, fmt.Errorf("failed to dial initial client connection: %v", err)
+		return nil, err
 	}
 	}
 	// TODO: With the old grpc balancer interface, we waited until the dial timeout
 	// TODO: With the old grpc balancer interface, we waited until the dial timeout
 	// for the balancer to be ready. Is there an equivalent wait we should do with the new grpc balancer interface?
 	// for the balancer to be ready. Is there an equivalent wait we should do with the new grpc balancer interface?

+ 9 - 6
clientv3/integration/black_hole_test.go

@@ -25,6 +25,7 @@ import (
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/testutil"
+	"google.golang.org/grpc"
 )
 )
 
 
 // TestBalancerUnderBlackholeKeepAliveWatch tests when watch discovers it cannot talk to
 // TestBalancerUnderBlackholeKeepAliveWatch tests when watch discovers it cannot talk to
@@ -44,6 +45,7 @@ func TestBalancerUnderBlackholeKeepAliveWatch(t *testing.T) {
 	ccfg := clientv3.Config{
 	ccfg := clientv3.Config{
 		Endpoints:            []string{eps[0]},
 		Endpoints:            []string{eps[0]},
 		DialTimeout:          1 * time.Second,
 		DialTimeout:          1 * time.Second,
+		DialOptions:          []grpc.DialOption{grpc.WithBlock()},
 		DialKeepAliveTime:    1 * time.Second,
 		DialKeepAliveTime:    1 * time.Second,
 		DialKeepAliveTimeout: 500 * time.Millisecond,
 		DialKeepAliveTimeout: 500 * time.Millisecond,
 	}
 	}
@@ -106,7 +108,7 @@ func TestBalancerUnderBlackholeKeepAliveWatch(t *testing.T) {
 func TestBalancerUnderBlackholeNoKeepAlivePut(t *testing.T) {
 func TestBalancerUnderBlackholeNoKeepAlivePut(t *testing.T) {
 	testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
 	testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
 		_, err := cli.Put(ctx, "foo", "bar")
 		_, err := cli.Put(ctx, "foo", "bar")
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
+		if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -116,7 +118,7 @@ func TestBalancerUnderBlackholeNoKeepAlivePut(t *testing.T) {
 func TestBalancerUnderBlackholeNoKeepAliveDelete(t *testing.T) {
 func TestBalancerUnderBlackholeNoKeepAliveDelete(t *testing.T) {
 	testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
 	testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
 		_, err := cli.Delete(ctx, "foo")
 		_, err := cli.Delete(ctx, "foo")
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
+		if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -129,7 +131,7 @@ func TestBalancerUnderBlackholeNoKeepAliveTxn(t *testing.T) {
 			If(clientv3.Compare(clientv3.Version("foo"), "=", 0)).
 			If(clientv3.Compare(clientv3.Version("foo"), "=", 0)).
 			Then(clientv3.OpPut("foo", "bar")).
 			Then(clientv3.OpPut("foo", "bar")).
 			Else(clientv3.OpPut("foo", "baz")).Commit()
 			Else(clientv3.OpPut("foo", "baz")).Commit()
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
+		if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -139,7 +141,7 @@ func TestBalancerUnderBlackholeNoKeepAliveTxn(t *testing.T) {
 func TestBalancerUnderBlackholeNoKeepAliveLinearizableGet(t *testing.T) {
 func TestBalancerUnderBlackholeNoKeepAliveLinearizableGet(t *testing.T) {
 	testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
 	testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
 		_, err := cli.Get(ctx, "a")
 		_, err := cli.Get(ctx, "a")
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
+		if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -149,7 +151,7 @@ func TestBalancerUnderBlackholeNoKeepAliveLinearizableGet(t *testing.T) {
 func TestBalancerUnderBlackholeNoKeepAliveSerializableGet(t *testing.T) {
 func TestBalancerUnderBlackholeNoKeepAliveSerializableGet(t *testing.T) {
 	testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
 	testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
 		_, err := cli.Get(ctx, "a", clientv3.WithSerializable())
 		_, err := cli.Get(ctx, "a", clientv3.WithSerializable())
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) {
+		if isClientTimeout(err) || isServerCtxTimeout(err) {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -172,6 +174,7 @@ func testBalancerUnderBlackholeNoKeepAlive(t *testing.T, op func(*clientv3.Clien
 	ccfg := clientv3.Config{
 	ccfg := clientv3.Config{
 		Endpoints:   []string{eps[0]},
 		Endpoints:   []string{eps[0]},
 		DialTimeout: 1 * time.Second,
 		DialTimeout: 1 * time.Second,
+		DialOptions: []grpc.DialOption{grpc.WithBlock()},
 	}
 	}
 	cli, err := clientv3.New(ccfg)
 	cli, err := clientv3.New(ccfg)
 	if err != nil {
 	if err != nil {
@@ -193,7 +196,7 @@ func testBalancerUnderBlackholeNoKeepAlive(t *testing.T, op func(*clientv3.Clien
 	// TODO: first operation can succeed
 	// TODO: first operation can succeed
 	// when gRPC supports better retry on non-delivered request
 	// when gRPC supports better retry on non-delivered request
 	for i := 0; i < 2; i++ {
 	for i := 0; i < 2; i++ {
-		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
 		err = op(cli, ctx)
 		err = op(cli, ctx)
 		cancel()
 		cancel()
 		if err == nil {
 		if err == nil {

+ 21 - 6
clientv3/integration/dial_test.go

@@ -26,6 +26,7 @@ import (
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/transport"
+	"google.golang.org/grpc"
 )
 )
 
 
 var (
 var (
@@ -58,10 +59,11 @@ func TestDialTLSExpired(t *testing.T) {
 	_, err = clientv3.New(clientv3.Config{
 	_, err = clientv3.New(clientv3.Config{
 		Endpoints:   []string{clus.Members[0].GRPCAddr()},
 		Endpoints:   []string{clus.Members[0].GRPCAddr()},
 		DialTimeout: 3 * time.Second,
 		DialTimeout: 3 * time.Second,
+		DialOptions: []grpc.DialOption{grpc.WithBlock()},
 		TLS:         tls,
 		TLS:         tls,
 	})
 	})
-	if err != context.DeadlineExceeded {
-		t.Fatalf("expected %v, got %v", context.DeadlineExceeded, err)
+	if !isClientTimeout(err) {
+		t.Fatalf("expected dial timeout error, got %v", err)
 	}
 	}
 }
 }
 
 
@@ -72,12 +74,19 @@ func TestDialTLSNoConfig(t *testing.T) {
 	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1, ClientTLS: &testTLSInfo, SkipCreatingClient: true})
 	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1, ClientTLS: &testTLSInfo, SkipCreatingClient: true})
 	defer clus.Terminate(t)
 	defer clus.Terminate(t)
 	// expect "signed by unknown authority"
 	// expect "signed by unknown authority"
-	_, err := clientv3.New(clientv3.Config{
+	c, err := clientv3.New(clientv3.Config{
 		Endpoints:   []string{clus.Members[0].GRPCAddr()},
 		Endpoints:   []string{clus.Members[0].GRPCAddr()},
 		DialTimeout: time.Second,
 		DialTimeout: time.Second,
+		DialOptions: []grpc.DialOption{grpc.WithBlock()},
 	})
 	})
-	if err != context.DeadlineExceeded {
-		t.Fatalf("expected %v, got %v", context.DeadlineExceeded, err)
+	defer c.Close()
+
+	// TODO: this should not be required when we set grpc.WithBlock()
+	if c != nil {
+		_, err = c.KV.Get(context.Background(), "/")
+	}
+	if !isClientTimeout(err) {
+		t.Fatalf("expected dial timeout error, got %v", err)
 	}
 	}
 }
 }
 
 
@@ -104,7 +113,11 @@ func testDialSetEndpoints(t *testing.T, setBefore bool) {
 	}
 	}
 	toKill := rand.Intn(len(eps))
 	toKill := rand.Intn(len(eps))
 
 
-	cfg := clientv3.Config{Endpoints: []string{eps[toKill]}, DialTimeout: 1 * time.Second}
+	cfg := clientv3.Config{
+		Endpoints:   []string{eps[toKill]},
+		DialTimeout: 1 * time.Second,
+		DialOptions: []grpc.DialOption{grpc.WithBlock()},
+	}
 	cli, err := clientv3.New(cfg)
 	cli, err := clientv3.New(cfg)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
@@ -121,6 +134,7 @@ func testDialSetEndpoints(t *testing.T, setBefore bool) {
 	if !setBefore {
 	if !setBefore {
 		cli.SetEndpoints(eps[toKill%3], eps[(toKill+1)%3])
 		cli.SetEndpoints(eps[toKill%3], eps[(toKill+1)%3])
 	}
 	}
+	time.Sleep(time.Second * 2)
 	ctx, cancel := context.WithTimeout(context.Background(), integration.RequestWaitTimeout)
 	ctx, cancel := context.WithTimeout(context.Background(), integration.RequestWaitTimeout)
 	if _, err = cli.Get(ctx, "foo", clientv3.WithSerializable()); err != nil {
 	if _, err = cli.Get(ctx, "foo", clientv3.WithSerializable()); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
@@ -158,6 +172,7 @@ func TestRejectOldCluster(t *testing.T) {
 	cfg := clientv3.Config{
 	cfg := clientv3.Config{
 		Endpoints:        []string{clus.Members[0].GRPCAddr(), clus.Members[1].GRPCAddr()},
 		Endpoints:        []string{clus.Members[0].GRPCAddr(), clus.Members[1].GRPCAddr()},
 		DialTimeout:      5 * time.Second,
 		DialTimeout:      5 * time.Second,
+		DialOptions:      []grpc.DialOption{grpc.WithBlock()},
 		RejectOldCluster: true,
 		RejectOldCluster: true,
 	}
 	}
 	cli, err := clientv3.New(cfg)
 	cli, err := clientv3.New(cfg)

+ 8 - 6
clientv3/integration/kv_test.go

@@ -708,6 +708,7 @@ func TestKVGetRetry(t *testing.T) {
 
 
 	time.Sleep(100 * time.Millisecond)
 	time.Sleep(100 * time.Millisecond)
 	clus.Members[fIdx].Restart(t)
 	clus.Members[fIdx].Restart(t)
+	clus.Members[fIdx].WaitOK(t)
 
 
 	select {
 	select {
 	case <-time.After(5 * time.Second):
 	case <-time.After(5 * time.Second):
@@ -792,7 +793,7 @@ func TestKVGetStoppedServerAndClose(t *testing.T) {
 	// this Get fails and triggers an asynchronous connection retry
 	// this Get fails and triggers an asynchronous connection retry
 	_, err := cli.Get(ctx, "abc")
 	_, err := cli.Get(ctx, "abc")
 	cancel()
 	cancel()
-	if err != nil && err != context.DeadlineExceeded {
+	if err != nil && !isServerUnavailable(err) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 }
 }
@@ -814,14 +815,15 @@ func TestKVPutStoppedServerAndClose(t *testing.T) {
 	// grpc finds out the original connection is down due to the member shutdown.
 	// grpc finds out the original connection is down due to the member shutdown.
 	_, err := cli.Get(ctx, "abc")
 	_, err := cli.Get(ctx, "abc")
 	cancel()
 	cancel()
-	if err != nil && err != context.DeadlineExceeded {
+	if err != nil && !isServerUnavailable(err) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
+	ctx, cancel = context.WithTimeout(context.TODO(), time.Second) // TODO: How was this test not consistently failing with context canceled errors?
 	// this Put fails and triggers an asynchronous connection retry
 	// this Put fails and triggers an asynchronous connection retry
 	_, err = cli.Put(ctx, "abc", "123")
 	_, err = cli.Put(ctx, "abc", "123")
 	cancel()
 	cancel()
-	if err != nil && err != context.DeadlineExceeded {
+	if err != nil && !isServerUnavailable(err) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 }
 }
@@ -906,7 +908,7 @@ func TestKVLargeRequests(t *testing.T) {
 			maxCallSendBytesClient: 10 * 1024 * 1024,
 			maxCallSendBytesClient: 10 * 1024 * 1024,
 			maxCallRecvBytesClient: 0,
 			maxCallRecvBytesClient: 0,
 			valueSize:              10 * 1024 * 1024,
 			valueSize:              10 * 1024 * 1024,
-			expectError:            grpc.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max "),
+			expectError:            grpc.Errorf(codes.ResourceExhausted, "trying to send message larger than max "),
 		},
 		},
 		{
 		{
 			maxRequestBytesServer:  10 * 1024 * 1024,
 			maxRequestBytesServer:  10 * 1024 * 1024,
@@ -920,7 +922,7 @@ func TestKVLargeRequests(t *testing.T) {
 			maxCallSendBytesClient: 10 * 1024 * 1024,
 			maxCallSendBytesClient: 10 * 1024 * 1024,
 			maxCallRecvBytesClient: 0,
 			maxCallRecvBytesClient: 0,
 			valueSize:              10*1024*1024 + 5,
 			valueSize:              10*1024*1024 + 5,
-			expectError:            grpc.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max "),
+			expectError:            grpc.Errorf(codes.ResourceExhausted, "trying to send message larger than max "),
 		},
 		},
 	}
 	}
 	for i, test := range tests {
 	for i, test := range tests {
@@ -940,7 +942,7 @@ func TestKVLargeRequests(t *testing.T) {
 				t.Errorf("#%d: expected %v, got %v", i, test.expectError, err)
 				t.Errorf("#%d: expected %v, got %v", i, test.expectError, err)
 			}
 			}
 		} else if err != nil && !strings.HasPrefix(err.Error(), test.expectError.Error()) {
 		} else if err != nil && !strings.HasPrefix(err.Error(), test.expectError.Error()) {
-			t.Errorf("#%d: expected %v, got %v", i, test.expectError, err)
+			t.Errorf("#%d: expected error starting with '%s', got '%s'", i, test.expectError.Error(), err.Error())
 		}
 		}
 
 
 		// put request went through, now expects large response back
 		// put request went through, now expects large response back

+ 34 - 30
clientv3/integration/leasing_test.go

@@ -1920,10 +1920,6 @@ func TestLeasingSessionExpire(t *testing.T) {
 }
 }
 
 
 func TestLeasingSessionExpireCancel(t *testing.T) {
 func TestLeasingSessionExpireCancel(t *testing.T) {
-	defer testutil.AfterTest(t)
-	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
-	defer clus.Terminate(t)
-
 	tests := []func(context.Context, clientv3.KV) error{
 	tests := []func(context.Context, clientv3.KV) error{
 		func(ctx context.Context, kv clientv3.KV) error {
 		func(ctx context.Context, kv clientv3.KV) error {
 			_, err := kv.Get(ctx, "abc")
 			_, err := kv.Get(ctx, "abc")
@@ -1960,37 +1956,43 @@ func TestLeasingSessionExpireCancel(t *testing.T) {
 		},
 		},
 	}
 	}
 	for i := range tests {
 	for i := range tests {
-		lkv, closeLKV, err := leasing.NewKV(clus.Client(0), "foo/", concurrency.WithTTL(1))
-		testutil.AssertNil(t, err)
-		defer closeLKV()
+		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
+			defer testutil.AfterTest(t)
+			clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
+			defer clus.Terminate(t)
 
 
-		if _, err = lkv.Get(context.TODO(), "abc"); err != nil {
-			t.Fatal(err)
-		}
+			lkv, closeLKV, err := leasing.NewKV(clus.Client(0), "foo/", concurrency.WithTTL(1))
+			testutil.AssertNil(t, err)
+			defer closeLKV()
 
 
-		// down endpoint lkv uses for keepalives
-		clus.Members[0].Stop(t)
-		if err := waitForLeasingExpire(clus.Client(1), "foo/abc"); err != nil {
-			t.Fatal(err)
-		}
-		waitForExpireAck(t, lkv)
+			if _, err = lkv.Get(context.TODO(), "abc"); err != nil {
+				t.Fatal(err)
+			}
 
 
-		ctx, cancel := context.WithCancel(context.TODO())
-		errc := make(chan error, 1)
-		go func() { errc <- tests[i](ctx, lkv) }()
-		// some delay to get past for ctx.Err() != nil {} loops
-		time.Sleep(100 * time.Millisecond)
-		cancel()
+			// down endpoint lkv uses for keepalives
+			clus.Members[0].Stop(t)
+			if err := waitForLeasingExpire(clus.Client(1), "foo/abc"); err != nil {
+				t.Fatal(err)
+			}
+			waitForExpireAck(t, lkv)
 
 
-		select {
-		case err := <-errc:
-			if err != ctx.Err() {
-				t.Errorf("#%d: expected %v, got %v", i, ctx.Err(), err)
+			ctx, cancel := context.WithCancel(context.TODO())
+			errc := make(chan error, 1)
+			go func() { errc <- tests[i](ctx, lkv) }()
+			// some delay to get past for ctx.Err() != nil {} loops
+			time.Sleep(100 * time.Millisecond)
+			cancel()
+
+			select {
+			case err := <-errc:
+				if err != ctx.Err() {
+					t.Errorf("#%d: expected %v, got %v", i, ctx.Err(), err)
+				}
+			case <-time.After(5 * time.Second):
+				t.Errorf("#%d: timed out waiting for cancel", i)
 			}
 			}
-		case <-time.After(5 * time.Second):
-			t.Errorf("#%d: timed out waiting for cancel", i)
-		}
-		clus.Members[0].Restart(t)
+			clus.Members[0].Restart(t)
+		})
 	}
 	}
 }
 }
 
 
@@ -2016,6 +2018,8 @@ func waitForExpireAck(t *testing.T, kv clientv3.KV) {
 		cancel()
 		cancel()
 		if err == ctx.Err() {
 		if err == ctx.Err() {
 			return
 			return
+		} else if err != nil {
+			t.Logf("current error: %v", err)
 		}
 		}
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}

+ 4 - 5
clientv3/integration/maintenance_test.go

@@ -21,7 +21,6 @@ import (
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"path/filepath"
 	"path/filepath"
-	"strings"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -131,8 +130,8 @@ func TestMaintenanceSnapshotError(t *testing.T) {
 	time.Sleep(2 * time.Second)
 	time.Sleep(2 * time.Second)
 
 
 	_, err = io.Copy(ioutil.Discard, rc2)
 	_, err = io.Copy(ioutil.Discard, rc2)
-	if err != nil && err != context.DeadlineExceeded {
-		t.Errorf("expected %v, got %v", context.DeadlineExceeded, err)
+	if err != nil && !isClientTimeout(err) {
+		t.Errorf("expected client timeout, got %v", err)
 	}
 	}
 }
 }
 
 
@@ -189,7 +188,7 @@ func TestMaintenanceSnapshotErrorInflight(t *testing.T) {
 	// 300ms left and expect timeout while snapshot reading is in progress
 	// 300ms left and expect timeout while snapshot reading is in progress
 	time.Sleep(700 * time.Millisecond)
 	time.Sleep(700 * time.Millisecond)
 	_, err = io.Copy(ioutil.Discard, rc2)
 	_, err = io.Copy(ioutil.Discard, rc2)
-	if err != nil && !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) {
-		t.Errorf("expected %v from gRPC, got %v", context.DeadlineExceeded, err)
+	if err != nil && !isClientTimeout(err) {
+		t.Errorf("expected client timeout, got %v", err)
 	}
 	}
 }
 }

+ 9 - 5
clientv3/integration/network_partition_test.go

@@ -26,6 +26,7 @@ import (
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/testutil"
+	"google.golang.org/grpc"
 )
 )
 
 
 var errExpected = errors.New("expected error")
 var errExpected = errors.New("expected error")
@@ -36,7 +37,7 @@ var errExpected = errors.New("expected error")
 func TestBalancerUnderNetworkPartitionPut(t *testing.T) {
 func TestBalancerUnderNetworkPartitionPut(t *testing.T) {
 	testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
 	testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
 		_, err := cli.Put(ctx, "a", "b")
 		_, err := cli.Put(ctx, "a", "b")
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
+		if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -46,7 +47,7 @@ func TestBalancerUnderNetworkPartitionPut(t *testing.T) {
 func TestBalancerUnderNetworkPartitionDelete(t *testing.T) {
 func TestBalancerUnderNetworkPartitionDelete(t *testing.T) {
 	testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
 	testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
 		_, err := cli.Delete(ctx, "a")
 		_, err := cli.Delete(ctx, "a")
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
+		if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -59,7 +60,7 @@ func TestBalancerUnderNetworkPartitionTxn(t *testing.T) {
 			If(clientv3.Compare(clientv3.Version("foo"), "=", 0)).
 			If(clientv3.Compare(clientv3.Version("foo"), "=", 0)).
 			Then(clientv3.OpPut("foo", "bar")).
 			Then(clientv3.OpPut("foo", "bar")).
 			Else(clientv3.OpPut("foo", "baz")).Commit()
 			Else(clientv3.OpPut("foo", "baz")).Commit()
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
+		if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -82,7 +83,7 @@ func TestBalancerUnderNetworkPartitionLinearizableGetWithLongTimeout(t *testing.
 func TestBalancerUnderNetworkPartitionLinearizableGetWithShortTimeout(t *testing.T) {
 func TestBalancerUnderNetworkPartitionLinearizableGetWithShortTimeout(t *testing.T) {
 	testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
 	testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
 		_, err := cli.Get(ctx, "a")
 		_, err := cli.Get(ctx, "a")
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) {
+		if isClientTimeout(err) || isServerCtxTimeout(err) {
 			return errExpected
 			return errExpected
 		}
 		}
 		return err
 		return err
@@ -111,6 +112,7 @@ func testBalancerUnderNetworkPartition(t *testing.T, op func(*clientv3.Client, c
 	ccfg := clientv3.Config{
 	ccfg := clientv3.Config{
 		Endpoints:   []string{eps[0]},
 		Endpoints:   []string{eps[0]},
 		DialTimeout: 3 * time.Second,
 		DialTimeout: 3 * time.Second,
+		DialOptions: []grpc.DialOption{grpc.WithBlock()},
 	}
 	}
 	cli, err := clientv3.New(ccfg)
 	cli, err := clientv3.New(ccfg)
 	if err != nil {
 	if err != nil {
@@ -123,6 +125,7 @@ func testBalancerUnderNetworkPartition(t *testing.T, op func(*clientv3.Client, c
 
 
 	// add other endpoints for later endpoint switch
 	// add other endpoints for later endpoint switch
 	cli.SetEndpoints(eps...)
 	cli.SetEndpoints(eps...)
+	time.Sleep(time.Second * 2)
 	clus.Members[0].InjectPartition(t, clus.Members[1:]...)
 	clus.Members[0].InjectPartition(t, clus.Members[1:]...)
 
 
 	for i := 0; i < 2; i++ {
 	for i := 0; i < 2; i++ {
@@ -133,7 +136,7 @@ func testBalancerUnderNetworkPartition(t *testing.T, op func(*clientv3.Client, c
 			break
 			break
 		}
 		}
 		if err != errExpected {
 		if err != errExpected {
-			t.Errorf("#%d: expected %v, got %v", i, errExpected, err)
+			t.Errorf("#%d: expected '%v', got '%v'", i, errExpected, err)
 		}
 		}
 		// give enough time for endpoint switch
 		// give enough time for endpoint switch
 		// TODO: remove random sleep by syncing directly with balancer
 		// TODO: remove random sleep by syncing directly with balancer
@@ -166,6 +169,7 @@ func TestBalancerUnderNetworkPartitionLinearizableGetLeaderElection(t *testing.T
 	cli, err := clientv3.New(clientv3.Config{
 	cli, err := clientv3.New(clientv3.Config{
 		Endpoints:   []string{eps[(lead+1)%2]},
 		Endpoints:   []string{eps[(lead+1)%2]},
 		DialTimeout: 1 * time.Second,
 		DialTimeout: 1 * time.Second,
+		DialOptions: []grpc.DialOption{grpc.WithBlock()},
 	})
 	})
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)

+ 33 - 3
clientv3/integration/server_shutdown_test.go

@@ -29,6 +29,7 @@ import (
 
 
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 	"google.golang.org/grpc/status"
+	"google.golang.org/grpc/transport"
 )
 )
 
 
 // TestBalancerUnderServerShutdownWatch expects that watch client
 // TestBalancerUnderServerShutdownWatch expects that watch client
@@ -105,7 +106,7 @@ func TestBalancerUnderServerShutdownWatch(t *testing.T) {
 		if err == nil {
 		if err == nil {
 			break
 			break
 		}
 		}
-		if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout || err == rpctypes.ErrTimeoutDueToLeaderFail {
+		if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout || err == rpctypes.ErrTimeoutDueToLeaderFail {
 			continue
 			continue
 		}
 		}
 		t.Fatal(err)
 		t.Fatal(err)
@@ -341,10 +342,10 @@ func testBalancerUnderServerStopInflightRangeOnRestart(t *testing.T, linearizabl
 		_, err := cli.Get(ctx, "abc", gops...)
 		_, err := cli.Get(ctx, "abc", gops...)
 		cancel()
 		cancel()
 		if err != nil {
 		if err != nil {
-			if linearizable && strings.Contains(err.Error(), "context deadline exceeded") {
+			if linearizable && isServerUnavailable(err) {
 				t.Logf("TODO: FIX THIS after balancer rewrite! %v %v", reflect.TypeOf(err), err)
 				t.Logf("TODO: FIX THIS after balancer rewrite! %v %v", reflect.TypeOf(err), err)
 			} else {
 			} else {
-				t.Fatal(err)
+				t.Fatalf("expected linearizable=true and a server unavailable error, but got linearizable=%t and '%v'", linearizable, err)
 			}
 			}
 		}
 		}
 	}()
 	}()
@@ -373,3 +374,32 @@ func isServerCtxTimeout(err error) bool {
 	code := ev.Code()
 	code := ev.Code()
 	return code == codes.DeadlineExceeded && strings.Contains(err.Error(), "context deadline exceeded")
 	return code == codes.DeadlineExceeded && strings.Contains(err.Error(), "context deadline exceeded")
 }
 }
+
+// In grpc v1.11.3+ dial timeouts can error out with transport.ErrConnClosing. Previously dial timeouts
+// would always error out with context.DeadlineExceeded.
+func isClientTimeout(err error) bool {
+	if err == nil {
+		return false
+	}
+	if err == context.DeadlineExceeded {
+		return true
+	}
+	ev, ok := status.FromError(err)
+	if !ok {
+		return false
+	}
+	code := ev.Code()
+	return code == codes.DeadlineExceeded || ev.Message() == transport.ErrConnClosing.Desc
+}
+
+func isServerUnavailable(err error) bool {
+	if err == nil {
+		return false
+	}
+	ev, ok := status.FromError(err)
+	if !ok {
+		return false
+	}
+	code := ev.Code()
+	return code == codes.Unavailable
+}

+ 7 - 1
clientv3/integration/user_test.go

@@ -17,11 +17,13 @@ package integration
 import (
 import (
 	"context"
 	"context"
 	"testing"
 	"testing"
+	"time"
 
 
 	"github.com/coreos/etcd/clientv3"
 	"github.com/coreos/etcd/clientv3"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/pkg/testutil"
 	"github.com/coreos/etcd/pkg/testutil"
+	"google.golang.org/grpc"
 )
 )
 
 
 func TestUserError(t *testing.T) {
 func TestUserError(t *testing.T) {
@@ -68,7 +70,11 @@ func TestUserErrorAuth(t *testing.T) {
 	}
 	}
 
 
 	// wrong id or password
 	// wrong id or password
-	cfg := clientv3.Config{Endpoints: authapi.Endpoints()}
+	cfg := clientv3.Config{
+		Endpoints:   authapi.Endpoints(),
+		DialTimeout: 5 * time.Second,
+		DialOptions: []grpc.DialOption{grpc.WithBlock()},
+	}
 	cfg.Username, cfg.Password = "wrong-id", "123"
 	cfg.Username, cfg.Password = "wrong-id", "123"
 	if _, err := clientv3.New(cfg); err != rpctypes.ErrAuthFailed {
 	if _, err := clientv3.New(cfg); err != rpctypes.ErrAuthFailed {
 		t.Fatalf("expected %v, got %v", rpctypes.ErrAuthFailed, err)
 		t.Fatalf("expected %v, got %v", rpctypes.ErrAuthFailed, err)

+ 2 - 2
clientv3/integration/watch_test.go

@@ -667,8 +667,8 @@ func TestWatchErrConnClosed(t *testing.T) {
 	go func() {
 	go func() {
 		defer close(donec)
 		defer close(donec)
 		ch := cli.Watch(context.TODO(), "foo")
 		ch := cli.Watch(context.TODO(), "foo")
-		if wr := <-ch; grpc.ErrorDesc(wr.Err()) != grpc.ErrClientConnClosing.Error() {
-			t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, grpc.ErrorDesc(wr.Err()))
+		if wr := <-ch; wr.Err() != grpc.ErrClientConnClosing {
+			t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, wr.Err())
 		}
 		}
 	}()
 	}()
 
 

+ 1 - 0
integration/cluster.go

@@ -734,6 +734,7 @@ func NewClientV3(m *member) (*clientv3.Client, error) {
 	cfg := clientv3.Config{
 	cfg := clientv3.Config{
 		Endpoints:          []string{m.grpcAddr},
 		Endpoints:          []string{m.grpcAddr},
 		DialTimeout:        5 * time.Second,
 		DialTimeout:        5 * time.Second,
+		DialOptions:        []grpc.DialOption{grpc.WithBlock()},
 		MaxCallSendMsgSize: m.clientMaxCallSendMsgSize,
 		MaxCallSendMsgSize: m.clientMaxCallSendMsgSize,
 		MaxCallRecvMsgSize: m.clientMaxCallRecvMsgSize,
 		MaxCallRecvMsgSize: m.clientMaxCallRecvMsgSize,
 	}
 	}