Browse Source

clientv3: update eps if pinAddr is not included in updateAddrs

FIXES #7392
fanmin shi 8 years ago
parent
commit
a23609efe6
2 changed files with 74 additions and 20 deletions
  1. 53 20
      clientv3/balancer.go
  2. 21 0
      clientv3/integration/dial_test.go

+ 53 - 20
clientv3/balancer.go

@@ -56,6 +56,9 @@ type simpleBalancer struct {
 	// donec closes when all goroutines are exited
 	donec chan struct{}
 
+	// updateAddrsC notifies updateNotifyLoop to update addrs.
+	updateAddrsC chan struct{}
+
 	// grpc issues TLS cert checks using the string passed into dial so
 	// that string must be the host. To recover the full scheme://host URL,
 	// have a map from hosts to the original endpoint.
@@ -76,14 +79,15 @@ func newSimpleBalancer(eps []string) *simpleBalancer {
 	}
 	notifyCh <- addrs
 	sb := &simpleBalancer{
-		addrs:    addrs,
-		notifyCh: notifyCh,
-		readyc:   make(chan struct{}),
-		upc:      make(chan struct{}),
-		stopc:    make(chan struct{}),
-		downc:    make(chan struct{}),
-		donec:    make(chan struct{}),
-		host2ep:  getHost2ep(eps),
+		addrs:        addrs,
+		notifyCh:     notifyCh,
+		readyc:       make(chan struct{}),
+		upc:          make(chan struct{}),
+		stopc:        make(chan struct{}),
+		downc:        make(chan struct{}),
+		donec:        make(chan struct{}),
+		updateAddrsC: make(chan struct{}, 1),
+		host2ep:      getHost2ep(eps),
 	}
 	go sb.updateNotifyLoop()
 	return sb
@@ -116,7 +120,6 @@ func (b *simpleBalancer) updateAddrs(eps []string) {
 	np := getHost2ep(eps)
 
 	b.mu.Lock()
-	defer b.mu.Unlock()
 
 	match := len(np) == len(b.host2ep)
 	for k, v := range np {
@@ -127,6 +130,7 @@ func (b *simpleBalancer) updateAddrs(eps []string) {
 	}
 	if match {
 		// same endpoints, so no need to update address
+		b.mu.Unlock()
 		return
 	}
 
@@ -137,13 +141,30 @@ func (b *simpleBalancer) updateAddrs(eps []string) {
 		addrs = append(addrs, grpc.Address{Addr: getHost(eps[i])})
 	}
 	b.addrs = addrs
+
 	// updating notifyCh can trigger new connections,
-	// but balancer only expects new connections if all connections are down
-	if b.pinAddr == "" {
-		b.notifyCh <- addrs
+	// only update addrs if all connections are down
+	// or addrs does not include pinAddr.
+	update := !hasAddr(addrs, b.pinAddr)
+	b.mu.Unlock()
+
+	if update {
+		select {
+		case b.updateAddrsC <- struct{}{}:
+		case <-b.stopc:
+		}
 	}
 }
 
+func hasAddr(addrs []grpc.Address, targetAddr string) bool {
+	for _, addr := range addrs {
+		if targetAddr == addr.Addr {
+			return true
+		}
+	}
+	return false
+}
+
 func (b *simpleBalancer) updateNotifyLoop() {
 	defer close(b.donec)
 
@@ -170,21 +191,28 @@ func (b *simpleBalancer) updateNotifyLoop() {
 			case <-b.stopc:
 				return
 			}
+		case <-b.updateAddrsC:
+			b.notifyAddrs()
+			continue
 		}
 		select {
 		case <-downc:
-			b.mu.RLock()
-			addrs := b.addrs
-			b.mu.RUnlock()
-			select {
-			case b.notifyCh <- addrs:
-			case <-b.stopc:
-				return
-			}
+			b.notifyAddrs()
+		case <-b.updateAddrsC:
+			b.notifyAddrs()
 		case <-b.stopc:
 			return
 		}
+	}
+}
 
+func (b *simpleBalancer) notifyAddrs() {
+	b.mu.RLock()
+	addrs := b.addrs
+	b.mu.RUnlock()
+	select {
+	case b.notifyCh <- addrs:
+	case <-b.stopc:
 	}
 }
 
@@ -198,6 +226,11 @@ func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
 	if b.closed {
 		return func(err error) {}
 	}
+	// gRPC might call Up on a stale address.
+	// Prevent updating pinAddr with a stale address.
+	if !hasAddr(b.addrs, addr.Addr) {
+		return func(err error) {}
+	}
 
 	if b.pinAddr == "" {
 		// notify waiting Get()s and pin first connected address

+ 21 - 0
clientv3/integration/dial_test.go

@@ -71,6 +71,27 @@ func testDialSetEndpoints(t *testing.T, setBefore bool) {
 	cancel()
 }
 
+// TestSwitchSetEndpoints ensures SetEndpoints can switch one endpoint
+// with a new one that doesn't include original endpoint.
+func TestSwitchSetEndpoints(t *testing.T) {
+	defer testutil.AfterTest(t)
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
+	defer clus.Terminate(t)
+
+	// get non partitioned members endpoints
+	eps := []string{clus.Members[1].GRPCAddr(), clus.Members[2].GRPCAddr()}
+
+	cli := clus.Client(0)
+	clus.Members[0].InjectPartition(t, clus.Members[1:])
+
+	cli.SetEndpoints(eps...)
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+	if _, err := cli.Get(ctx, "foo"); err != nil {
+		t.Fatal(err)
+	}
+}
+
 func TestRejectOldCluster(t *testing.T) {
 	defer testutil.AfterTest(t)
 	// 2 endpoints to test multi-endpoint Status