Browse Source

clientv3: balancer uses one connection at a time

FIX #7080
fanmin shi 9 years ago
parent
commit
df55438a60
2 changed files with 49 additions and 26 deletions
  1. 31 26
      clientv3/balancer.go
  2. 18 0
      clientv3/balancer_test.go

+ 31 - 26
clientv3/balancer.go

@@ -43,8 +43,7 @@ type simpleBalancer struct {
 
 	// mu protects upEps, pinAddr, and connectingAddr
 	mu sync.RWMutex
-	// upEps holds the current endpoints that have an active connection
-	upEps map[string]struct{}
+
 	// upc closes when upEps transitions from empty to non-zero or the balancer closes.
 	upc chan struct{}
 
@@ -71,7 +70,6 @@ func newSimpleBalancer(eps []string) *simpleBalancer {
 		addrs:    addrs,
 		notifyCh: notifyCh,
 		readyc:   make(chan struct{}),
-		upEps:    make(map[string]struct{}),
 		upc:      make(chan struct{}),
 		host2ep:  getHost2ep(eps),
 	}
@@ -140,48 +138,45 @@ func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
 		return func(err error) {}
 	}
 
-	if len(b.upEps) == 0 {
+	if b.pinAddr == "" {
 		// notify waiting Get()s and pin first connected address
 		close(b.upc)
 		b.pinAddr = addr.Addr
+		// notify client that a connection is up
+		b.readyOnce.Do(func() { close(b.readyc) })
+		// close opened connections that are not pinAddr
+		// this ensures only one connection is open per client
+		b.notifyCh <- []grpc.Address{addr}
 	}
-	b.upEps[addr.Addr] = struct{}{}
-
-	// notify client that a connection is up
-	b.readyOnce.Do(func() { close(b.readyc) })
 
 	return func(err error) {
 		b.mu.Lock()
-		delete(b.upEps, addr.Addr)
-		if len(b.upEps) == 0 && b.pinAddr != "" {
+		if b.pinAddr == addr.Addr {
 			b.upc = make(chan struct{})
-		} else if b.pinAddr == addr.Addr {
-			// choose new random up endpoint
-			for k := range b.upEps {
-				b.pinAddr = k
-				break
-			}
+			b.pinAddr = ""
+			b.notifyCh <- b.addrs
 		}
 		b.mu.Unlock()
 	}
 }
 
 func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
-	var addr string
+	var (
+		addr   string
+		closed bool
+	)
 
 	// If opts.BlockingWait is false (for fail-fast RPCs), it should return
 	// an address it has notified via Notify immediately instead of blocking.
 	if !opts.BlockingWait {
 		b.mu.RLock()
-		closed := b.closed
+		closed = b.closed
 		addr = b.pinAddr
-		upEps := len(b.upEps)
 		b.mu.RUnlock()
 		if closed {
 			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
 		}
-
-		if upEps == 0 {
+		if addr == "" {
 			return grpc.Address{Addr: ""}, nil, ErrNoAddrAvilable
 		}
 		return grpc.Address{Addr: addr}, func() {}, nil
@@ -197,13 +192,14 @@ func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions)
 			return grpc.Address{Addr: ""}, nil, ctx.Err()
 		}
 		b.mu.RLock()
+		closed = b.closed
 		addr = b.pinAddr
-		upEps := len(b.upEps)
 		b.mu.RUnlock()
-		if addr == "" {
+		// Close() which sets b.closed = true can be called before Get(), Get() must exit if balancer is closed.
+		if closed {
 			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
 		}
-		if upEps > 0 {
+		if addr != "" {
 			break
 		}
 	}
@@ -222,9 +218,18 @@ func (b *simpleBalancer) Close() error {
 	}
 	b.closed = true
 	close(b.notifyCh)
-	// terminate all waiting Get()s
 	b.pinAddr = ""
-	if len(b.upEps) == 0 {
+
+	// In the case of follwing scenerio:
+	//	1. upc is not closed; no pinned address
+	// 	2. client issues an rpc, calling invoke(), which calls Get(), enters for loop, blocks
+	// 	3. clientconn.Close() calls balancer.Close(); closed = true
+	// 	4. for loop in Get() never exits since ctx is the context passed in by the client and may not be canceled
+	// we must close upc so Get() exits from blocking on upc
+	select {
+	case <-b.upc:
+	default:
+		// terminate all waiting Get()s
 		close(b.upc)
 	}
 	return nil

+ 18 - 0
clientv3/balancer_test.go

@@ -29,6 +29,9 @@ var (
 
 func TestBalancerGetUnblocking(t *testing.T) {
 	sb := newSimpleBalancer(endpoints)
+	if addrs := <-sb.Notify(); len(addrs) != len(endpoints) {
+		t.Errorf("Initialize newSimpleBalancer should have triggered Notify() chan, but it didn't")
+	}
 	unblockingOpts := grpc.BalancerGetOptions{BlockingWait: false}
 
 	_, _, err := sb.Get(context.Background(), unblockingOpts)
@@ -37,6 +40,9 @@ func TestBalancerGetUnblocking(t *testing.T) {
 	}
 
 	down1 := sb.Up(grpc.Address{Addr: endpoints[1]})
+	if addrs := <-sb.Notify(); len(addrs) != 1 {
+		t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed")
+	}
 	down2 := sb.Up(grpc.Address{Addr: endpoints[2]})
 	addrFirst, putFun, err := sb.Get(context.Background(), unblockingOpts)
 	if err != nil {
@@ -54,6 +60,9 @@ func TestBalancerGetUnblocking(t *testing.T) {
 	}
 
 	down1(errors.New("error"))
+	if addrs := <-sb.Notify(); len(addrs) != len(endpoints) {
+		t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection")
+	}
 	down2(errors.New("error"))
 	_, _, err = sb.Get(context.Background(), unblockingOpts)
 	if err != ErrNoAddrAvilable {
@@ -63,6 +72,9 @@ func TestBalancerGetUnblocking(t *testing.T) {
 
 func TestBalancerGetBlocking(t *testing.T) {
 	sb := newSimpleBalancer(endpoints)
+	if addrs := <-sb.Notify(); len(addrs) != len(endpoints) {
+		t.Errorf("Initialize newSimpleBalancer should have triggered Notify() chan, but it didn't")
+	}
 	blockingOpts := grpc.BalancerGetOptions{BlockingWait: true}
 
 	ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*100)
@@ -77,6 +89,9 @@ func TestBalancerGetBlocking(t *testing.T) {
 		// ensure sb.Up() will be called after sb.Get() to see if Up() releases blocking Get()
 		time.Sleep(time.Millisecond * 100)
 		downC <- sb.Up(grpc.Address{Addr: endpoints[1]})
+		if addrs := <-sb.Notify(); len(addrs) != 1 {
+			t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed")
+		}
 	}()
 	addrFirst, putFun, err := sb.Get(context.Background(), blockingOpts)
 	if err != nil {
@@ -97,6 +112,9 @@ func TestBalancerGetBlocking(t *testing.T) {
 	}
 
 	down1(errors.New("error"))
+	if addrs := <-sb.Notify(); len(addrs) != len(endpoints) {
+		t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection")
+	}
 	down2(errors.New("error"))
 	ctx, _ = context.WithTimeout(context.Background(), time.Millisecond*100)
 	_, _, err = sb.Get(ctx, blockingOpts)