Browse Source

clientv3: Fix maintenance APIs to directly dial grpc endpoints correctly.

Joe Betz 7 years ago
parent
commit
b3b06a862a

+ 42 - 14
clientv3/balancer/resolver/endpoint/endpoint.go

@@ -99,6 +99,15 @@ func Target(id, endpoint string) string {
 	return fmt.Sprintf("%s://%s/%s", scheme, id, endpoint)
 }
 
+// DirectTarget constructs a direct resolver target to a single endpoint.
+// TODO: It should be possible to use the 'passthrough' resolver instead
+// of a custom resolver for this use case, but TLS connections fail for
+// a reason we haven't been able to determine.
+func DirectTarget(endpoint string) string {
+	_, host, scheme := ParseEndpoint(endpoint)
+	return Target(fmt.Sprintf("direct:%s", scheme), host)
+}
+
 // IsTarget checks if a given target string in an endpoint resolver target.
 func IsTarget(target string) bool {
 	return strings.HasPrefix(target, "endpoint://")
@@ -114,6 +123,11 @@ func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts res
 		return nil, fmt.Errorf("'etcd' target scheme requires non-empty authority identifying etcd cluster being routed to")
 	}
 	id := target.Authority
+
+	if isDirectEndpoint(target) {
+		return buildDirectEndpointResolver(target, cc, opts)
+	}
+
 	es, err := b.getResolverGroup(id)
 	if err != nil {
 		return nil, fmt.Errorf("failed to build resolver: %v", err)
@@ -126,6 +140,25 @@ func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts res
 	return r, nil
 }
 
+func isDirectEndpoint(target resolver.Target) bool {
+	return strings.HasPrefix(target.Authority, "direct:")
+}
+
+func buildDirectEndpointResolver(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
+	parts := strings.SplitN(target.Authority, ":", 2)
+	if len(parts) != 2 || parts[0] != "direct" {
+		return nil, fmt.Errorf("'endpoint' resolver authority must be of form 'direct:<scheme>', but got %s", target.Authority)
+	}
+	scheme := parts[1]
+	ep := scheme + "://" + target.Endpoint
+	r := &DirectResolver{
+		endpoint: ep,
+		cc:       cc,
+	}
+	r.cc.NewAddress(epsToAddrs(ep))
+	return r, nil
+}
+
 func (b *builder) newResolverGroup(id string) (*ResolverGroup, error) {
 	b.mu.RLock()
 	_, ok := b.resolverGroups[id]
@@ -187,6 +220,15 @@ func (r *Resolver) Close() {
 	es.removeResolver(r)
 }
 
+// DirectResolver provides a resolver for a single etcd endpoint.
+type DirectResolver struct {
+	endpoint string
+	cc       resolver.ClientConn
+}
+
+func (*DirectResolver) ResolveNow(o resolver.ResolveNowOption) {}
+func (*DirectResolver) Close()                                 {}
+
 // ParseEndpoint endpoint parses an endpoint of the form
 // (http|https)://<host>*|(unix|unixs)://<path>)
 // and returns a protocol ('tcp' or 'unix'),
@@ -213,17 +255,3 @@ func ParseEndpoint(endpoint string) (proto string, host string, scheme string) {
 	}
 	return proto, host, scheme
 }
-
-// ParseTarget parses a endpoint://<id>/<endpoint> string and returns the parsed id and endpoint.
-// If the target is malformed, an error is returned.
-func ParseTarget(target string) (string, string, error) {
-	noPrefix := strings.TrimPrefix(target, targetPrefix)
-	if noPrefix == target {
-		return "", "", fmt.Errorf("malformed target, %s prefix is required: %s", targetPrefix, target)
-	}
-	parts := strings.SplitN(noPrefix, "/", 2)
-	if len(parts) != 2 {
-		return "", "", fmt.Errorf("malformed target, expected %s://<id>/<endpoint>, but got %s", scheme, target)
-	}
-	return parts[0], parts[1], nil
-}

+ 19 - 26
clientv3/client.go

@@ -229,13 +229,8 @@ func (c *Client) processCreds(scheme string) (creds *credentials.TransportCreden
 	return creds
 }
 
-// dialSetupOpts gives the dial opts prior to any authentication
-func (c *Client) dialSetupOpts(target string, dopts ...grpc.DialOption) (opts []grpc.DialOption, err error) {
-	_, ep, err := endpoint.ParseTarget(target)
-	if err != nil {
-		return nil, fmt.Errorf("unable to parse target: %v", err)
-	}
-
+// dialSetupOpts gives the dial opts prior to any authentication.
+func (c *Client) dialSetupOpts(scheme string, dopts ...grpc.DialOption) (opts []grpc.DialOption, err error) {
 	if c.cfg.DialKeepAliveTime > 0 {
 		params := keepalive.ClientParameters{
 			Time:    c.cfg.DialKeepAliveTime,
@@ -245,16 +240,9 @@ func (c *Client) dialSetupOpts(target string, dopts ...grpc.DialOption) (opts []
 	}
 	opts = append(opts, dopts...)
 
+	// Provide a net dialer that supports cancelation and timeout.
 	f := func(dialEp string, t time.Duration) (net.Conn, error) {
 		proto, host, _ := endpoint.ParseEndpoint(dialEp)
-		if host == "" && ep != "" {
-			// dialing an endpoint not in the balancer; use
-			// endpoint passed into dial
-			proto, host, _ = endpoint.ParseEndpoint(ep)
-		}
-		if proto == "" {
-			return nil, fmt.Errorf("unknown scheme for %q", host)
-		}
 		select {
 		case <-c.ctx.Done():
 			return nil, c.ctx.Err()
@@ -266,7 +254,7 @@ func (c *Client) dialSetupOpts(target string, dopts ...grpc.DialOption) (opts []
 	opts = append(opts, grpc.WithDialer(f))
 
 	creds := c.creds
-	if _, _, scheme := endpoint.ParseEndpoint(ep); len(scheme) != 0 {
+	if len(scheme) != 0 {
 		creds = c.processCreds(scheme)
 	}
 	if creds != nil {
@@ -291,8 +279,9 @@ func (c *Client) dialSetupOpts(target string, dopts ...grpc.DialOption) (opts []
 }
 
 // Dial connects to a single endpoint using the client's config.
-func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
-	return c.dial(endpoint)
+func (c *Client) Dial(ep string) (*grpc.ClientConn, error) {
+	_, _, scheme := endpoint.ParseEndpoint(ep)
+	return c.dial(endpoint.DirectTarget(ep), scheme)
 }
 
 func (c *Client) getToken(ctx context.Context) error {
@@ -303,9 +292,9 @@ func (c *Client) getToken(ctx context.Context) error {
 		ep := c.cfg.Endpoints[i]
 		// use dial options without dopts to avoid reusing the client balancer
 		var dOpts []grpc.DialOption
-		_, host, _ := endpoint.ParseEndpoint(ep)
+		_, host, scheme := endpoint.ParseEndpoint(ep)
 		target := c.resolverGroup.Target(host)
-		dOpts, err = c.dialSetupOpts(target, c.cfg.DialOptions...)
+		dOpts, err = c.dialSetupOpts(scheme, c.cfg.DialOptions...)
 		if err != nil {
 			err = fmt.Errorf("failed to configure auth dialer: %v", err)
 			continue
@@ -333,13 +322,17 @@ func (c *Client) getToken(ctx context.Context) error {
 	return err
 }
 
-func (c *Client) dial(ep string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
-	// We pass a target to DialContext of the form: endpoint://<clusterName>/<host-part> that
-	// does not include scheme (http/https/unix/unixs) or path parts.
-	_, host, _ := endpoint.ParseEndpoint(ep)
+// dialWithBalancer dials the client's current load balanced resolver group.  The scheme of the host
+// of the provided endpoint determines the scheme used for all endpoints of the client connection.
+func (c *Client) dialWithBalancer(ep string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
+	_, host, scheme := endpoint.ParseEndpoint(ep)
 	target := c.resolverGroup.Target(host)
+	return c.dial(target, scheme, dopts...)
+}
 
-	opts, err := c.dialSetupOpts(target, dopts...)
+// dial configures and dials any grpc balancer target.
+func (c *Client) dial(target string, scheme string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
+	opts, err := c.dialSetupOpts(scheme, dopts...)
 	if err != nil {
 		return nil, fmt.Errorf("failed to configure dialer: %v", err)
 	}
@@ -467,7 +460,7 @@ func newClient(cfg *Config) (*Client, error) {
 
 	// Use an provided endpoint target so that for https:// without any tls config given, then
 	// grpc will assume the certificate server name is the endpoint host.
-	conn, err := client.dial(dialEndpoint, grpc.WithBalancerName(roundRobinBalancerName))
+	conn, err := client.dialWithBalancer(dialEndpoint, grpc.WithBalancerName(roundRobinBalancerName))
 	if err != nil {
 		client.cancel()
 		client.resolverGroup.Close()

+ 46 - 0
clientv3/integration/maintenance_test.go

@@ -25,7 +25,9 @@ import (
 	"time"
 
 	"go.uber.org/zap"
+	"google.golang.org/grpc"
 
+	"github.com/coreos/etcd/clientv3"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	"github.com/coreos/etcd/integration"
 	"github.com/coreos/etcd/lease"
@@ -193,3 +195,47 @@ func TestMaintenanceSnapshotErrorInflight(t *testing.T) {
 		t.Errorf("expected client timeout, got %v", err)
 	}
 }
+
+func TestMaintenanceStatus(t *testing.T) {
+	defer testutil.AfterTest(t)
+
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
+	defer clus.Terminate(t)
+
+	clus.WaitLeader(t)
+
+	eps := make([]string, 3)
+	for i := 0; i < 3; i++ {
+		eps[i] = clus.Members[i].GRPCAddr()
+	}
+
+	cli, err := clientv3.New(clientv3.Config{Endpoints: eps, DialOptions: []grpc.DialOption{grpc.WithBlock()}})
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cli.Close()
+
+	prevID, leaderFound := uint64(0), false
+	for i := 0; i < 3; i++ {
+		resp, err := cli.Status(context.TODO(), eps[i])
+		if err != nil {
+			t.Fatal(err)
+		}
+		if prevID == 0 {
+			prevID, leaderFound = resp.Header.MemberId, resp.Header.MemberId == resp.Leader
+			continue
+		}
+		if prevID == resp.Header.MemberId {
+			t.Errorf("#%d: status returned duplicate member ID with %016x", i, prevID)
+		}
+		if leaderFound && resp.Header.MemberId == resp.Leader {
+			t.Errorf("#%d: leader already found, but found another %016x", i, resp.Header.MemberId)
+		}
+		if !leaderFound {
+			leaderFound = resp.Header.MemberId == resp.Leader
+		}
+	}
+	if !leaderFound {
+		t.Fatal("no leader found")
+	}
+}

+ 1 - 1
clientv3/maintenance.go

@@ -76,7 +76,7 @@ type maintenance struct {
 func NewMaintenance(c *Client) Maintenance {
 	api := &maintenance{
 		dial: func(endpoint string) (pb.MaintenanceClient, func(), error) {
-			conn, err := c.dial(endpoint)
+			conn, err := c.Dial(endpoint)
 			if err != nil {
 				return nil, nil, fmt.Errorf("failed to dial endpoint %s with maintenance client: %v", endpoint, err)
 			}

+ 4 - 2
tests/e2e/ctl_v3_move_leader_test.go

@@ -72,10 +72,12 @@ func testCtlV3MoveLeader(t *testing.T, cfg etcdProcessClusterConfig) {
 		if err != nil {
 			t.Fatal(err)
 		}
-		resp, err := cli.Status(context.Background(), ep)
+		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+		resp, err := cli.Status(ctx, ep)
 		if err != nil {
-			t.Fatal(err)
+			t.Fatalf("failed to get status from endpoint %s: %v", ep, err)
 		}
+		cancel()
 		cli.Close()
 
 		if resp.Header.GetMemberId() == resp.Leader {