Browse Source

clientv3: respect up/down notifications from grpc

Fixes #5842
Anthony Romano 9 years ago
parent
commit
46765ad79c
5 changed files with 194 additions and 23 deletions
  1. 93 16
      clientv3/balancer.go
  2. 23 6
      clientv3/client.go
  3. 3 0
      clientv3/client_test.go
  4. 73 0
      clientv3/integration/kv_test.go
  5. 2 1
      integration/v3_grpc_test.go

+ 93 - 16
clientv3/balancer.go

@@ -17,7 +17,7 @@ package clientv3
 import (
 	"net/url"
 	"strings"
-	"sync/atomic"
+	"sync"
 
 	"golang.org/x/net/context"
 	"google.golang.org/grpc"
@@ -26,32 +26,109 @@ import (
 // simpleBalancer does the bare minimum to expose multiple eps
 // to the grpc reconnection code path
 type simpleBalancer struct {
-	// eps are the client's endpoints stripped of any URL scheme
-	eps     []string
-	ch      chan []grpc.Address
-	numGets uint32
+	// addrs are the client's endpoints for grpc
+	addrs []grpc.Address
+	// notifyCh notifies grpc of the set of addresses for connecting
+	notifyCh chan []grpc.Address
+
+	// readyc closes once the first connection is up
+	readyc    chan struct{}
+	readyOnce sync.Once
+
+	// mu protects upEps, pinAddr, and connectingAddr
+	mu sync.RWMutex
+	// upEps holds the current endpoints that have an active connection
+	upEps map[string]struct{}
+	// upc closes when upEps transitions from empty to non-zero or the balancer closes.
+	upc chan struct{}
+
+	// pinAddr is the currently pinned address; set to the empty string on
+	// intialization and shutdown.
+	pinAddr string
 }
 
-func newSimpleBalancer(eps []string) grpc.Balancer {
-	ch := make(chan []grpc.Address, 1)
+func newSimpleBalancer(eps []string) *simpleBalancer {
+	notifyCh := make(chan []grpc.Address, 1)
 	addrs := make([]grpc.Address, len(eps))
 	for i := range eps {
 		addrs[i].Addr = getHost(eps[i])
 	}
-	ch <- addrs
-	return &simpleBalancer{eps: eps, ch: ch}
+	notifyCh <- addrs
+	sb := &simpleBalancer{
+		addrs:    addrs,
+		notifyCh: notifyCh,
+		readyc:   make(chan struct{}),
+		upEps:    make(map[string]struct{}),
+		upc:      make(chan struct{}),
+	}
+	return sb
+}
+
+func (b *simpleBalancer) Start(target string) error { return nil }
+
+func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
+	b.mu.Lock()
+	if len(b.upEps) == 0 {
+		// notify waiting Get()s and pin first connected address
+		close(b.upc)
+		b.pinAddr = addr.Addr
+	}
+	b.upEps[addr.Addr] = struct{}{}
+	b.mu.Unlock()
+	// notify client that a connection is up
+	b.readyOnce.Do(func() { close(b.readyc) })
+	return func(err error) {
+		b.mu.Lock()
+		delete(b.upEps, addr.Addr)
+		if len(b.upEps) == 0 && b.pinAddr != "" {
+			b.upc = make(chan struct{})
+		} else if b.pinAddr == addr.Addr {
+			// choose new random up endpoint
+			for k := range b.upEps {
+				b.pinAddr = k
+				break
+			}
+		}
+		b.mu.Unlock()
+	}
 }
 
-func (b *simpleBalancer) Start(target string) error        { return nil }
-func (b *simpleBalancer) Up(addr grpc.Address) func(error) { return func(error) {} }
 func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
-	v := atomic.AddUint32(&b.numGets, 1)
-	ep := b.eps[v%uint32(len(b.eps))]
-	return grpc.Address{Addr: getHost(ep)}, func() {}, nil
+	var addr string
+	for {
+		b.mu.RLock()
+		ch := b.upc
+		b.mu.RUnlock()
+		select {
+		case <-ch:
+		case <-ctx.Done():
+			return grpc.Address{Addr: ""}, nil, ctx.Err()
+		}
+		b.mu.RLock()
+		addr = b.pinAddr
+		upEps := len(b.upEps)
+		b.mu.RUnlock()
+		if addr == "" {
+			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
+		}
+		if upEps > 0 {
+			break
+		}
+	}
+	return grpc.Address{Addr: addr}, func() {}, nil
 }
-func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.ch }
+
+func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
+
 func (b *simpleBalancer) Close() error {
-	close(b.ch)
+	b.mu.Lock()
+	close(b.notifyCh)
+	// terminate all waiting Get()s
+	b.pinAddr = ""
+	if len(b.upEps) == 0 {
+		close(b.upc)
+	}
+	b.mu.Unlock()
 	return nil
 }
 

+ 23 - 6
clientv3/client.go

@@ -141,10 +141,7 @@ func (c *Client) dialTarget(endpoint string) (proto string, host string, creds *
 // dialSetupOpts gives the dial opts prior to any authentication
 func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts []grpc.DialOption) {
 	if c.cfg.DialTimeout > 0 {
-		opts = []grpc.DialOption{
-			grpc.WithTimeout(c.cfg.DialTimeout),
-			grpc.WithBlock(),
-		}
+		opts = []grpc.DialOption{grpc.WithTimeout(c.cfg.DialTimeout)}
 	}
 	opts = append(opts, dopts...)
 
@@ -249,6 +246,23 @@ func newClient(cfg *Config) (*Client, error) {
 	}
 	client.conn = conn
 
+	// wait for a connection
+	if cfg.DialTimeout > 0 {
+		hasConn := false
+		waitc := time.After(cfg.DialTimeout)
+		select {
+		case <-client.balancer.readyc:
+			hasConn = true
+		case <-ctx.Done():
+		case <-waitc:
+		}
+		if !hasConn {
+			client.cancel()
+			conn.Close()
+			return nil, grpc.ErrClientConnTimeout
+		}
+	}
+
 	client.Cluster = NewCluster(client)
 	client.KV = NewKV(client)
 	client.Lease = NewLease(client)
@@ -291,9 +305,12 @@ func toErr(ctx context.Context, err error) error {
 		return nil
 	}
 	err = rpctypes.Error(err)
-	if ctx.Err() != nil && strings.Contains(err.Error(), "context") {
+	switch {
+	case ctx.Err() != nil && strings.Contains(err.Error(), "context"):
 		err = ctx.Err()
-	} else if strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error()) {
+	case strings.Contains(err.Error(), ErrNoAvailableEndpoints.Error()):
+		err = ErrNoAvailableEndpoints
+	case strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error()):
 		err = grpc.ErrClientConnClosing
 	}
 	return err

+ 3 - 0
clientv3/client_test.go

@@ -20,11 +20,14 @@ import (
 	"time"
 
 	"github.com/coreos/etcd/etcdserver"
+	"github.com/coreos/etcd/pkg/testutil"
 	"golang.org/x/net/context"
 	"google.golang.org/grpc"
 )
 
 func TestDialTimeout(t *testing.T) {
+	defer testutil.AfterTest(t)
+
 	donec := make(chan error)
 	go func() {
 		// without timeout, grpc keeps redialing if connection refused

+ 73 - 0
clientv3/integration/kv_test.go

@@ -16,6 +16,7 @@ package integration
 
 import (
 	"bytes"
+	"math/rand"
 	"reflect"
 	"strings"
 	"testing"
@@ -662,3 +663,75 @@ func TestKVPutStoppedServerAndClose(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+// TestKVGetOneEndpointDown ensures a client can connect and get if one endpoint is down
+func TestKVPutOneEndpointDown(t *testing.T) {
+	defer testutil.AfterTest(t)
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
+	defer clus.Terminate(t)
+
+	// get endpoint list
+	eps := make([]string, 3)
+	for i := range eps {
+		eps[i] = clus.Members[i].GRPCAddr()
+	}
+
+	// make a dead node
+	clus.Members[rand.Intn(len(eps))].Stop(t)
+
+	// try to connect with dead node in the endpoint list
+	cfg := clientv3.Config{Endpoints: eps, DialTimeout: 1 * time.Second}
+	cli, err := clientv3.New(cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cli.Close()
+	ctx, cancel := context.WithTimeout(context.TODO(), 3*time.Second)
+	if _, err := cli.Get(ctx, "abc", clientv3.WithSerializable()); err != nil {
+		t.Fatal(err)
+	}
+	cancel()
+}
+
+// TestKVGetResetLoneEndpoint ensures that if an endpoint resets and all other
+// endpoints are down, then it will reconnect.
+func TestKVGetResetLoneEndpoint(t *testing.T) {
+	defer testutil.AfterTest(t)
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 2})
+	defer clus.Terminate(t)
+
+	// get endpoint list
+	eps := make([]string, 2)
+	for i := range eps {
+		eps[i] = clus.Members[i].GRPCAddr()
+	}
+
+	cfg := clientv3.Config{Endpoints: eps, DialTimeout: 500 * time.Millisecond}
+	cli, err := clientv3.New(cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cli.Close()
+
+	// disconnect everything
+	clus.Members[0].Stop(t)
+	clus.Members[1].Stop(t)
+
+	// have Get try to reconnect
+	donec := make(chan struct{})
+	go func() {
+		ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
+		if _, err := cli.Get(ctx, "abc", clientv3.WithSerializable()); err != nil {
+			t.Fatal(err)
+		}
+		cancel()
+		close(donec)
+	}()
+	time.Sleep(500 * time.Millisecond)
+	clus.Members[0].Restart(t)
+	select {
+	case <-time.After(10 * time.Second):
+		t.Fatalf("timed out waiting for Get")
+	case <-donec:
+	}
+}

+ 2 - 1
integration/v3_grpc_test.go

@@ -21,6 +21,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/coreos/etcd/clientv3"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
@@ -960,7 +961,7 @@ func TestTLSGRPCRejectSecureClient(t *testing.T) {
 	client, err := NewClientV3(clus.Members[0])
 	if client != nil || err == nil {
 		t.Fatalf("expected no client")
-	} else if err != grpc.ErrClientConnTimeout {
+	} else if err != clientv3.ErrNoAvailableEndpoints {
 		t.Fatalf("unexpected error (%v)", err)
 	}
 }