Browse Source

clientv3: Fix endpoint resolver to create a new resolver for each grpc client connection

Joe Betz 7 years ago
parent
commit
8569b9c782

+ 25 - 17
clientv3/balancer/balancer_test.go

@@ -30,7 +30,6 @@ import (
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/peer"
-	"google.golang.org/grpc/resolver"
 	"google.golang.org/grpc/status"
 )
 
@@ -58,14 +57,17 @@ func TestRoundRobinBalancedResolvableNoFailover(t *testing.T) {
 			}
 			defer ms.Stop()
 
-			var resolvedAddrs []resolver.Address
+			var eps []string
 			for _, svr := range ms.Servers {
-				resolvedAddrs = append(resolvedAddrs, svr.ResolverAddress())
+				eps = append(eps, svr.ResolverAddress().Addr)
 			}
 
-			rsv := endpoint.EndpointResolver("nofailover")
+			rsv, err := endpoint.NewResolverGroup("nofailover")
+			if err != nil {
+				t.Fatal(err)
+			}
 			defer rsv.Close()
-			rsv.InitialAddrs(resolvedAddrs)
+			rsv.SetEndpoints(eps)
 
 			name := genName()
 			cfg := Config{
@@ -121,14 +123,17 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
 		t.Fatalf("failed to start mock servers: %s", err)
 	}
 	defer ms.Stop()
-	var resolvedAddrs []resolver.Address
+	var eps []string
 	for _, svr := range ms.Servers {
-		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address})
+		eps = append(eps, svr.ResolverAddress().Addr)
 	}
 
-	rsv := endpoint.EndpointResolver("serverfail")
+	rsv, err := endpoint.NewResolverGroup("serverfail")
+	if err != nil {
+		t.Fatal(err)
+	}
 	defer rsv.Close()
-	rsv.InitialAddrs(resolvedAddrs)
+	rsv.SetEndpoints(eps)
 
 	name := genName()
 	cfg := Config{
@@ -158,7 +163,7 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
 	ms.StopAt(0)
 	available := make(map[string]struct{})
 	for i := 1; i < serverCount; i++ {
-		available[resolvedAddrs[i].Addr] = struct{}{}
+		available[eps[i]] = struct{}{}
 	}
 
 	reqN := 10
@@ -169,8 +174,8 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
 			continue
 		}
 		if prev == "" { // first failover
-			if resolvedAddrs[0].Addr == picked {
-				t.Fatalf("expected failover from %q, picked %q", resolvedAddrs[0].Addr, picked)
+			if eps[0] == picked {
+				t.Fatalf("expected failover from %q, picked %q", eps[0], picked)
 			}
 			prev = picked
 			continue
@@ -194,7 +199,7 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
 	time.Sleep(time.Second)
 
 	prev, switches = "", 0
-	recoveredAddr, recovered := resolvedAddrs[0].Addr, 0
+	recoveredAddr, recovered := eps[0], 0
 	available[recoveredAddr] = struct{}{}
 
 	for i := 0; i < 2*reqN; i++ {
@@ -234,15 +239,18 @@ func TestRoundRobinBalancedResolvableFailoverFromRequestFail(t *testing.T) {
 		t.Fatalf("failed to start mock servers: %s", err)
 	}
 	defer ms.Stop()
-	var resolvedAddrs []resolver.Address
+	var eps []string
 	available := make(map[string]struct{})
 	for _, svr := range ms.Servers {
-		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address})
+		eps = append(eps, svr.ResolverAddress().Addr)
 		available[svr.Address] = struct{}{}
 	}
-	rsv := endpoint.EndpointResolver("requestfail")
+	rsv, err := endpoint.NewResolverGroup("requestfail")
+	if err != nil {
+		t.Fatal(err)
+	}
 	defer rsv.Close()
-	rsv.InitialAddrs(resolvedAddrs)
+	rsv.SetEndpoints(eps)
 
 	name := genName()
 	cfg := Config{

+ 104 - 76
clientv3/balancer/resolver/endpoint/endpoint.go

@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint://<clientId>/<endpoint>'.
+// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint://<id>/<endpoint>'.
 package endpoint
 
 import (
@@ -36,91 +36,140 @@ var (
 
 func init() {
 	bldr = &builder{
-		clientResolvers: make(map[string]*Resolver),
+		resolverGroups: make(map[string]*ResolverGroup),
 	}
 	resolver.Register(bldr)
 }
 
 type builder struct {
-	clientResolvers map[string]*Resolver
+	resolverGroups map[string]*ResolverGroup
 	sync.RWMutex
 }
 
+// NewResolverGroup creates a new ResolverGroup with the given id.
+func NewResolverGroup(id string) (*ResolverGroup, error) {
+	return bldr.newResolverGroup(id)
+}
+
+// ResolverGroup keeps all endpoints of resolvers using a common endpoint://<id>/ target
+// up-to-date.
+type ResolverGroup struct {
+	id        string
+	endpoints []string
+	resolvers []*Resolver
+	sync.RWMutex
+}
+
+func (e *ResolverGroup) addResolver(r *Resolver) {
+	e.Lock()
+	addrs := epsToAddrs(e.endpoints...)
+	e.resolvers = append(e.resolvers, r)
+	e.Unlock()
+	r.cc.NewAddress(addrs)
+}
+
+func (e *ResolverGroup) removeResolver(r *Resolver) {
+	e.Lock()
+	for i, er := range e.resolvers {
+		if er == r {
+			e.resolvers = append(e.resolvers[:i], e.resolvers[i+1:]...)
+			break
+		}
+	}
+	e.Unlock()
+}
+
+// SetEndpoints updates the endpoints for ResolverGroup. All registered resolver are updated
+// immediately with the new endpoints.
+func (e *ResolverGroup) SetEndpoints(endpoints []string) {
+	addrs := epsToAddrs(endpoints...)
+	e.Lock()
+	e.endpoints = endpoints
+	for _, r := range e.resolvers {
+		r.cc.NewAddress(addrs)
+	}
+	e.Unlock()
+}
+
+// Target constructs a endpoint target using the endpoint id of the ResolverGroup.
+func (e *ResolverGroup) Target(endpoint string) string {
+	return Target(e.id, endpoint)
+}
+
+// Target constructs a endpoint resolver target.
+func Target(id, endpoint string) string {
+	return fmt.Sprintf("%s://%s/%s", scheme, id, endpoint)
+}
+
+// IsTarget checks if a given target string in an endpoint resolver target.
+func IsTarget(target string) bool {
+	return strings.HasPrefix(target, "endpoint://")
+}
+
+func (e *ResolverGroup) Close() {
+	bldr.close(e.id)
+}
+
 // Build creates or reuses an etcd resolver for the etcd cluster name identified by the authority part of the target.
 func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
 	if len(target.Authority) < 1 {
 		return nil, fmt.Errorf("'etcd' target scheme requires non-empty authority identifying etcd cluster being routed to")
 	}
-	r := b.getResolver(target.Authority)
-	r.cc = cc
-	if r.addrs != nil {
-		r.NewAddress(r.addrs)
+	id := target.Authority
+	es, err := b.getResolverGroup(id)
+	if err != nil {
+		return nil, fmt.Errorf("failed to build resolver: %v", err)
+	}
+	r := &Resolver{
+		endpointId: id,
+		cc:         cc,
 	}
+	es.addResolver(r)
 	return r, nil
 }
 
-func (b *builder) getResolver(clientId string) *Resolver {
+func (b *builder) newResolverGroup(id string) (*ResolverGroup, error) {
 	b.RLock()
-	r, ok := b.clientResolvers[clientId]
+	es, ok := b.resolverGroups[id]
 	b.RUnlock()
 	if !ok {
-		r = &Resolver{
-			clientId: clientId,
-		}
+		es = &ResolverGroup{id: id}
 		b.Lock()
-		b.clientResolvers[clientId] = r
+		b.resolverGroups[id] = es
 		b.Unlock()
+	} else {
+		return nil, fmt.Errorf("Endpoint already exists for id: %s", id)
 	}
-	return r
+	return es, nil
 }
 
-func (b *builder) addResolver(r *Resolver) {
-	bldr.Lock()
-	bldr.clientResolvers[r.clientId] = r
-	bldr.Unlock()
+func (b *builder) getResolverGroup(id string) (*ResolverGroup, error) {
+	b.RLock()
+	es, ok := b.resolverGroups[id]
+	b.RUnlock()
+	if !ok {
+		return nil, fmt.Errorf("ResolverGroup not found for id: %s", id)
+	}
+	return es, nil
 }
 
-func (b *builder) removeResolver(r *Resolver) {
-	bldr.Lock()
-	delete(bldr.clientResolvers, r.clientId)
-	bldr.Unlock()
+func (b *builder) close(id string) {
+	b.Lock()
+	delete(b.resolverGroups, id)
+	b.Unlock()
 }
 
 func (r *builder) Scheme() string {
 	return scheme
 }
 
-// EndpointResolver gets the resolver for  given etcd cluster name.
-func EndpointResolver(clientId string) *Resolver {
-	return bldr.getResolver(clientId)
-}
-
 // Resolver provides a resolver for a single etcd cluster, identified by name.
 type Resolver struct {
-	clientId string
-	cc       resolver.ClientConn
-	addrs    []resolver.Address
+	endpointId string
+	cc         resolver.ClientConn
 	sync.RWMutex
 }
 
-// InitialAddrs sets the initial endpoint addresses for the resolver.
-func (r *Resolver) InitialAddrs(addrs []resolver.Address) {
-	r.Lock()
-	r.addrs = addrs
-	r.Unlock()
-}
-
-// InitialEndpoints sets the initial endpoints to for the resolver.
-// This should be called before dialing. The endpoints may be updated after the dial using NewAddress.
-// At least one endpoint is required.
-func (r *Resolver) InitialEndpoints(eps []string) error {
-	if len(eps) < 1 {
-		return fmt.Errorf("At least one endpoint is required, but got: %v", eps)
-	}
-	r.InitialAddrs(epsToAddrs(eps...))
-	return nil
-}
-
 // TODO: use balancer.epsToAddrs
 func epsToAddrs(eps ...string) (addrs []resolver.Address) {
 	addrs = make([]resolver.Address, 0, len(eps))
@@ -130,35 +179,14 @@ func epsToAddrs(eps ...string) (addrs []resolver.Address) {
 	return addrs
 }
 
-// NewAddress updates the addresses of the resolver.
-func (r *Resolver) NewAddress(addrs []resolver.Address) {
-	r.Lock()
-	r.addrs = addrs
-	r.Unlock()
-	if r.cc != nil {
-		r.cc.NewAddress(addrs)
-	}
-}
-
 func (*Resolver) ResolveNow(o resolver.ResolveNowOption) {}
 
 func (r *Resolver) Close() {
-	bldr.removeResolver(r)
-}
-
-// Target constructs a endpoint target with current resolver's clientId.
-func (r *Resolver) Target(endpoint string) string {
-	return Target(r.clientId, endpoint)
-}
-
-// Target constructs a endpoint resolver target.
-func Target(clientId, endpoint string) string {
-	return fmt.Sprintf("%s://%s/%s", scheme, clientId, endpoint)
-}
-
-// IsTarget checks if a given target string in an endpoint resolver target.
-func IsTarget(target string) bool {
-	return strings.HasPrefix(target, "endpoint://")
+	es, err := bldr.getResolverGroup(r.endpointId)
+	if err != nil {
+		return
+	}
+	es.removeResolver(r)
 }
 
 // Parse endpoint parses a endpoint of the form (http|https)://<host>*|(unix|unixs)://<path>) and returns a
@@ -185,7 +213,7 @@ func ParseEndpoint(endpoint string) (proto string, host string, scheme string) {
 	return proto, host, scheme
 }
 
-// ParseTarget parses a endpoint://<clientId>/<endpoint> string and returns the parsed clientId and endpoint.
+// ParseTarget parses a endpoint://<id>/<endpoint> string and returns the parsed id and endpoint.
 // If the target is malformed, an error is returned.
 func ParseTarget(target string) (string, string, error) {
 	noPrefix := strings.TrimPrefix(target, targetPrefix)
@@ -194,7 +222,7 @@ func ParseTarget(target string) (string, string, error) {
 	}
 	parts := strings.SplitN(noPrefix, "/", 2)
 	if len(parts) != 2 {
-		return "", "", fmt.Errorf("malformed target, expected %s://<clientId>/<endpoint>, but got %s", scheme, target)
+		return "", "", fmt.Errorf("malformed target, expected %s://<id>/<endpoint>, but got %s", scheme, target)
 	}
 	return parts[0], parts[1], nil
 }

+ 16 - 28
clientv3/client.go

@@ -37,7 +37,6 @@ import (
 	"google.golang.org/grpc/credentials"
 	"google.golang.org/grpc/keepalive"
 	"google.golang.org/grpc/metadata"
-	"google.golang.org/grpc/resolver"
 	"google.golang.org/grpc/status"
 )
 
@@ -68,11 +67,11 @@ type Client struct {
 	conn     *grpc.ClientConn
 	dialerrc chan error
 
-	cfg      Config
-	creds    *credentials.TransportCredentials
-	balancer balancer.Balancer
-	resolver *endpoint.Resolver
-	mu       *sync.Mutex
+	cfg           Config
+	creds         *credentials.TransportCredentials
+	balancer      balancer.Balancer
+	resolverGroup *endpoint.ResolverGroup
+	mu            *sync.Mutex
 
 	ctx    context.Context
 	cancel context.CancelFunc
@@ -119,12 +118,12 @@ func (c *Client) Close() error {
 	c.cancel()
 	c.Watcher.Close()
 	c.Lease.Close()
+	if c.resolverGroup != nil {
+		c.resolverGroup.Close()
+	}
 	if c.conn != nil {
 		return toErr(c.ctx, c.conn.Close())
 	}
-	if c.resolver != nil {
-		c.resolver.Close()
-	}
 	return c.ctx.Err()
 }
 
@@ -143,22 +142,10 @@ func (c *Client) Endpoints() (eps []string) {
 
 // SetEndpoints updates client's endpoints.
 func (c *Client) SetEndpoints(eps ...string) {
-	var addrs []resolver.Address
-	for _, ep := range eps {
-		addrs = append(addrs, resolver.Address{Addr: ep})
-	}
-
 	c.mu.Lock()
 	defer c.mu.Unlock()
 	c.cfg.Endpoints = eps
-	c.resolver.NewAddress(addrs)
-	// TODO: Does the new grpc balancer provide a way to block until the endpoint changes are propagated?
-	/*if c.balancer.NeedUpdate() {
-		select {
-		case c.balancer.UpdateAddrsC() <- balancer.NotifyNext:
-		case <-c.balancer.StopC():
-		}
-	}*/
+	c.resolverGroup.SetEndpoints(eps)
 }
 
 // Sync synchronizes client's endpoints with the known endpoints from the etcd membership.
@@ -301,12 +288,13 @@ func (c *Client) getToken(ctx context.Context) error {
 		// use dial options without dopts to avoid reusing the client balancer
 		var dOpts []grpc.DialOption
 		_, host, _ := endpoint.ParseEndpoint(ep)
-		target := c.resolver.Target(host)
+		target := c.resolverGroup.Target(host)
 		dOpts, err = c.dialSetupOpts(target, c.cfg.DialOptions...)
 		if err != nil {
 			err = fmt.Errorf("failed to configure auth dialer: %v", err)
 			continue
 		}
+		dOpts = append(dOpts, grpc.WithBalancerName(roundRobinBalancerName))
 		auth, err = newAuthenticator(ctx, target, dOpts, c)
 		if err != nil {
 			continue
@@ -333,7 +321,7 @@ func (c *Client) dial(ep string, dopts ...grpc.DialOption) (*grpc.ClientConn, er
 	// We pass a target to DialContext of the form: endpoint://<clusterName>/<host-part> that
 	// does not include scheme (http/https/unix/unixs) or path parts.
 	_, host, _ := endpoint.ParseEndpoint(ep)
-	target := c.resolver.Target(host)
+	target := c.resolverGroup.Target(host)
 
 	opts, err := c.dialSetupOpts(target, dopts...)
 	if err != nil {
@@ -439,13 +427,13 @@ func newClient(cfg *Config) (*Client, error) {
 
 	// Prepare a 'endpoint://<unique-client-id>/' resolver for the client and create a endpoint target to pass
 	// to dial so the client knows to use this resolver.
-	client.resolver = endpoint.EndpointResolver(fmt.Sprintf("client-%s", strconv.FormatInt(time.Now().UnixNano(), 36)))
-	err := client.resolver.InitialEndpoints(cfg.Endpoints)
+	var err error
+	client.resolverGroup, err = endpoint.NewResolverGroup(fmt.Sprintf("client-%s", strconv.FormatInt(time.Now().UnixNano(), 36)))
 	if err != nil {
 		client.cancel()
-		client.resolver.Close()
 		return nil, err
 	}
+	client.resolverGroup.SetEndpoints(cfg.Endpoints)
 
 	if len(cfg.Endpoints) < 1 {
 		return nil, fmt.Errorf("at least one Endpoint must is required in client config")
@@ -457,7 +445,7 @@ func newClient(cfg *Config) (*Client, error) {
 	conn, err := client.dial(dialEndpoint, grpc.WithBalancerName(roundRobinBalancerName))
 	if err != nil {
 		client.cancel()
-		client.resolver.Close()
+		client.resolverGroup.Close()
 		return nil, err
 	}
 	// TODO: With the old grpc balancer interface, we waited until the dial timeout

+ 15 - 0
clientv3/integration/server_shutdown_test.go

@@ -403,3 +403,18 @@ func isServerUnavailable(err error) bool {
 	code := ev.Code()
 	return code == codes.Unavailable
 }
+
+func isCanceled(err error) bool {
+	if err == nil {
+		return false
+	}
+	if err == context.Canceled {
+		return true
+	}
+	ev, ok := status.FromError(err)
+	if !ok {
+		return false
+	}
+	code := ev.Code()
+	return code == codes.Canceled
+}

+ 5 - 5
clientv3/integration/watch_test.go

@@ -30,7 +30,6 @@ import (
 	mvccpb "github.com/coreos/etcd/mvcc/mvccpb"
 	"github.com/coreos/etcd/pkg/testutil"
 
-	"google.golang.org/grpc"
 	"google.golang.org/grpc/metadata"
 )
 
@@ -667,8 +666,9 @@ func TestWatchErrConnClosed(t *testing.T) {
 	go func() {
 		defer close(donec)
 		ch := cli.Watch(context.TODO(), "foo")
-		if wr := <-ch; wr.Err() != grpc.ErrClientConnClosing {
-			t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, wr.Err())
+
+		if wr := <-ch; !isCanceled(wr.Err()) {
+			t.Fatalf("expected context canceled, got %v", wr.Err())
 		}
 	}()
 
@@ -699,8 +699,8 @@ func TestWatchAfterClose(t *testing.T) {
 	donec := make(chan struct{})
 	go func() {
 		cli.Watch(context.TODO(), "foo")
-		if err := cli.Close(); err != nil && err != grpc.ErrClientConnClosing {
-			t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, err)
+		if err := cli.Close(); err != nil && err != context.Canceled {
+			t.Fatalf("expected %v, got %v", context.Canceled, err)
 		}
 		close(donec)
 	}()