Browse Source

clientv3: add 'SetEndpoints' method

Gyu-Ho Lee 9 years ago
parent
commit
b9d18d4ac9
2 changed files with 61 additions and 14 deletions
  1. 35 0
      clientv3/balancer.go
  2. 26 14
      clientv3/client.go

+ 35 - 0
clientv3/balancer.go

@@ -42,6 +42,11 @@ type simpleBalancer struct {
 	// upc closes when upEps transitions from empty to non-zero or the balancer closes.
 	upc chan struct{}
 
+	// grpc issues TLS cert checks using the string passed into dial so
+	// that string must be the host. To recover the full scheme://host URL,
+	// have a map from hosts to the original endpoint.
+	host2ep map[string]string
+
 	// pinAddr is the currently pinned address; set to the empty string on
 	// intialization and shutdown.
 	pinAddr string
@@ -62,6 +67,7 @@ func newSimpleBalancer(eps []string) *simpleBalancer {
 		readyc:   make(chan struct{}),
 		upEps:    make(map[string]struct{}),
 		upc:      make(chan struct{}),
+		host2ep:  getHost2ep(eps),
 	}
 	return sb
 }
@@ -74,6 +80,35 @@ func (b *simpleBalancer) ConnectNotify() <-chan struct{} {
 	return b.upc
 }
 
+func (b *simpleBalancer) getEndpoint(host string) string {
+	b.mu.Lock()
+	defer b.mu.Unlock()
+	return b.host2ep[host]
+}
+
+func getHost2ep(eps []string) map[string]string {
+	hm := make(map[string]string, len(eps))
+	for i := range eps {
+		_, host, _ := parseEndpoint(eps[i])
+		hm[host] = eps[i]
+	}
+	return hm
+}
+
+func (b *simpleBalancer) updateAddrs(eps []string) {
+	b.mu.Lock()
+	defer b.mu.Unlock()
+
+	b.host2ep = getHost2ep(eps)
+
+	addrs := make([]grpc.Address, 0, len(eps))
+	for i := range eps {
+		addrs = append(addrs, grpc.Address{Addr: getHost(eps[i])})
+	}
+	b.addrs = addrs
+	b.notifyCh <- addrs
+}
+
 func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
 	b.mu.Lock()
 	defer b.mu.Unlock()

+ 26 - 14
clientv3/client.go

@@ -99,6 +99,12 @@ func (c *Client) Ctx() context.Context { return c.ctx }
 // Endpoints lists the registered endpoints for the client.
 func (c *Client) Endpoints() []string { return c.cfg.Endpoints }
 
+// SetEndpoints updates client's endpoints.
+func (c *Client) SetEndpoints(eps ...string) {
+	c.cfg.Endpoints = eps
+	c.balancer.updateAddrs(eps)
+}
+
 type authTokenCredential struct {
 	token string
 }
@@ -113,19 +119,31 @@ func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...str
 	}, nil
 }
 
-func (c *Client) dialTarget(endpoint string) (proto string, host string, creds *credentials.TransportCredentials) {
+func parseEndpoint(endpoint string) (proto string, host string, scheme bool) {
 	proto = "tcp"
 	host = endpoint
-	creds = c.creds
 	url, uerr := url.Parse(endpoint)
 	if uerr != nil || !strings.Contains(endpoint, "://") {
 		return
 	}
+	scheme = true
+
 	// strip scheme:// prefix since grpc dials by host
 	host = url.Host
 	switch url.Scheme {
+	case "http", "https":
 	case "unix":
 		proto = "unix"
+	default:
+		proto, host = "", ""
+	}
+	return
+}
+
+func (c *Client) processCreds(protocol string) (creds *credentials.TransportCredentials) {
+	creds = c.creds
+	switch protocol {
+	case "unix":
 	case "http":
 		creds = nil
 	case "https":
@@ -136,7 +154,7 @@ func (c *Client) dialTarget(endpoint string) (proto string, host string, creds *
 		emptyCreds := credentials.NewTLS(tlsconfig)
 		creds = &emptyCreds
 	default:
-		return "", "", nil
+		creds = nil
 	}
 	return
 }
@@ -148,17 +166,8 @@ func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts
 	}
 	opts = append(opts, dopts...)
 
-	// grpc issues TLS cert checks using the string passed into dial so
-	// that string must be the host. To recover the full scheme://host URL,
-	// have a map from hosts to the original endpoint.
-	host2ep := make(map[string]string)
-	for i := range c.cfg.Endpoints {
-		_, host, _ := c.dialTarget(c.cfg.Endpoints[i])
-		host2ep[host] = c.cfg.Endpoints[i]
-	}
-
 	f := func(host string, t time.Duration) (net.Conn, error) {
-		proto, host, _ := c.dialTarget(host2ep[host])
+		proto, host, _ := parseEndpoint(c.balancer.getEndpoint(host))
 		if proto == "" {
 			return nil, fmt.Errorf("unknown scheme for %q", host)
 		}
@@ -171,7 +180,10 @@ func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts
 	}
 	opts = append(opts, grpc.WithDialer(f))
 
-	_, _, creds := c.dialTarget(endpoint)
+	creds := c.creds
+	if proto, _, scheme := parseEndpoint(endpoint); scheme {
+		creds = c.processCreds(proto)
+	}
 	if creds != nil {
 		opts = append(opts, grpc.WithTransportCredentials(*creds))
 	} else {