Browse Source

clientv3: use grpc balancer

Anthony Romano 9 years ago
parent
commit
4a13c9f9b3
2 changed files with 121 additions and 29 deletions
  1. 64 0
      clientv3/balancer.go
  2. 57 29
      clientv3/client.go

+ 64 - 0
clientv3/balancer.go

@@ -0,0 +1,64 @@
+// Copyright 2016 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package clientv3
+
+import (
+	"net/url"
+	"strings"
+	"sync/atomic"
+
+	"golang.org/x/net/context"
+	"google.golang.org/grpc"
+)
+
+// simpleBalancer does the bare minimum to expose multiple eps
+// to the grpc reconnection code path
+type simpleBalancer struct {
+	// eps are the client's endpoints stripped of any URL scheme
+	eps     []string
+	ch      chan []grpc.Address
+	numGets uint32
+}
+
+func newSimpleBalancer(eps []string) grpc.Balancer {
+	ch := make(chan []grpc.Address, 1)
+	addrs := make([]grpc.Address, len(eps))
+	for i := range eps {
+		addrs[i].Addr = getHost(eps[i])
+	}
+	ch <- addrs
+	return &simpleBalancer{eps: eps, ch: ch}
+}
+
+func (b *simpleBalancer) Start(target string) error        { return nil }
+func (b *simpleBalancer) Up(addr grpc.Address) func(error) { return func(error) {} }
+func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
+	v := atomic.AddUint32(&b.numGets, 1)
+	ep := b.eps[v%uint32(len(b.eps))]
+	return grpc.Address{Addr: getHost(ep)}, func() {}, nil
+}
+func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.ch }
+func (b *simpleBalancer) Close() error {
+	close(b.ch)
+	return nil
+}
+
+func getHost(ep string) string {
+	url, uerr := url.Parse(ep)
+	if uerr != nil || !strings.Contains(ep, "://") {
+		return ep
+	}
+	return url.Host
+}

+ 57 - 29
clientv3/client.go

@@ -46,9 +46,9 @@ type Client struct {
 	Auth
 	Auth
 	Maintenance
 	Maintenance
 
 
-	conn   *grpc.ClientConn
-	cfg    Config
-	creds  *credentials.TransportAuthenticator
+	conn  *grpc.ClientConn
+	cfg   Config
+	creds *credentials.TransportAuthenticator
 
 
 	ctx    context.Context
 	ctx    context.Context
 	cancel context.CancelFunc
 	cancel context.CancelFunc
@@ -110,43 +110,70 @@ func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...str
 	}, nil
 	}, nil
 }
 }
 
 
-// Dial establishes a connection for a given endpoint using the client's config
+func (c *Client) dialTarget(endpoint string) (proto string, host string, creds *credentials.TransportAuthenticator) {
+	proto = "tcp"
+	host = endpoint
+	creds = c.creds
+	url, uerr := url.Parse(endpoint)
+	if uerr != nil || !strings.Contains(endpoint, "://") {
+		return
+	}
+	// strip scheme:// prefix since grpc dials by host
+	host = url.Host
+	switch url.Scheme {
+	case "unix":
+		proto = "unix"
+	case "http":
+		creds = nil
+	case "https":
+		if creds != nil {
+			break
+		}
+		tlsconfig := &tls.Config{}
+		emptyCreds := credentials.NewTLS(tlsconfig)
+		creds = &emptyCreds
+	default:
+		return "", "", nil
+	}
+	return
+}
+
+// Dial connects to a single endpoint using the client's config.
 func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
 func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
+	return c.dial(endpoint)
+}
+
+func (c *Client) dial(endpoint string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
 	opts := []grpc.DialOption{
 	opts := []grpc.DialOption{
 		grpc.WithBlock(),
 		grpc.WithBlock(),
 		grpc.WithTimeout(c.cfg.DialTimeout),
 		grpc.WithTimeout(c.cfg.DialTimeout),
 	}
 	}
+	opts = append(opts, dopts...)
+
+	// grpc issues TLS cert checks using the string passed into dial so
+	// that string must be the host. To recover the full scheme://host URL,
+	// have a map from hosts to the original endpoint.
+	host2ep := make(map[string]string)
+	for i := range c.cfg.Endpoints {
+		_, host, _ := c.dialTarget(c.cfg.Endpoints[i])
+		host2ep[host] = c.cfg.Endpoints[i]
+	}
 
 
-	proto := "tcp"
-	creds := c.creds
-	if url, uerr := url.Parse(endpoint); uerr == nil && strings.Contains(endpoint, "://") {
-		switch url.Scheme {
-		case "unix":
-			proto = "unix"
-		case "http":
-			creds = nil
-		case "https":
-			if creds == nil {
-				tlsconfig := &tls.Config{InsecureSkipVerify: true}
-				emptyCreds := credentials.NewTLS(tlsconfig)
-				creds = &emptyCreds
-			}
-		default:
-			return nil, fmt.Errorf("unknown scheme %q for %q", url.Scheme, endpoint)
+	f := func(host string, t time.Duration) (net.Conn, error) {
+		proto, host, _ := c.dialTarget(host2ep[host])
+		if proto == "" {
+			return nil, fmt.Errorf("unknown scheme for %q", endpoint)
 		}
 		}
-		// strip scheme:// prefix since grpc dials by host
-		endpoint = url.Host
-	}
-	f := func(a string, t time.Duration) (net.Conn, error) {
 		select {
 		select {
 		case <-c.ctx.Done():
 		case <-c.ctx.Done():
 			return nil, c.ctx.Err()
 			return nil, c.ctx.Err()
 		default:
 		default:
 		}
 		}
-		return net.DialTimeout(proto, a, t)
+		return net.DialTimeout(proto, host, t)
 	}
 	}
 	opts = append(opts, grpc.WithDialer(f))
 	opts = append(opts, grpc.WithDialer(f))
 
 
+	_, host, creds := c.dialTarget(endpoint)
 	if creds != nil {
 	if creds != nil {
 		opts = append(opts, grpc.WithTransportCredentials(*creds))
 		opts = append(opts, grpc.WithTransportCredentials(*creds))
 	} else {
 	} else {
@@ -154,7 +181,7 @@ func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
 	}
 	}
 
 
 	if c.Username != "" && c.Password != "" {
 	if c.Username != "" && c.Password != "" {
-		auth, err := newAuthenticator(endpoint, opts)
+		auth, err := newAuthenticator(host, opts)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -168,7 +195,7 @@ func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
 		opts = append(opts, grpc.WithPerRPCCredentials(authTokenCredential{token: resp.Token}))
 		opts = append(opts, grpc.WithPerRPCCredentials(authTokenCredential{token: resp.Token}))
 	}
 	}
 
 
-	conn, err := grpc.Dial(endpoint, opts...)
+	conn, err := grpc.Dial(host, opts...)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -205,8 +232,9 @@ func newClient(cfg *Config) (*Client, error) {
 		client.Username = cfg.Username
 		client.Username = cfg.Username
 		client.Password = cfg.Password
 		client.Password = cfg.Password
 	}
 	}
-	// TODO: use grpc balancer
-	conn, err := client.Dial(cfg.Endpoints[0])
+
+	b := newSimpleBalancer(cfg.Endpoints)
+	conn, err := client.dial(cfg.Endpoints[0], grpc.WithBalancer(b))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}