Browse Source

clientv3: combine "healthBalancer" and "simpleBalancer"

Signed-off-by: Gyu-Ho Lee <gyuhox@gmail.com>
Gyu-Ho Lee 8 years ago
parent
commit
012b013538
4 changed files with 546 additions and 593 deletions
  1. 0 439
      clientv3/balancer.go
  2. 16 3
      clientv3/client.go
  3. 498 116
      clientv3/health_balancer.go
  4. 32 35
      clientv3/health_balancer_test.go

+ 0 - 439
clientv3/balancer.go

@@ -1,439 +0,0 @@
-// Copyright 2016 The etcd Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package clientv3
-
-import (
-	"context"
-	"net/url"
-	"strings"
-	"sync"
-
-	"google.golang.org/grpc"
-	"google.golang.org/grpc/codes"
-	"google.golang.org/grpc/status"
-)
-
-// ErrNoAddrAvilable is returned by Get() when the balancer does not have
-// any active connection to endpoints at the time.
-// This error is returned only when opts.BlockingWait is true.
-var ErrNoAddrAvilable = status.Error(codes.Unavailable, "there is no address available")
-
-type notifyMsg int
-
-const (
-	notifyReset notifyMsg = iota
-	notifyNext
-)
-
-// simpleBalancer does the bare minimum to expose multiple eps
-// to the grpc reconnection code path
-type simpleBalancer struct {
-	// addrs are the client's endpoint addresses for grpc
-	addrs []grpc.Address
-
-	// eps holds the raw endpoints from the client
-	eps []string
-
-	// notifyCh notifies grpc of the set of addresses for connecting
-	notifyCh chan []grpc.Address
-
-	// readyc closes once the first connection is up
-	readyc    chan struct{}
-	readyOnce sync.Once
-
-	// mu protects all fields below.
-	mu sync.RWMutex
-
-	// upc closes when pinAddr transitions from empty to non-empty or the balancer closes.
-	upc chan struct{}
-
-	// downc closes when grpc calls down() on pinAddr
-	downc chan struct{}
-
-	// stopc is closed to signal updateNotifyLoop should stop.
-	stopc chan struct{}
-
-	// donec closes when all goroutines are exited
-	donec chan struct{}
-
-	// updateAddrsC notifies updateNotifyLoop to update addrs.
-	updateAddrsC chan notifyMsg
-
-	// grpc issues TLS cert checks using the string passed into dial so
-	// that string must be the host. To recover the full scheme://host URL,
-	// have a map from hosts to the original endpoint.
-	hostPort2ep map[string]string
-
-	// pinAddr is the currently pinned address; set to the empty string on
-	// initialization and shutdown.
-	pinAddr string
-
-	closed bool
-}
-
-func newSimpleBalancer(eps []string) *simpleBalancer {
-	notifyCh := make(chan []grpc.Address)
-	addrs := eps2addrs(eps)
-	sb := &simpleBalancer{
-		addrs:        addrs,
-		eps:          eps,
-		notifyCh:     notifyCh,
-		readyc:       make(chan struct{}),
-		upc:          make(chan struct{}),
-		stopc:        make(chan struct{}),
-		downc:        make(chan struct{}),
-		donec:        make(chan struct{}),
-		updateAddrsC: make(chan notifyMsg),
-		hostPort2ep:  getHostPort2ep(eps),
-	}
-	close(sb.downc)
-	go sb.updateNotifyLoop()
-	return sb
-}
-
-func (b *simpleBalancer) Start(target string, config grpc.BalancerConfig) error { return nil }
-
-func (b *simpleBalancer) ConnectNotify() <-chan struct{} {
-	b.mu.Lock()
-	defer b.mu.Unlock()
-	return b.upc
-}
-
-func (b *simpleBalancer) ready() <-chan struct{} { return b.readyc }
-
-func (b *simpleBalancer) endpoint(hostPort string) string {
-	b.mu.Lock()
-	defer b.mu.Unlock()
-	return b.hostPort2ep[hostPort]
-}
-
-func (b *simpleBalancer) endpoints() []string {
-	b.mu.RLock()
-	defer b.mu.RUnlock()
-	return b.eps
-}
-
-func (b *simpleBalancer) pinned() string {
-	b.mu.RLock()
-	defer b.mu.RUnlock()
-	return b.pinAddr
-}
-
-func getHostPort2ep(eps []string) map[string]string {
-	hm := make(map[string]string, len(eps))
-	for i := range eps {
-		_, host, _ := parseEndpoint(eps[i])
-		hm[host] = eps[i]
-	}
-	return hm
-}
-
-func (b *simpleBalancer) updateAddrs(eps ...string) {
-	np := getHostPort2ep(eps)
-
-	b.mu.Lock()
-
-	match := len(np) == len(b.hostPort2ep)
-	for k, v := range np {
-		if b.hostPort2ep[k] != v {
-			match = false
-			break
-		}
-	}
-	if match {
-		// same endpoints, so no need to update address
-		b.mu.Unlock()
-		return
-	}
-
-	b.hostPort2ep = np
-	b.addrs, b.eps = eps2addrs(eps), eps
-
-	// updating notifyCh can trigger new connections,
-	// only update addrs if all connections are down
-	// or addrs does not include pinAddr.
-	update := !hasAddr(b.addrs, b.pinAddr)
-	b.mu.Unlock()
-
-	if update {
-		select {
-		case b.updateAddrsC <- notifyNext:
-		case <-b.stopc:
-		}
-	}
-}
-
-func (b *simpleBalancer) next() {
-	b.mu.RLock()
-	downc := b.downc
-	b.mu.RUnlock()
-	select {
-	case b.updateAddrsC <- notifyNext:
-	case <-b.stopc:
-	}
-	// wait until disconnect so new RPCs are not issued on old connection
-	select {
-	case <-downc:
-	case <-b.stopc:
-	}
-}
-
-func hasAddr(addrs []grpc.Address, targetAddr string) bool {
-	for _, addr := range addrs {
-		if targetAddr == addr.Addr {
-			return true
-		}
-	}
-	return false
-}
-
-func (b *simpleBalancer) updateNotifyLoop() {
-	defer close(b.donec)
-
-	for {
-		b.mu.RLock()
-		upc, downc, addr := b.upc, b.downc, b.pinAddr
-		b.mu.RUnlock()
-		// downc or upc should be closed
-		select {
-		case <-downc:
-			downc = nil
-		default:
-		}
-		select {
-		case <-upc:
-			upc = nil
-		default:
-		}
-		switch {
-		case downc == nil && upc == nil:
-			// stale
-			select {
-			case <-b.stopc:
-				return
-			default:
-			}
-		case downc == nil:
-			b.notifyAddrs(notifyReset)
-			select {
-			case <-upc:
-			case msg := <-b.updateAddrsC:
-				b.notifyAddrs(msg)
-			case <-b.stopc:
-				return
-			}
-		case upc == nil:
-			select {
-			// close connections that are not the pinned address
-			case b.notifyCh <- []grpc.Address{{Addr: addr}}:
-			case <-downc:
-			case <-b.stopc:
-				return
-			}
-			select {
-			case <-downc:
-				b.notifyAddrs(notifyReset)
-			case msg := <-b.updateAddrsC:
-				b.notifyAddrs(msg)
-			case <-b.stopc:
-				return
-			}
-		}
-	}
-}
-
-func (b *simpleBalancer) notifyAddrs(msg notifyMsg) {
-	if msg == notifyNext {
-		select {
-		case b.notifyCh <- []grpc.Address{}:
-		case <-b.stopc:
-			return
-		}
-	}
-	b.mu.RLock()
-	addrs := b.addrs
-	pinAddr := b.pinAddr
-	downc := b.downc
-	b.mu.RUnlock()
-
-	var waitDown bool
-	if pinAddr != "" {
-		waitDown = true
-		for _, a := range addrs {
-			if a.Addr == pinAddr {
-				waitDown = false
-			}
-		}
-	}
-
-	select {
-	case b.notifyCh <- addrs:
-		if waitDown {
-			select {
-			case <-downc:
-			case <-b.stopc:
-			}
-		}
-	case <-b.stopc:
-	}
-}
-
-func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
-	f, _ := b.up(addr)
-	return f
-}
-
-func (b *simpleBalancer) up(addr grpc.Address) (func(error), bool) {
-	b.mu.Lock()
-	defer b.mu.Unlock()
-
-	// gRPC might call Up after it called Close. We add this check
-	// to "fix" it up at application layer. Otherwise, will panic
-	// if b.upc is already closed.
-	if b.closed {
-		return func(err error) {}, false
-	}
-	// gRPC might call Up on a stale address.
-	// Prevent updating pinAddr with a stale address.
-	if !hasAddr(b.addrs, addr.Addr) {
-		return func(err error) {}, false
-	}
-	if b.pinAddr != "" {
-		if logger.V(4) {
-			logger.Infof("clientv3/balancer: %q is up but not pinned (already pinned %q)", addr.Addr, b.pinAddr)
-		}
-		return func(err error) {}, false
-	}
-	// notify waiting Get()s and pin first connected address
-	close(b.upc)
-	b.downc = make(chan struct{})
-	b.pinAddr = addr.Addr
-	if logger.V(4) {
-		logger.Infof("clientv3/balancer: pin %q", addr.Addr)
-	}
-	// notify client that a connection is up
-	b.readyOnce.Do(func() { close(b.readyc) })
-	return func(err error) {
-		b.mu.Lock()
-		b.upc = make(chan struct{})
-		close(b.downc)
-		b.pinAddr = ""
-		b.mu.Unlock()
-		if logger.V(4) {
-			logger.Infof("clientv3/balancer: unpin %q (%q)", addr.Addr, err.Error())
-		}
-	}, true
-}
-
-func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
-	var (
-		addr   string
-		closed bool
-	)
-
-	// If opts.BlockingWait is false (for fail-fast RPCs), it should return
-	// an address it has notified via Notify immediately instead of blocking.
-	if !opts.BlockingWait {
-		b.mu.RLock()
-		closed = b.closed
-		addr = b.pinAddr
-		b.mu.RUnlock()
-		if closed {
-			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
-		}
-		if addr == "" {
-			return grpc.Address{Addr: ""}, nil, ErrNoAddrAvilable
-		}
-		return grpc.Address{Addr: addr}, func() {}, nil
-	}
-
-	for {
-		b.mu.RLock()
-		ch := b.upc
-		b.mu.RUnlock()
-		select {
-		case <-ch:
-		case <-b.donec:
-			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
-		case <-ctx.Done():
-			return grpc.Address{Addr: ""}, nil, ctx.Err()
-		}
-		b.mu.RLock()
-		closed = b.closed
-		addr = b.pinAddr
-		b.mu.RUnlock()
-		// Close() which sets b.closed = true can be called before Get(), Get() must exit if balancer is closed.
-		if closed {
-			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
-		}
-		if addr != "" {
-			break
-		}
-	}
-	return grpc.Address{Addr: addr}, func() {}, nil
-}
-
-func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
-
-func (b *simpleBalancer) Close() error {
-	b.mu.Lock()
-	// In case gRPC calls close twice. TODO: remove the checking
-	// when we are sure that gRPC wont call close twice.
-	if b.closed {
-		b.mu.Unlock()
-		<-b.donec
-		return nil
-	}
-	b.closed = true
-	close(b.stopc)
-	b.pinAddr = ""
-
-	// In the case of following scenario:
-	//	1. upc is not closed; no pinned address
-	// 	2. client issues an RPC, calling invoke(), which calls Get(), enters for loop, blocks
-	// 	3. client.conn.Close() calls balancer.Close(); closed = true
-	// 	4. for loop in Get() never exits since ctx is the context passed in by the client and may not be canceled
-	// we must close upc so Get() exits from blocking on upc
-	select {
-	case <-b.upc:
-	default:
-		// terminate all waiting Get()s
-		close(b.upc)
-	}
-
-	b.mu.Unlock()
-
-	// wait for updateNotifyLoop to finish
-	<-b.donec
-	close(b.notifyCh)
-
-	return nil
-}
-
-func getHost(ep string) string {
-	url, uerr := url.Parse(ep)
-	if uerr != nil || !strings.Contains(ep, "://") {
-		return ep
-	}
-	return url.Host
-}
-
-func eps2addrs(eps []string) []grpc.Address {
-	addrs := make([]grpc.Address, len(eps))
-	for i := range eps {
-		addrs[i].Addr = getHost(eps[i])
-	}
-	return addrs
-}

+ 16 - 3
clientv3/client.go

@@ -121,6 +121,19 @@ func (c *Client) SetEndpoints(eps ...string) {
 	c.cfg.Endpoints = eps
 	c.mu.Unlock()
 	c.balancer.updateAddrs(eps...)
+
+	// updating notifyCh can trigger new connections,
+	// need update addrs if all connections are down
+	// or addrs does not include pinAddr.
+	c.balancer.mu.RLock()
+	update := !hasAddr(c.balancer.addrs, c.balancer.pinAddr)
+	c.balancer.mu.RUnlock()
+	if update {
+		select {
+		case c.balancer.updateAddrsC <- notifyNext:
+		case <-c.balancer.stopc:
+		}
+	}
 }
 
 // Sync synchronizes client's endpoints with the known endpoints from the etcd membership.
@@ -378,9 +391,9 @@ func newClient(cfg *Config) (*Client, error) {
 		client.Password = cfg.Password
 	}
 
-	sb := newSimpleBalancer(cfg.Endpoints)
-	hc := func(ep string) (bool, error) { return grpcHealthCheck(client, ep) }
-	client.balancer = newHealthBalancer(sb, cfg.DialTimeout, hc)
+	client.balancer = newHealthBalancer(cfg.Endpoints, cfg.DialTimeout, func(ep string) (bool, error) {
+		return grpcHealthCheck(client, ep)
+	})
 
 	// use Endpoints[0] so that for https:// without any tls config given, then
 	// grpc will assume the certificate server name is the endpoint host.

+ 498 - 116
clientv3/health_balancer.go

@@ -16,6 +16,9 @@ package clientv3
 
 import (
 	"context"
+	"errors"
+	"net/url"
+	"strings"
 	"sync"
 	"time"
 
@@ -25,207 +28,553 @@ import (
 	"google.golang.org/grpc/status"
 )
 
-const minHealthRetryDuration = 3 * time.Second
-const unknownService = "unknown service grpc.health.v1.Health"
+const (
+	minHealthRetryDuration = 3 * time.Second
+	unknownService         = "unknown service grpc.health.v1.Health"
+)
+
+// ErrNoAddrAvilable is returned by Get() when the balancer does not have
+// any active connection to endpoints at the time.
+// This error is returned only when opts.BlockingWait is true.
+var ErrNoAddrAvilable = status.Error(codes.Unavailable, "there is no address available")
 
 type healthCheckFunc func(ep string) (bool, error)
 
-// healthBalancer wraps a balancer so that it uses health checking
-// to choose its endpoints.
+type notifyMsg int
+
+const (
+	notifyReset notifyMsg = iota
+	notifyNext
+)
+
+// healthBalancer does the bare minimum to expose multiple eps
+// to the grpc reconnection code path
 type healthBalancer struct {
-	*simpleBalancer
+	// addrs are the client's endpoint addresses for grpc
+	addrs []grpc.Address
+
+	// eps holds the raw endpoints from the client
+	eps []string
+
+	// notifyCh notifies grpc of the set of addresses for connecting
+	notifyCh chan []grpc.Address
+
+	// readyc closes once the first connection is up
+	readyc    chan struct{}
+	readyOnce sync.Once
 
 	// healthCheck checks an endpoint's health.
 	healthCheck        healthCheckFunc
 	healthCheckTimeout time.Duration
 
-	// mu protects addrs, eps, unhealthy map, and stopc.
+	unhealthyMu        sync.RWMutex
+	unhealthyHostPorts map[string]time.Time
+
+	// mu protects all fields below.
 	mu sync.RWMutex
 
-	// addrs stores all grpc addresses associated with the balancer.
-	addrs []grpc.Address
+	// upc closes when pinAddr transitions from empty to non-empty or the balancer closes.
+	upc chan struct{}
 
-	// eps stores all client endpoints
-	eps []string
-
-	// unhealthy tracks the last unhealthy time of endpoints.
-	unhealthy map[string]time.Time
+	// downc closes when grpc calls down() on pinAddr
+	downc chan struct{}
 
+	// stopc is closed to signal updateNotifyLoop should stop.
 	stopc    chan struct{}
 	stopOnce sync.Once
+	wg       sync.WaitGroup
+
+	// donec closes when all goroutines are exited
+	donec chan struct{}
 
+	// updateAddrsC notifies updateNotifyLoop to update addrs.
+	updateAddrsC chan notifyMsg
+
+	// grpc issues TLS cert checks using the string passed into dial so
+	// that string must be the host. To recover the full scheme://host URL,
+	// have a map from hosts to the original endpoint.
 	hostPort2ep map[string]string
 
-	wg sync.WaitGroup
+	// pinAddr is the currently pinned address; set to the empty string on
+	// initialization and shutdown.
+	pinAddr string
+
+	closed bool
 }
 
-func newHealthBalancer(b *simpleBalancer, timeout time.Duration, hc healthCheckFunc) *healthBalancer {
+func newHealthBalancer(eps []string, timeout time.Duration, hc healthCheckFunc) *healthBalancer {
+	notifyCh := make(chan []grpc.Address)
+	addrs := eps2addrs(eps)
 	hb := &healthBalancer{
-		simpleBalancer: b,
-		healthCheck:    hc,
-		eps:            b.endpoints(),
-		addrs:          eps2addrs(b.endpoints()),
-		hostPort2ep:    getHostPort2ep(b.endpoints()),
-		unhealthy:      make(map[string]time.Time),
-		stopc:          make(chan struct{}),
+		addrs:              addrs,
+		eps:                eps,
+		notifyCh:           notifyCh,
+		readyc:             make(chan struct{}),
+		healthCheck:        hc,
+		unhealthyHostPorts: make(map[string]time.Time),
+		upc:                make(chan struct{}),
+		stopc:              make(chan struct{}),
+		downc:              make(chan struct{}),
+		donec:              make(chan struct{}),
+		updateAddrsC:       make(chan notifyMsg),
+		hostPort2ep:        getHostPort2ep(eps),
 	}
 	if timeout < minHealthRetryDuration {
 		timeout = minHealthRetryDuration
 	}
 	hb.healthCheckTimeout = timeout
 
+	close(hb.downc)
+	go hb.updateNotifyLoop()
 	hb.wg.Add(1)
 	go func() {
 		defer hb.wg.Done()
-		hb.updateUnhealthy(timeout)
+		hb.updateUnhealthy()
 	}()
-
 	return hb
 }
 
-func (hb *healthBalancer) Up(addr grpc.Address) func(error) {
-	f, used := hb.up(addr)
-	if !used {
-		return f
+func (b *healthBalancer) Start(target string, config grpc.BalancerConfig) error { return nil }
+
+func (b *healthBalancer) ConnectNotify() <-chan struct{} {
+	b.mu.Lock()
+	defer b.mu.Unlock()
+	return b.upc
+}
+
+func (b *healthBalancer) ready() <-chan struct{} { return b.readyc }
+
+func (b *healthBalancer) endpoint(hostPort string) string {
+	b.mu.RLock()
+	defer b.mu.RUnlock()
+	return b.hostPort2ep[hostPort]
+}
+
+func (b *healthBalancer) pinned() string {
+	b.mu.RLock()
+	defer b.mu.RUnlock()
+	return b.pinAddr
+}
+
+func (b *healthBalancer) hostPortError(hostPort string, err error) {
+	if b.endpoint(hostPort) == "" {
+		if logger.V(4) {
+			logger.Infof("clientv3/balancer: %q is stale (skip marking as unhealthy on %q)", hostPort, err.Error())
+		}
+		return
 	}
-	return func(err error) {
-		// If connected to a black hole endpoint or a killed server, the gRPC ping
-		// timeout will induce a network I/O error, and retrying until success;
-		// finding healthy endpoint on retry could take several timeouts and redials.
-		// To avoid wasting retries, gray-list unhealthy endpoints.
-		hb.hostPortError(addr.Addr, err)
-		f(err)
+
+	b.unhealthyMu.Lock()
+	b.unhealthyHostPorts[hostPort] = time.Now()
+	b.unhealthyMu.Unlock()
+	if logger.V(4) {
+		logger.Infof("clientv3/balancer: %q is marked unhealthy (%q)", hostPort, err.Error())
 	}
 }
 
-func (hb *healthBalancer) up(addr grpc.Address) (func(error), bool) {
-	if !hb.mayPin(addr) {
-		return func(err error) {}, false
+func (b *healthBalancer) removeUnhealthy(hostPort, msg string) {
+	if b.endpoint(hostPort) == "" {
+		if logger.V(4) {
+			logger.Infof("clientv3/balancer: %q was not in unhealthy (%q)", hostPort, msg)
+		}
+		return
+	}
+
+	b.unhealthyMu.Lock()
+	delete(b.unhealthyHostPorts, hostPort)
+	b.unhealthyMu.Unlock()
+	if logger.V(4) {
+		logger.Infof("clientv3/balancer: %q is removed from unhealthy (%q)", hostPort, msg)
 	}
-	return hb.simpleBalancer.up(addr)
 }
 
-func (hb *healthBalancer) Close() error {
-	hb.stopOnce.Do(func() { close(hb.stopc) })
-	hb.wg.Wait()
-	return hb.simpleBalancer.Close()
+func (b *healthBalancer) countUnhealthy() (count int) {
+	b.unhealthyMu.RLock()
+	count = len(b.unhealthyHostPorts)
+	b.unhealthyMu.RUnlock()
+	return count
 }
 
-func (hb *healthBalancer) updateAddrs(eps ...string) {
-	addrs, hostPort2ep := eps2addrs(eps), getHostPort2ep(eps)
-	hb.mu.Lock()
-	hb.addrs, hb.eps, hb.hostPort2ep = addrs, eps, hostPort2ep
-	hb.unhealthy = make(map[string]time.Time)
-	hb.mu.Unlock()
-	hb.simpleBalancer.updateAddrs(eps...)
+func (b *healthBalancer) isUnhealthy(hostPort string) (unhealthy bool) {
+	b.unhealthyMu.RLock()
+	_, unhealthy = b.unhealthyHostPorts[hostPort]
+	b.unhealthyMu.RUnlock()
+	return unhealthy
 }
 
-func (hb *healthBalancer) endpoint(host string) string {
-	hb.mu.RLock()
-	defer hb.mu.RUnlock()
-	return hb.hostPort2ep[host]
+func (b *healthBalancer) cleanupUnhealthy() {
+	b.unhealthyMu.Lock()
+	for k, v := range b.unhealthyHostPorts {
+		if time.Since(v) > b.healthCheckTimeout {
+			delete(b.unhealthyHostPorts, k)
+			if logger.V(4) {
+				logger.Infof("clientv3/balancer: removed %q from unhealthy after %v", k, b.healthCheckTimeout)
+			}
+		}
+	}
+	b.unhealthyMu.Unlock()
 }
 
-func (hb *healthBalancer) endpoints() []string {
-	hb.mu.RLock()
-	defer hb.mu.RUnlock()
-	return hb.eps
+func (b *healthBalancer) liveAddrs() ([]grpc.Address, map[string]struct{}) {
+	unhealthyCnt := b.countUnhealthy()
+
+	b.mu.RLock()
+	defer b.mu.RUnlock()
+
+	hbAddrs := b.addrs
+	if len(b.addrs) == 1 || unhealthyCnt == 0 || unhealthyCnt == len(b.addrs) {
+		liveHostPorts := make(map[string]struct{}, len(b.hostPort2ep))
+		for k := range b.hostPort2ep {
+			liveHostPorts[k] = struct{}{}
+		}
+		return hbAddrs, liveHostPorts
+	}
+
+	addrs := make([]grpc.Address, 0, len(b.addrs)-unhealthyCnt)
+	liveHostPorts := make(map[string]struct{}, len(addrs))
+	for _, addr := range b.addrs {
+		if !b.isUnhealthy(addr.Addr) {
+			addrs = append(addrs, addr)
+			liveHostPorts[addr.Addr] = struct{}{}
+		}
+	}
+	return addrs, liveHostPorts
 }
 
-func (hb *healthBalancer) updateUnhealthy(timeout time.Duration) {
+func (b *healthBalancer) updateUnhealthy() {
 	for {
 		select {
-		case <-time.After(timeout):
-			hb.mu.Lock()
-			for k, v := range hb.unhealthy {
-				if time.Since(v) > timeout {
-					delete(hb.unhealthy, k)
-					if logger.V(4) {
-						logger.Infof("clientv3/health-balancer: removes %q from unhealthy after %v", k, timeout)
-					}
+		case <-time.After(b.healthCheckTimeout):
+			b.cleanupUnhealthy()
+			pinned := b.pinned()
+			if pinned == "" || b.isUnhealthy(pinned) {
+				select {
+				case b.updateAddrsC <- notifyNext:
+				case <-b.stopc:
+					return
 				}
 			}
-			hb.mu.Unlock()
-			eps := []string{}
-			for _, addr := range hb.liveAddrs() {
-				eps = append(eps, hb.endpoint(addr.Addr))
-			}
-			hb.simpleBalancer.updateAddrs(eps...)
-		case <-hb.stopc:
+		case <-b.stopc:
 			return
 		}
 	}
 }
 
-func (hb *healthBalancer) liveAddrs() []grpc.Address {
-	hb.mu.RLock()
-	defer hb.mu.RUnlock()
-	hbAddrs := hb.addrs
-	if len(hb.addrs) == 1 || len(hb.unhealthy) == 0 || len(hb.unhealthy) == len(hb.addrs) {
-		return hbAddrs
+func (b *healthBalancer) updateAddrs(eps ...string) {
+	np := getHostPort2ep(eps)
+
+	b.mu.Lock()
+	defer b.mu.Unlock()
+
+	match := len(np) == len(b.hostPort2ep)
+	if match {
+		for k, v := range np {
+			if b.hostPort2ep[k] != v {
+				match = false
+				break
+			}
+		}
 	}
-	addrs := make([]grpc.Address, 0, len(hb.addrs)-len(hb.unhealthy))
-	for _, addr := range hb.addrs {
-		if _, unhealthy := hb.unhealthy[addr.Addr]; !unhealthy {
-			addrs = append(addrs, addr)
+	if match {
+		// same endpoints, so no need to update address
+		return
+	}
+
+	b.hostPort2ep = np
+	b.addrs, b.eps = eps2addrs(eps), eps
+
+	b.unhealthyMu.Lock()
+	b.unhealthyHostPorts = make(map[string]time.Time)
+	b.unhealthyMu.Unlock()
+}
+
+func (b *healthBalancer) next() {
+	b.mu.RLock()
+	downc := b.downc
+	b.mu.RUnlock()
+	select {
+	case b.updateAddrsC <- notifyNext:
+	case <-b.stopc:
+	}
+	// wait until disconnect so new RPCs are not issued on old connection
+	select {
+	case <-downc:
+	case <-b.stopc:
+	}
+}
+
+func (b *healthBalancer) updateNotifyLoop() {
+	defer close(b.donec)
+
+	for {
+		b.mu.RLock()
+		upc, downc, addr := b.upc, b.downc, b.pinAddr
+		b.mu.RUnlock()
+		// downc or upc should be closed
+		select {
+		case <-downc:
+			downc = nil
+		default:
+		}
+		select {
+		case <-upc:
+			upc = nil
+		default:
+		}
+		switch {
+		case downc == nil && upc == nil:
+			// stale
+			select {
+			case <-b.stopc:
+				return
+			default:
+			}
+		case downc == nil:
+			b.notifyAddrs(notifyReset)
+			select {
+			case <-upc:
+			case msg := <-b.updateAddrsC:
+				b.notifyAddrs(msg)
+			case <-b.stopc:
+				return
+			}
+		case upc == nil:
+			select {
+			// close connections that are not the pinned address
+			case b.notifyCh <- []grpc.Address{{Addr: addr}}:
+			case <-downc:
+			case <-b.stopc:
+				return
+			}
+			select {
+			case <-downc:
+				b.notifyAddrs(notifyReset)
+			case msg := <-b.updateAddrsC:
+				b.notifyAddrs(msg)
+			case <-b.stopc:
+				return
+			}
 		}
 	}
-	return addrs
 }
 
-func (hb *healthBalancer) hostPortError(hostPort string, err error) {
-	hb.mu.Lock()
-	if _, ok := hb.hostPort2ep[hostPort]; ok {
-		hb.unhealthy[hostPort] = time.Now()
+func (b *healthBalancer) notifyAddrs(msg notifyMsg) {
+	if msg == notifyNext {
+		select {
+		case b.notifyCh <- []grpc.Address{}:
+		case <-b.stopc:
+			return
+		}
+	}
+	b.mu.RLock()
+	addrs := b.addrs
+	pinAddr := b.pinAddr
+	downc := b.downc
+	b.mu.RUnlock()
+
+	var waitDown bool
+	if pinAddr != "" {
+		waitDown = true
+		for _, a := range addrs {
+			if a.Addr == pinAddr {
+				waitDown = false
+			}
+		}
+	}
+
+	select {
+	case b.notifyCh <- addrs:
+		if waitDown {
+			select {
+			case <-downc:
+			case <-b.stopc:
+			}
+		}
+	case <-b.stopc:
+	}
+}
+
+func (b *healthBalancer) Up(addr grpc.Address) func(error) {
+	if !b.mayPin(addr) {
+		return func(err error) {}
+	}
+
+	b.mu.Lock()
+	defer b.mu.Unlock()
+
+	// gRPC might call Up after it called Close. We add this check
+	// to "fix" it up at application layer. Otherwise, will panic
+	// if b.upc is already closed.
+	if b.closed {
+		return func(err error) {}
+	}
+
+	// gRPC might call Up on a stale address.
+	// Prevent updating pinAddr with a stale address.
+	if !hasAddr(b.addrs, addr.Addr) {
+		return func(err error) {}
+	}
+
+	if b.pinAddr != "" {
+		if logger.V(4) {
+			logger.Infof("clientv3/balancer: %q is up but not pinned (already pinned %q)", addr.Addr, b.pinAddr)
+		}
+		return func(err error) {}
+	}
+
+	// notify waiting Get()s and pin first connected address
+	close(b.upc)
+	b.downc = make(chan struct{})
+	b.pinAddr = addr.Addr
+	if logger.V(4) {
+		logger.Infof("clientv3/balancer: pin %q", addr.Addr)
+	}
+
+	// notify client that a connection is up
+	b.readyOnce.Do(func() { close(b.readyc) })
+
+	return func(err error) {
+		// If connected to a black hole endpoint or a killed server, the gRPC ping
+		// timeout will induce a network I/O error, and retrying until success;
+		// finding healthy endpoint on retry could take several timeouts and redials.
+		// To avoid wasting retries, gray-list unhealthy endpoints.
+		b.hostPortError(addr.Addr, err)
+
+		b.mu.Lock()
+		b.upc = make(chan struct{})
+		close(b.downc)
+		b.pinAddr = ""
+		b.mu.Unlock()
 		if logger.V(4) {
-			logger.Infof("clientv3/health-balancer: marking %q as unhealthy (%q)", hostPort, err.Error())
+			logger.Infof("clientv3/balancer: unpin %q (%q)", addr.Addr, err.Error())
 		}
 	}
-	hb.mu.Unlock()
 }
 
-func (hb *healthBalancer) mayPin(addr grpc.Address) bool {
-	hb.mu.RLock()
-	if _, ok := hb.hostPort2ep[addr.Addr]; !ok { // stale host:port
-		hb.mu.RUnlock()
+func (b *healthBalancer) mayPin(addr grpc.Address) bool {
+	if b.endpoint(addr.Addr) == "" { // stale host:port
 		return false
 	}
-	skip := len(hb.addrs) == 1 || len(hb.unhealthy) == 0 || len(hb.addrs) == len(hb.unhealthy)
-	failedTime, bad := hb.unhealthy[addr.Addr]
-	dur := hb.healthCheckTimeout
-	hb.mu.RUnlock()
+
+	b.unhealthyMu.RLock()
+	unhealthyCnt := len(b.unhealthyHostPorts)
+	failedTime, bad := b.unhealthyHostPorts[addr.Addr]
+	b.unhealthyMu.RUnlock()
+
+	b.mu.RLock()
+	skip := len(b.addrs) == 1 || unhealthyCnt == 0 || len(b.addrs) == unhealthyCnt
+	b.mu.RUnlock()
 	if skip || !bad {
 		return true
 	}
+
 	// prevent isolated member's endpoint from being infinitely retried, as follows:
 	//   1. keepalive pings detects GoAway with http2.ErrCodeEnhanceYourCalm
 	//   2. balancer 'Up' unpins with grpc: failed with network I/O error
 	//   3. grpc-healthcheck still SERVING, thus retry to pin
 	// instead, return before grpc-healthcheck if failed within healthcheck timeout
-	if elapsed := time.Since(failedTime); elapsed < dur {
+	if elapsed := time.Since(failedTime); elapsed < b.healthCheckTimeout {
 		if logger.V(4) {
-			logger.Infof("clientv3/health-balancer: %q is up but not pinned (failed %v ago, require minimum %v after failure)", addr.Addr, elapsed, dur)
+			logger.Infof("clientv3/balancer: %q is up but not pinned (failed %v ago, require minimum %v after failure)", addr.Addr, elapsed, b.healthCheckTimeout)
 		}
 		return false
 	}
-	if ok, _ := hb.healthCheck(addr.Addr); ok {
-		hb.mu.Lock()
-		delete(hb.unhealthy, addr.Addr)
-		hb.mu.Unlock()
-		if logger.V(4) {
-			logger.Infof("clientv3/health-balancer: %q is healthy (health check success)", addr.Addr)
-		}
+
+	if ok, _ := b.healthCheck(addr.Addr); ok {
+		b.removeUnhealthy(addr.Addr, "health check success")
 		return true
 	}
-	hb.mu.Lock()
-	hb.unhealthy[addr.Addr] = time.Now()
-	hb.mu.Unlock()
-	if logger.V(4) {
-		logger.Infof("clientv3/health-balancer: %q becomes unhealthy (health check failed)", addr.Addr)
-	}
+
+	b.hostPortError(addr.Addr, errors.New("health check failed"))
 	return false
 }
 
+func (b *healthBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
+	var (
+		addr   string
+		closed bool
+	)
+
+	// If opts.BlockingWait is false (for fail-fast RPCs), it should return
+	// an address it has notified via Notify immediately instead of blocking.
+	if !opts.BlockingWait {
+		b.mu.RLock()
+		closed = b.closed
+		addr = b.pinAddr
+		b.mu.RUnlock()
+		if closed {
+			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
+		}
+		if addr == "" {
+			return grpc.Address{Addr: ""}, nil, ErrNoAddrAvilable
+		}
+		return grpc.Address{Addr: addr}, func() {}, nil
+	}
+
+	for {
+		b.mu.RLock()
+		ch := b.upc
+		b.mu.RUnlock()
+		select {
+		case <-ch:
+		case <-b.donec:
+			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
+		case <-ctx.Done():
+			return grpc.Address{Addr: ""}, nil, ctx.Err()
+		}
+		b.mu.RLock()
+		closed = b.closed
+		addr = b.pinAddr
+		b.mu.RUnlock()
+		// Close() which sets b.closed = true can be called before Get(), Get() must exit if balancer is closed.
+		if closed {
+			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
+		}
+		if addr != "" {
+			break
+		}
+	}
+	return grpc.Address{Addr: addr}, func() {}, nil
+}
+
+func (b *healthBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
+
+func (b *healthBalancer) Close() error {
+	b.mu.Lock()
+	// In case gRPC calls close twice. TODO: remove the checking
+	// when we are sure that gRPC wont call close twice.
+	if b.closed {
+		b.mu.Unlock()
+		<-b.donec
+		return nil
+	}
+	b.closed = true
+	b.stopOnce.Do(func() { close(b.stopc) })
+	b.pinAddr = ""
+
+	// In the case of following scenario:
+	//	1. upc is not closed; no pinned address
+	// 	2. client issues an RPC, calling invoke(), which calls Get(), enters for loop, blocks
+	// 	3. client.conn.Close() calls balancer.Close(); closed = true
+	// 	4. for loop in Get() never exits since ctx is the context passed in by the client and may not be canceled
+	// we must close upc so Get() exits from blocking on upc
+	select {
+	case <-b.upc:
+	default:
+		// terminate all waiting Get()s
+		close(b.upc)
+	}
+
+	b.mu.Unlock()
+	b.wg.Wait()
+
+	// wait for updateNotifyLoop to finish
+	<-b.donec
+	close(b.notifyCh)
+
+	return nil
+}
+
 func grpcHealthCheck(client *Client, ep string) (bool, error) {
 	conn, err := client.dial(ep)
 	if err != nil {
@@ -238,8 +587,7 @@ func grpcHealthCheck(client *Client, ep string) (bool, error) {
 	cancel()
 	if err != nil {
 		if s, ok := status.FromError(err); ok && s.Code() == codes.Unavailable {
-			if s.Message() == unknownService {
-				// etcd < v3.3.0
+			if s.Message() == unknownService { // etcd < v3.3.0
 				return true, nil
 			}
 		}
@@ -247,3 +595,37 @@ func grpcHealthCheck(client *Client, ep string) (bool, error) {
 	}
 	return resp.Status == healthpb.HealthCheckResponse_SERVING, nil
 }
+
+func hasAddr(addrs []grpc.Address, targetAddr string) bool {
+	for _, addr := range addrs {
+		if targetAddr == addr.Addr {
+			return true
+		}
+	}
+	return false
+}
+
+func getHost(ep string) string {
+	url, uerr := url.Parse(ep)
+	if uerr != nil || !strings.Contains(ep, "://") {
+		return ep
+	}
+	return url.Host
+}
+
+func eps2addrs(eps []string) []grpc.Address {
+	addrs := make([]grpc.Address, len(eps))
+	for i := range eps {
+		addrs[i].Addr = getHost(eps[i])
+	}
+	return addrs
+}
+
+func getHostPort2ep(eps []string) map[string]string {
+	hm := make(map[string]string, len(eps))
+	for i := range eps {
+		_, host, _ := parseEndpoint(eps[i])
+		hm[host] = eps[i]
+	}
+	return hm
+}

+ 32 - 35
clientv3/balancer_test.go → clientv3/health_balancer_test.go

@@ -1,4 +1,4 @@
-// Copyright 2016 The etcd Authors
+// Copyright 2017 The etcd Authors
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -28,29 +28,27 @@ import (
 	"google.golang.org/grpc"
 )
 
-var (
-	endpoints = []string{"localhost:2379", "localhost:22379", "localhost:32379"}
-)
+var endpoints = []string{"localhost:2379", "localhost:22379", "localhost:32379"}
 
 func TestBalancerGetUnblocking(t *testing.T) {
-	sb := newSimpleBalancer(endpoints)
-	defer sb.Close()
-	if addrs := <-sb.Notify(); len(addrs) != len(endpoints) {
-		t.Errorf("Initialize newSimpleBalancer should have triggered Notify() chan, but it didn't")
+	hb := newHealthBalancer(endpoints, minHealthRetryDuration, func(string) (bool, error) { return true, nil })
+	defer hb.Close()
+	if addrs := <-hb.Notify(); len(addrs) != len(endpoints) {
+		t.Errorf("Initialize newHealthBalancer should have triggered Notify() chan, but it didn't")
 	}
 	unblockingOpts := grpc.BalancerGetOptions{BlockingWait: false}
 
-	_, _, err := sb.Get(context.Background(), unblockingOpts)
+	_, _, err := hb.Get(context.Background(), unblockingOpts)
 	if err != ErrNoAddrAvilable {
 		t.Errorf("Get() with no up endpoints should return ErrNoAddrAvailable, got: %v", err)
 	}
 
-	down1 := sb.Up(grpc.Address{Addr: endpoints[1]})
-	if addrs := <-sb.Notify(); len(addrs) != 1 {
+	down1 := hb.Up(grpc.Address{Addr: endpoints[1]})
+	if addrs := <-hb.Notify(); len(addrs) != 1 {
 		t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed")
 	}
-	down2 := sb.Up(grpc.Address{Addr: endpoints[2]})
-	addrFirst, putFun, err := sb.Get(context.Background(), unblockingOpts)
+	down2 := hb.Up(grpc.Address{Addr: endpoints[2]})
+	addrFirst, putFun, err := hb.Get(context.Background(), unblockingOpts)
 	if err != nil {
 		t.Errorf("Get() with up endpoints should success, got %v", err)
 	}
@@ -60,32 +58,32 @@ func TestBalancerGetUnblocking(t *testing.T) {
 	if putFun == nil {
 		t.Errorf("Get() returned unexpected nil put function")
 	}
-	addrSecond, _, _ := sb.Get(context.Background(), unblockingOpts)
+	addrSecond, _, _ := hb.Get(context.Background(), unblockingOpts)
 	if addrFirst.Addr != addrSecond.Addr {
 		t.Errorf("Get() didn't return the same address as previous call, got %v and %v", addrFirst, addrSecond)
 	}
 
 	down1(errors.New("error"))
-	if addrs := <-sb.Notify(); len(addrs) != len(endpoints) {
+	if addrs := <-hb.Notify(); len(addrs) != len(endpoints) {
 		t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection")
 	}
 	down2(errors.New("error"))
-	_, _, err = sb.Get(context.Background(), unblockingOpts)
+	_, _, err = hb.Get(context.Background(), unblockingOpts)
 	if err != ErrNoAddrAvilable {
 		t.Errorf("Get() with no up endpoints should return ErrNoAddrAvailable, got: %v", err)
 	}
 }
 
 func TestBalancerGetBlocking(t *testing.T) {
-	sb := newSimpleBalancer(endpoints)
-	defer sb.Close()
-	if addrs := <-sb.Notify(); len(addrs) != len(endpoints) {
-		t.Errorf("Initialize newSimpleBalancer should have triggered Notify() chan, but it didn't")
+	hb := newHealthBalancer(endpoints, minHealthRetryDuration, func(string) (bool, error) { return true, nil })
+	defer hb.Close()
+	if addrs := <-hb.Notify(); len(addrs) != len(endpoints) {
+		t.Errorf("Initialize newHealthBalancer should have triggered Notify() chan, but it didn't")
 	}
 	blockingOpts := grpc.BalancerGetOptions{BlockingWait: true}
 
 	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
-	_, _, err := sb.Get(ctx, blockingOpts)
+	_, _, err := hb.Get(ctx, blockingOpts)
 	cancel()
 	if err != context.DeadlineExceeded {
 		t.Errorf("Get() with no up endpoints should timeout, got %v", err)
@@ -94,15 +92,15 @@ func TestBalancerGetBlocking(t *testing.T) {
 	downC := make(chan func(error), 1)
 
 	go func() {
-		// ensure sb.Up() will be called after sb.Get() to see if Up() releases blocking Get()
+		// ensure hb.Up() will be called after hb.Get() to see if Up() releases blocking Get()
 		time.Sleep(time.Millisecond * 100)
-		f := sb.Up(grpc.Address{Addr: endpoints[1]})
-		if addrs := <-sb.Notify(); len(addrs) != 1 {
+		f := hb.Up(grpc.Address{Addr: endpoints[1]})
+		if addrs := <-hb.Notify(); len(addrs) != 1 {
 			t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed")
 		}
 		downC <- f
 	}()
-	addrFirst, putFun, err := sb.Get(context.Background(), blockingOpts)
+	addrFirst, putFun, err := hb.Get(context.Background(), blockingOpts)
 	if err != nil {
 		t.Errorf("Get() with up endpoints should success, got %v", err)
 	}
@@ -114,19 +112,19 @@ func TestBalancerGetBlocking(t *testing.T) {
 	}
 	down1 := <-downC
 
-	down2 := sb.Up(grpc.Address{Addr: endpoints[2]})
-	addrSecond, _, _ := sb.Get(context.Background(), blockingOpts)
+	down2 := hb.Up(grpc.Address{Addr: endpoints[2]})
+	addrSecond, _, _ := hb.Get(context.Background(), blockingOpts)
 	if addrFirst.Addr != addrSecond.Addr {
 		t.Errorf("Get() didn't return the same address as previous call, got %v and %v", addrFirst, addrSecond)
 	}
 
 	down1(errors.New("error"))
-	if addrs := <-sb.Notify(); len(addrs) != len(endpoints) {
+	if addrs := <-hb.Notify(); len(addrs) != len(endpoints) {
 		t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection")
 	}
 	down2(errors.New("error"))
 	ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond*100)
-	_, _, err = sb.Get(ctx, blockingOpts)
+	_, _, err = hb.Get(ctx, blockingOpts)
 	cancel()
 	if err != context.DeadlineExceeded {
 		t.Errorf("Get() with no up endpoints should timeout, got %v", err)
@@ -168,9 +166,8 @@ func TestHealthBalancerGraylist(t *testing.T) {
 		}()
 	}
 
-	sb := newSimpleBalancer(eps)
 	tf := func(s string) (bool, error) { return false, nil }
-	hb := newHealthBalancer(sb, 5*time.Second, tf)
+	hb := newHealthBalancer(eps, 5*time.Second, tf)
 
 	conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(hb))
 	testutil.AssertNil(t, err)
@@ -203,13 +200,13 @@ func TestBalancerDoNotBlockOnClose(t *testing.T) {
 	defer kcl.close()
 
 	for i := 0; i < 5; i++ {
-		sb := newSimpleBalancer(kcl.endpoints())
-		conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(sb))
+		hb := newHealthBalancer(kcl.endpoints(), minHealthRetryDuration, func(string) (bool, error) { return true, nil })
+		conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(hb))
 		if err != nil {
 			t.Fatal(err)
 		}
 		kvc := pb.NewKVClient(conn)
-		<-sb.readyc
+		<-hb.readyc
 
 		var wg sync.WaitGroup
 		wg.Add(100)
@@ -225,7 +222,7 @@ func TestBalancerDoNotBlockOnClose(t *testing.T) {
 		bclosec, cclosec := make(chan struct{}), make(chan struct{})
 		go func() {
 			defer close(bclosec)
-			sb.Close()
+			hb.Close()
 		}()
 		go func() {
 			defer close(cclosec)