Browse Source

clientv3: Use passthrough resolver for direct endpoint dialing

Joe Betz 7 years ago
parent
commit
67bcf28c4e
4 changed files with 69 additions and 56 deletions
  1. 1 0
      .words
  2. 25 42
      clientv3/balancer/resolver/endpoint/endpoint.go
  3. 43 13
      clientv3/client.go
  4. 0 1
      clientv3/integration/dial_test.go

+ 1 - 0
.words

@@ -30,6 +30,7 @@ gRPC
 goroutine
 goroutines
 healthcheck
+hostname
 iff
 inflight
 keepalive

+ 25 - 42
clientv3/balancer/resolver/endpoint/endpoint.go

@@ -99,15 +99,6 @@ func Target(id, endpoint string) string {
 	return fmt.Sprintf("%s://%s/%s", scheme, id, endpoint)
 }
 
-// DirectTarget constructs a direct resolver target to a single endpoint.
-// TODO: It should be possible to use the 'passthrough' resolver instead
-// of a custom resolver for this use case, but TLS connections fail for
-// a reason we haven't been able to determine.
-func DirectTarget(endpoint string) string {
-	_, host, scheme := ParseEndpoint(endpoint)
-	return Target(fmt.Sprintf("direct:%s", scheme), host)
-}
-
 // IsTarget checks if a given target string in an endpoint resolver target.
 func IsTarget(target string) bool {
 	return strings.HasPrefix(target, "endpoint://")
@@ -123,11 +114,6 @@ func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts res
 		return nil, fmt.Errorf("'etcd' target scheme requires non-empty authority identifying etcd cluster being routed to")
 	}
 	id := target.Authority
-
-	if isDirectEndpoint(target) {
-		return buildDirectEndpointResolver(target, cc, opts)
-	}
-
 	es, err := b.getResolverGroup(id)
 	if err != nil {
 		return nil, fmt.Errorf("failed to build resolver: %v", err)
@@ -140,25 +126,6 @@ func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts res
 	return r, nil
 }
 
-func isDirectEndpoint(target resolver.Target) bool {
-	return strings.HasPrefix(target.Authority, "direct:")
-}
-
-func buildDirectEndpointResolver(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
-	parts := strings.SplitN(target.Authority, ":", 2)
-	if len(parts) != 2 || parts[0] != "direct" {
-		return nil, fmt.Errorf("'endpoint' resolver authority must be of form 'direct:<scheme>', but got %s", target.Authority)
-	}
-	scheme := parts[1]
-	ep := scheme + "://" + target.Endpoint
-	r := &DirectResolver{
-		endpoint: ep,
-		cc:       cc,
-	}
-	r.cc.NewAddress(epsToAddrs(ep))
-	return r, nil
-}
-
 func (b *builder) newResolverGroup(id string) (*ResolverGroup, error) {
 	b.mu.RLock()
 	_, ok := b.resolverGroups[id]
@@ -220,15 +187,6 @@ func (r *Resolver) Close() {
 	es.removeResolver(r)
 }
 
-// DirectResolver provides a resolver for a single etcd endpoint.
-type DirectResolver struct {
-	endpoint string
-	cc       resolver.ClientConn
-}
-
-func (*DirectResolver) ResolveNow(o resolver.ResolveNowOption) {}
-func (*DirectResolver) Close()                                 {}
-
 // ParseEndpoint endpoint parses an endpoint of the form
 // (http|https)://<host>*|(unix|unixs)://<path>)
 // and returns a protocol ('tcp' or 'unix'),
@@ -255,3 +213,28 @@ func ParseEndpoint(endpoint string) (proto string, host string, scheme string) {
 	}
 	return proto, host, scheme
 }
+
+// ParseTarget parses a endpoint://<id>/<endpoint> string and returns the parsed id and endpoint.
+// If the target is malformed, an error is returned.
+func ParseTarget(target string) (string, string, error) {
+	noPrefix := strings.TrimPrefix(target, targetPrefix)
+	if noPrefix == target {
+		return "", "", fmt.Errorf("malformed target, %s prefix is required: %s", targetPrefix, target)
+	}
+	parts := strings.SplitN(noPrefix, "/", 2)
+	if len(parts) != 2 {
+		return "", "", fmt.Errorf("malformed target, expected %s://<id>/<endpoint>, but got %s", scheme, target)
+	}
+	return parts[0], parts[1], nil
+}
+
+// ParseHostPort splits a "<host>:<port>" string into the host and port parts.
+// The port part is optional.
+func ParseHostPort(hostPort string) (host string, port string) {
+	parts := strings.SplitN(hostPort, ":", 2)
+	host = parts[0]
+	if len(parts) > 1 {
+		port = parts[1]
+	}
+	return host, port
+}

+ 43 - 13
clientv3/client.go

@@ -230,7 +230,7 @@ func (c *Client) processCreds(scheme string) (creds *credentials.TransportCreden
 }
 
 // dialSetupOpts gives the dial opts prior to any authentication.
-func (c *Client) dialSetupOpts(scheme string, dopts ...grpc.DialOption) (opts []grpc.DialOption, err error) {
+func (c *Client) dialSetupOpts(creds *credentials.TransportCredentials, dopts ...grpc.DialOption) (opts []grpc.DialOption, err error) {
 	if c.cfg.DialKeepAliveTime > 0 {
 		params := keepalive.ClientParameters{
 			Time:    c.cfg.DialKeepAliveTime,
@@ -253,10 +253,6 @@ func (c *Client) dialSetupOpts(scheme string, dopts ...grpc.DialOption) (opts []
 	}
 	opts = append(opts, grpc.WithDialer(f))
 
-	creds := c.creds
-	if len(scheme) != 0 {
-		creds = c.processCreds(scheme)
-	}
 	if creds != nil {
 		opts = append(opts, grpc.WithTransportCredentials(*creds))
 	} else {
@@ -280,8 +276,12 @@ func (c *Client) dialSetupOpts(scheme string, dopts ...grpc.DialOption) (opts []
 
 // Dial connects to a single endpoint using the client's config.
 func (c *Client) Dial(ep string) (*grpc.ClientConn, error) {
-	_, _, scheme := endpoint.ParseEndpoint(ep)
-	return c.dial(endpoint.DirectTarget(ep), scheme)
+	creds := c.directDialCreds(ep)
+	// Use the grpc passthrough resolver to directly dial a single endpoint.
+	// This resolver passes through the 'unix' and 'unixs' endpoints schemes used
+	// by etcd without modification, allowing us to directly dial endpoints and
+	// using the same dial functions that we use for load balancer dialing.
+	return c.dial(fmt.Sprintf("passthrough:///%s", ep), creds)
 }
 
 func (c *Client) getToken(ctx context.Context) error {
@@ -292,9 +292,10 @@ func (c *Client) getToken(ctx context.Context) error {
 		ep := c.cfg.Endpoints[i]
 		// use dial options without dopts to avoid reusing the client balancer
 		var dOpts []grpc.DialOption
-		_, host, scheme := endpoint.ParseEndpoint(ep)
+		_, host, _ := endpoint.ParseEndpoint(ep)
 		target := c.resolverGroup.Target(host)
-		dOpts, err = c.dialSetupOpts(scheme, c.cfg.DialOptions...)
+		creds := c.dialWithBalancerCreds(ep)
+		dOpts, err = c.dialSetupOpts(creds, c.cfg.DialOptions...)
 		if err != nil {
 			err = fmt.Errorf("failed to configure auth dialer: %v", err)
 			continue
@@ -325,14 +326,15 @@ func (c *Client) getToken(ctx context.Context) error {
 // dialWithBalancer dials the client's current load balanced resolver group.  The scheme of the host
 // of the provided endpoint determines the scheme used for all endpoints of the client connection.
 func (c *Client) dialWithBalancer(ep string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
-	_, host, scheme := endpoint.ParseEndpoint(ep)
+	_, host, _ := endpoint.ParseEndpoint(ep)
 	target := c.resolverGroup.Target(host)
-	return c.dial(target, scheme, dopts...)
+	creds := c.dialWithBalancerCreds(ep)
+	return c.dial(target, creds, dopts...)
 }
 
 // dial configures and dials any grpc balancer target.
-func (c *Client) dial(target string, scheme string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
-	opts, err := c.dialSetupOpts(scheme, dopts...)
+func (c *Client) dial(target string, creds *credentials.TransportCredentials, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
+	opts, err := c.dialSetupOpts(creds, dopts...)
 	if err != nil {
 		return nil, fmt.Errorf("failed to configure dialer: %v", err)
 	}
@@ -378,6 +380,34 @@ func (c *Client) dial(target string, scheme string, dopts ...grpc.DialOption) (*
 	return conn, nil
 }
 
+func (c *Client) directDialCreds(ep string) *credentials.TransportCredentials {
+	_, hostPort, scheme := endpoint.ParseEndpoint(ep)
+	creds := c.creds
+	if len(scheme) != 0 {
+		creds = c.processCreds(scheme)
+		if creds != nil {
+			c := *creds
+			clone := c.Clone()
+			// Set the server name must to the endpoint hostname without port since grpc
+			// otherwise attempts to check if x509 cert is valid for the full endpoint
+			// including the scheme and port, which fails.
+			host, _ := endpoint.ParseHostPort(hostPort)
+			clone.OverrideServerName(host)
+			creds = &clone
+		}
+	}
+	return creds
+}
+
+func (c *Client) dialWithBalancerCreds(ep string) *credentials.TransportCredentials {
+	_, _, scheme := endpoint.ParseEndpoint(ep)
+	creds := c.creds
+	if len(scheme) != 0 {
+		creds = c.processCreds(scheme)
+	}
+	return creds
+}
+
 // WithRequireLeader requires client requests to only succeed
 // when the cluster has a leader.
 func WithRequireLeader(ctx context.Context) context.Context {

+ 0 - 1
clientv3/integration/dial_test.go

@@ -87,7 +87,6 @@ func TestDialTLSNoConfig(t *testing.T) {
 	if !isClientTimeout(err) {
 		t.Fatalf("expected dial timeout error, got %v", err)
 	}
-
 }
 
 // TestDialSetEndpointsBeforeFail ensures SetEndpoints can replace unavailable