Просмотр исходного кода

Merge pull request #5845 from heyitsanthony/clientv3-ignore-dead-eps

clientv3: respect up/down notifications from grpc
Anthony Romano 9 лет назад
Родитель
Сommit
8d7703528a
30 измененных файлов с 1547 добавлено и 370 удалено
  1. 11 11
      clientv3/auth.go
  2. 99 16
      clientv3/balancer.go
  3. 37 13
      clientv3/client.go
  4. 3 0
      clientv3/client_test.go
  5. 4 4
      clientv3/cluster.go
  6. 73 0
      clientv3/integration/kv_test.go
  7. 3 3
      clientv3/kv.go
  8. 3 3
      clientv3/lease.go
  9. 243 0
      clientv3/retry.go
  10. 1 2
      clientv3/txn.go
  11. 19 6
      cmd/vendor/google.golang.org/grpc/call.go
  12. 205 97
      cmd/vendor/google.golang.org/grpc/clientconn.go
  13. 19 34
      cmd/vendor/google.golang.org/grpc/credentials/credentials.go
  14. 76 0
      cmd/vendor/google.golang.org/grpc/credentials/credentials_util_go17.go
  15. 74 0
      cmd/vendor/google.golang.org/grpc/credentials/credentials_util_pre_go17.go
  16. 10 4
      cmd/vendor/google.golang.org/grpc/metadata/metadata.go
  17. 13 5
      cmd/vendor/google.golang.org/grpc/rpc_util.go
  18. 76 19
      cmd/vendor/google.golang.org/grpc/server.go
  19. 71 40
      cmd/vendor/google.golang.org/grpc/stream.go
  20. 5 0
      cmd/vendor/google.golang.org/grpc/transport/control.go
  21. 46 0
      cmd/vendor/google.golang.org/grpc/transport/go16.go
  22. 46 0
      cmd/vendor/google.golang.org/grpc/transport/go17.go
  23. 6 2
      cmd/vendor/google.golang.org/grpc/transport/handler_server.go
  24. 141 63
      cmd/vendor/google.golang.org/grpc/transport/http2_client.go
  25. 51 25
      cmd/vendor/google.golang.org/grpc/transport/http2_server.go
  26. 79 4
      cmd/vendor/google.golang.org/grpc/transport/http_util.go
  27. 51 0
      cmd/vendor/google.golang.org/grpc/transport/pre_go16.go
  28. 78 15
      cmd/vendor/google.golang.org/grpc/transport/transport.go
  29. 3 3
      glide.lock
  30. 1 1
      glide.yaml

+ 11 - 11
clientv3/auth.go

@@ -115,32 +115,32 @@ func NewAuth(c *Client) Auth {
 }
 
 func (auth *auth) AuthEnable(ctx context.Context) (*AuthEnableResponse, error) {
-	resp, err := auth.remote.AuthEnable(ctx, &pb.AuthEnableRequest{}, grpc.FailFast(false))
+	resp, err := auth.remote.AuthEnable(ctx, &pb.AuthEnableRequest{})
 	return (*AuthEnableResponse)(resp), toErr(ctx, err)
 }
 
 func (auth *auth) AuthDisable(ctx context.Context) (*AuthDisableResponse, error) {
-	resp, err := auth.remote.AuthDisable(ctx, &pb.AuthDisableRequest{}, grpc.FailFast(false))
+	resp, err := auth.remote.AuthDisable(ctx, &pb.AuthDisableRequest{})
 	return (*AuthDisableResponse)(resp), toErr(ctx, err)
 }
 
 func (auth *auth) UserAdd(ctx context.Context, name string, password string) (*AuthUserAddResponse, error) {
-	resp, err := auth.remote.UserAdd(ctx, &pb.AuthUserAddRequest{Name: name, Password: password}, grpc.FailFast(false))
+	resp, err := auth.remote.UserAdd(ctx, &pb.AuthUserAddRequest{Name: name, Password: password})
 	return (*AuthUserAddResponse)(resp), toErr(ctx, err)
 }
 
 func (auth *auth) UserDelete(ctx context.Context, name string) (*AuthUserDeleteResponse, error) {
-	resp, err := auth.remote.UserDelete(ctx, &pb.AuthUserDeleteRequest{Name: name}, grpc.FailFast(false))
+	resp, err := auth.remote.UserDelete(ctx, &pb.AuthUserDeleteRequest{Name: name})
 	return (*AuthUserDeleteResponse)(resp), toErr(ctx, err)
 }
 
 func (auth *auth) UserChangePassword(ctx context.Context, name string, password string) (*AuthUserChangePasswordResponse, error) {
-	resp, err := auth.remote.UserChangePassword(ctx, &pb.AuthUserChangePasswordRequest{Name: name, Password: password}, grpc.FailFast(false))
+	resp, err := auth.remote.UserChangePassword(ctx, &pb.AuthUserChangePasswordRequest{Name: name, Password: password})
 	return (*AuthUserChangePasswordResponse)(resp), toErr(ctx, err)
 }
 
 func (auth *auth) UserGrantRole(ctx context.Context, user string, role string) (*AuthUserGrantRoleResponse, error) {
-	resp, err := auth.remote.UserGrantRole(ctx, &pb.AuthUserGrantRoleRequest{User: user, Role: role}, grpc.FailFast(false))
+	resp, err := auth.remote.UserGrantRole(ctx, &pb.AuthUserGrantRoleRequest{User: user, Role: role})
 	return (*AuthUserGrantRoleResponse)(resp), toErr(ctx, err)
 }
 
@@ -155,12 +155,12 @@ func (auth *auth) UserList(ctx context.Context) (*AuthUserListResponse, error) {
 }
 
 func (auth *auth) UserRevokeRole(ctx context.Context, name string, role string) (*AuthUserRevokeRoleResponse, error) {
-	resp, err := auth.remote.UserRevokeRole(ctx, &pb.AuthUserRevokeRoleRequest{Name: name, Role: role}, grpc.FailFast(false))
+	resp, err := auth.remote.UserRevokeRole(ctx, &pb.AuthUserRevokeRoleRequest{Name: name, Role: role})
 	return (*AuthUserRevokeRoleResponse)(resp), toErr(ctx, err)
 }
 
 func (auth *auth) RoleAdd(ctx context.Context, name string) (*AuthRoleAddResponse, error) {
-	resp, err := auth.remote.RoleAdd(ctx, &pb.AuthRoleAddRequest{Name: name}, grpc.FailFast(false))
+	resp, err := auth.remote.RoleAdd(ctx, &pb.AuthRoleAddRequest{Name: name})
 	return (*AuthRoleAddResponse)(resp), toErr(ctx, err)
 }
 
@@ -170,7 +170,7 @@ func (auth *auth) RoleGrantPermission(ctx context.Context, name string, key, ran
 		RangeEnd: []byte(rangeEnd),
 		PermType: authpb.Permission_Type(permType),
 	}
-	resp, err := auth.remote.RoleGrantPermission(ctx, &pb.AuthRoleGrantPermissionRequest{Name: name, Perm: perm}, grpc.FailFast(false))
+	resp, err := auth.remote.RoleGrantPermission(ctx, &pb.AuthRoleGrantPermissionRequest{Name: name, Perm: perm})
 	return (*AuthRoleGrantPermissionResponse)(resp), toErr(ctx, err)
 }
 
@@ -185,12 +185,12 @@ func (auth *auth) RoleList(ctx context.Context) (*AuthRoleListResponse, error) {
 }
 
 func (auth *auth) RoleRevokePermission(ctx context.Context, role string, key, rangeEnd string) (*AuthRoleRevokePermissionResponse, error) {
-	resp, err := auth.remote.RoleRevokePermission(ctx, &pb.AuthRoleRevokePermissionRequest{Role: role, Key: key, RangeEnd: rangeEnd}, grpc.FailFast(false))
+	resp, err := auth.remote.RoleRevokePermission(ctx, &pb.AuthRoleRevokePermissionRequest{Role: role, Key: key, RangeEnd: rangeEnd})
 	return (*AuthRoleRevokePermissionResponse)(resp), toErr(ctx, err)
 }
 
 func (auth *auth) RoleDelete(ctx context.Context, role string) (*AuthRoleDeleteResponse, error) {
-	resp, err := auth.remote.RoleDelete(ctx, &pb.AuthRoleDeleteRequest{Role: role}, grpc.FailFast(false))
+	resp, err := auth.remote.RoleDelete(ctx, &pb.AuthRoleDeleteRequest{Role: role})
 	return (*AuthRoleDeleteResponse)(resp), toErr(ctx, err)
 }
 

+ 99 - 16
clientv3/balancer.go

@@ -17,7 +17,7 @@ package clientv3
 import (
 	"net/url"
 	"strings"
-	"sync/atomic"
+	"sync"
 
 	"golang.org/x/net/context"
 	"google.golang.org/grpc"
@@ -26,32 +26,115 @@ import (
 // simpleBalancer does the bare minimum to expose multiple eps
 // to the grpc reconnection code path
 type simpleBalancer struct {
-	// eps are the client's endpoints stripped of any URL scheme
-	eps     []string
-	ch      chan []grpc.Address
-	numGets uint32
+	// addrs are the client's endpoints for grpc
+	addrs []grpc.Address
+	// notifyCh notifies grpc of the set of addresses for connecting
+	notifyCh chan []grpc.Address
+
+	// readyc closes once the first connection is up
+	readyc    chan struct{}
+	readyOnce sync.Once
+
+	// mu protects upEps, pinAddr, and connectingAddr
+	mu sync.RWMutex
+	// upEps holds the current endpoints that have an active connection
+	upEps map[string]struct{}
+	// upc closes when upEps transitions from empty to non-zero or the balancer closes.
+	upc chan struct{}
+
+	// pinAddr is the currently pinned address; set to the empty string on
+	// intialization and shutdown.
+	pinAddr string
 }
 
-func newSimpleBalancer(eps []string) grpc.Balancer {
-	ch := make(chan []grpc.Address, 1)
+func newSimpleBalancer(eps []string) *simpleBalancer {
+	notifyCh := make(chan []grpc.Address, 1)
 	addrs := make([]grpc.Address, len(eps))
 	for i := range eps {
 		addrs[i].Addr = getHost(eps[i])
 	}
-	ch <- addrs
-	return &simpleBalancer{eps: eps, ch: ch}
+	notifyCh <- addrs
+	sb := &simpleBalancer{
+		addrs:    addrs,
+		notifyCh: notifyCh,
+		readyc:   make(chan struct{}),
+		upEps:    make(map[string]struct{}),
+		upc:      make(chan struct{}),
+	}
+	return sb
+}
+
+func (b *simpleBalancer) Start(target string) error { return nil }
+
+func (b *simpleBalancer) ConnectNotify() <-chan struct{} {
+	b.mu.Lock()
+	defer b.mu.Unlock()
+	return b.upc
+}
+
+func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
+	b.mu.Lock()
+	if len(b.upEps) == 0 {
+		// notify waiting Get()s and pin first connected address
+		close(b.upc)
+		b.pinAddr = addr.Addr
+	}
+	b.upEps[addr.Addr] = struct{}{}
+	b.mu.Unlock()
+	// notify client that a connection is up
+	b.readyOnce.Do(func() { close(b.readyc) })
+	return func(err error) {
+		b.mu.Lock()
+		delete(b.upEps, addr.Addr)
+		if len(b.upEps) == 0 && b.pinAddr != "" {
+			b.upc = make(chan struct{})
+		} else if b.pinAddr == addr.Addr {
+			// choose new random up endpoint
+			for k := range b.upEps {
+				b.pinAddr = k
+				break
+			}
+		}
+		b.mu.Unlock()
+	}
 }
 
-func (b *simpleBalancer) Start(target string) error        { return nil }
-func (b *simpleBalancer) Up(addr grpc.Address) func(error) { return func(error) {} }
 func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
-	v := atomic.AddUint32(&b.numGets, 1)
-	ep := b.eps[v%uint32(len(b.eps))]
-	return grpc.Address{Addr: getHost(ep)}, func() {}, nil
+	var addr string
+	for {
+		b.mu.RLock()
+		ch := b.upc
+		b.mu.RUnlock()
+		select {
+		case <-ch:
+		case <-ctx.Done():
+			return grpc.Address{Addr: ""}, nil, ctx.Err()
+		}
+		b.mu.RLock()
+		addr = b.pinAddr
+		upEps := len(b.upEps)
+		b.mu.RUnlock()
+		if addr == "" {
+			return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
+		}
+		if upEps > 0 {
+			break
+		}
+	}
+	return grpc.Address{Addr: addr}, func() {}, nil
 }
-func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.ch }
+
+func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
+
 func (b *simpleBalancer) Close() error {
-	close(b.ch)
+	b.mu.Lock()
+	close(b.notifyCh)
+	// terminate all waiting Get()s
+	b.pinAddr = ""
+	if len(b.upEps) == 0 {
+		close(b.upc)
+	}
+	b.mu.Unlock()
 	return nil
 }
 

+ 37 - 13
clientv3/client.go

@@ -46,9 +46,11 @@ type Client struct {
 	Auth
 	Maintenance
 
-	conn  *grpc.ClientConn
-	cfg   Config
-	creds *credentials.TransportCredentials
+	conn         *grpc.ClientConn
+	cfg          Config
+	creds        *credentials.TransportCredentials
+	balancer     *simpleBalancer
+	retryWrapper retryRpcFunc
 
 	ctx    context.Context
 	cancel context.CancelFunc
@@ -141,10 +143,7 @@ func (c *Client) dialTarget(endpoint string) (proto string, host string, creds *
 // dialSetupOpts gives the dial opts prior to any authentication
 func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts []grpc.DialOption) {
 	if c.cfg.DialTimeout > 0 {
-		opts = []grpc.DialOption{
-			grpc.WithTimeout(c.cfg.DialTimeout),
-			grpc.WithBlock(),
-		}
+		opts = []grpc.DialOption{grpc.WithTimeout(c.cfg.DialTimeout)}
 	}
 	opts = append(opts, dopts...)
 
@@ -242,12 +241,30 @@ func newClient(cfg *Config) (*Client, error) {
 		client.Password = cfg.Password
 	}
 
-	b := newSimpleBalancer(cfg.Endpoints)
-	conn, err := client.dial(cfg.Endpoints[0], grpc.WithBalancer(b))
+	client.balancer = newSimpleBalancer(cfg.Endpoints)
+	conn, err := client.dial(cfg.Endpoints[0], grpc.WithBalancer(client.balancer))
 	if err != nil {
 		return nil, err
 	}
 	client.conn = conn
+	client.retryWrapper = client.newRetryWrapper()
+
+	// wait for a connection
+	if cfg.DialTimeout > 0 {
+		hasConn := false
+		waitc := time.After(cfg.DialTimeout)
+		select {
+		case <-client.balancer.readyc:
+			hasConn = true
+		case <-ctx.Done():
+		case <-waitc:
+		}
+		if !hasConn {
+			client.cancel()
+			conn.Close()
+			return nil, grpc.ErrClientConnTimeout
+		}
+	}
 
 	client.Cluster = NewCluster(client)
 	client.KV = NewKV(client)
@@ -282,8 +299,12 @@ func isHaltErr(ctx context.Context, err error) bool {
 		return eErr != rpctypes.ErrStopped && eErr != rpctypes.ErrNoLeader
 	}
 	// treat etcdserver errors not recognized by the client as halting
-	return strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error()) ||
-		strings.Contains(err.Error(), "etcdserver:")
+	return isConnClosing(err) || strings.Contains(err.Error(), "etcdserver:")
+}
+
+// isConnClosing returns true if the error matches a grpc client closing error
+func isConnClosing(err error) bool {
+	return strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error())
 }
 
 func toErr(ctx context.Context, err error) error {
@@ -291,9 +312,12 @@ func toErr(ctx context.Context, err error) error {
 		return nil
 	}
 	err = rpctypes.Error(err)
-	if ctx.Err() != nil && strings.Contains(err.Error(), "context") {
+	switch {
+	case ctx.Err() != nil && strings.Contains(err.Error(), "context"):
 		err = ctx.Err()
-	} else if strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error()) {
+	case strings.Contains(err.Error(), ErrNoAvailableEndpoints.Error()):
+		err = ErrNoAvailableEndpoints
+	case strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error()):
 		err = grpc.ErrClientConnClosing
 	}
 	return err

+ 3 - 0
clientv3/client_test.go

@@ -20,11 +20,14 @@ import (
 	"time"
 
 	"github.com/coreos/etcd/etcdserver"
+	"github.com/coreos/etcd/pkg/testutil"
 	"golang.org/x/net/context"
 	"google.golang.org/grpc"
 )
 
 func TestDialTimeout(t *testing.T) {
+	defer testutil.AfterTest(t)
+
 	donec := make(chan error)
 	go func() {
 		// without timeout, grpc keeps redialing if connection refused

+ 4 - 4
clientv3/cluster.go

@@ -47,12 +47,12 @@ type cluster struct {
 }
 
 func NewCluster(c *Client) Cluster {
-	return &cluster{remote: pb.NewClusterClient(c.conn)}
+	return &cluster{remote: RetryClusterClient(c)}
 }
 
 func (c *cluster) MemberAdd(ctx context.Context, peerAddrs []string) (*MemberAddResponse, error) {
 	r := &pb.MemberAddRequest{PeerURLs: peerAddrs}
-	resp, err := c.remote.MemberAdd(ctx, r, grpc.FailFast(false))
+	resp, err := c.remote.MemberAdd(ctx, r)
 	if err == nil {
 		return (*MemberAddResponse)(resp), nil
 	}
@@ -64,7 +64,7 @@ func (c *cluster) MemberAdd(ctx context.Context, peerAddrs []string) (*MemberAdd
 
 func (c *cluster) MemberRemove(ctx context.Context, id uint64) (*MemberRemoveResponse, error) {
 	r := &pb.MemberRemoveRequest{ID: id}
-	resp, err := c.remote.MemberRemove(ctx, r, grpc.FailFast(false))
+	resp, err := c.remote.MemberRemove(ctx, r)
 	if err == nil {
 		return (*MemberRemoveResponse)(resp), nil
 	}
@@ -78,7 +78,7 @@ func (c *cluster) MemberUpdate(ctx context.Context, id uint64, peerAddrs []strin
 	// it is safe to retry on update.
 	for {
 		r := &pb.MemberUpdateRequest{ID: id, PeerURLs: peerAddrs}
-		resp, err := c.remote.MemberUpdate(ctx, r, grpc.FailFast(false))
+		resp, err := c.remote.MemberUpdate(ctx, r)
 		if err == nil {
 			return (*MemberUpdateResponse)(resp), nil
 		}

+ 73 - 0
clientv3/integration/kv_test.go

@@ -16,6 +16,7 @@ package integration
 
 import (
 	"bytes"
+	"math/rand"
 	"reflect"
 	"strings"
 	"testing"
@@ -662,3 +663,75 @@ func TestKVPutStoppedServerAndClose(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+// TestKVGetOneEndpointDown ensures a client can connect and get if one endpoint is down
+func TestKVPutOneEndpointDown(t *testing.T) {
+	defer testutil.AfterTest(t)
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
+	defer clus.Terminate(t)
+
+	// get endpoint list
+	eps := make([]string, 3)
+	for i := range eps {
+		eps[i] = clus.Members[i].GRPCAddr()
+	}
+
+	// make a dead node
+	clus.Members[rand.Intn(len(eps))].Stop(t)
+
+	// try to connect with dead node in the endpoint list
+	cfg := clientv3.Config{Endpoints: eps, DialTimeout: 1 * time.Second}
+	cli, err := clientv3.New(cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cli.Close()
+	ctx, cancel := context.WithTimeout(context.TODO(), 3*time.Second)
+	if _, err := cli.Get(ctx, "abc", clientv3.WithSerializable()); err != nil {
+		t.Fatal(err)
+	}
+	cancel()
+}
+
+// TestKVGetResetLoneEndpoint ensures that if an endpoint resets and all other
+// endpoints are down, then it will reconnect.
+func TestKVGetResetLoneEndpoint(t *testing.T) {
+	defer testutil.AfterTest(t)
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 2})
+	defer clus.Terminate(t)
+
+	// get endpoint list
+	eps := make([]string, 2)
+	for i := range eps {
+		eps[i] = clus.Members[i].GRPCAddr()
+	}
+
+	cfg := clientv3.Config{Endpoints: eps, DialTimeout: 500 * time.Millisecond}
+	cli, err := clientv3.New(cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cli.Close()
+
+	// disconnect everything
+	clus.Members[0].Stop(t)
+	clus.Members[1].Stop(t)
+
+	// have Get try to reconnect
+	donec := make(chan struct{})
+	go func() {
+		ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
+		if _, err := cli.Get(ctx, "abc", clientv3.WithSerializable()); err != nil {
+			t.Fatal(err)
+		}
+		cancel()
+		close(donec)
+	}()
+	time.Sleep(500 * time.Millisecond)
+	clus.Members[0].Restart(t)
+	select {
+	case <-time.After(10 * time.Second):
+		t.Fatalf("timed out waiting for Get")
+	case <-donec:
+	}
+}

+ 3 - 3
clientv3/kv.go

@@ -82,7 +82,7 @@ type kv struct {
 }
 
 func NewKV(c *Client) KV {
-	return &kv{remote: pb.NewKVClient(c.conn)}
+	return &kv{remote: RetryKVClient(c)}
 }
 
 func NewKVFromKVClient(remote pb.KVClient) KV {
@@ -162,14 +162,14 @@ func (kv *kv) do(ctx context.Context, op Op) (OpResponse, error) {
 	case tPut:
 		var resp *pb.PutResponse
 		r := &pb.PutRequest{Key: op.key, Value: op.val, Lease: int64(op.leaseID), PrevKv: op.prevKV}
-		resp, err = kv.remote.Put(ctx, r, grpc.FailFast(false))
+		resp, err = kv.remote.Put(ctx, r)
 		if err == nil {
 			return OpResponse{put: (*PutResponse)(resp)}, nil
 		}
 	case tDeleteRange:
 		var resp *pb.DeleteRangeResponse
 		r := &pb.DeleteRangeRequest{Key: op.key, RangeEnd: op.end, PrevKv: op.prevKV}
-		resp, err = kv.remote.DeleteRange(ctx, r, grpc.FailFast(false))
+		resp, err = kv.remote.DeleteRange(ctx, r)
 		if err == nil {
 			return OpResponse{del: (*DeleteResponse)(resp)}, nil
 		}

+ 3 - 3
clientv3/lease.go

@@ -110,7 +110,7 @@ func NewLease(c *Client) Lease {
 	l := &lessor{
 		donec:                 make(chan struct{}),
 		keepAlives:            make(map[LeaseID]*keepAlive),
-		remote:                pb.NewLeaseClient(c.conn),
+		remote:                RetryLeaseClient(c),
 		firstKeepAliveTimeout: c.cfg.DialTimeout + time.Second,
 	}
 	if l.firstKeepAliveTimeout == time.Second {
@@ -130,7 +130,7 @@ func (l *lessor) Grant(ctx context.Context, ttl int64) (*LeaseGrantResponse, err
 
 	for {
 		r := &pb.LeaseGrantRequest{TTL: ttl}
-		resp, err := l.remote.LeaseGrant(cctx, r, grpc.FailFast(false))
+		resp, err := l.remote.LeaseGrant(cctx, r)
 		if err == nil {
 			gresp := &LeaseGrantResponse{
 				ResponseHeader: resp.GetHeader(),
@@ -156,7 +156,7 @@ func (l *lessor) Revoke(ctx context.Context, id LeaseID) (*LeaseRevokeResponse,
 
 	for {
 		r := &pb.LeaseRevokeRequest{ID: int64(id)}
-		resp, err := l.remote.LeaseRevoke(cctx, r, grpc.FailFast(false))
+		resp, err := l.remote.LeaseRevoke(cctx, r)
 
 		if err == nil {
 			return (*LeaseRevokeResponse)(resp), nil

+ 243 - 0
clientv3/retry.go

@@ -0,0 +1,243 @@
+// Copyright 2016 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package clientv3
+
+import (
+	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
+	"golang.org/x/net/context"
+	"google.golang.org/grpc"
+)
+
+type rpcFunc func(ctx context.Context) error
+type retryRpcFunc func(context.Context, rpcFunc)
+
+func (c *Client) newRetryWrapper() retryRpcFunc {
+	return func(rpcCtx context.Context, f rpcFunc) {
+		for {
+			err := f(rpcCtx)
+			// ignore grpc conn closing on fail-fast calls; they are transient errors
+			if err == nil || !isConnClosing(err) {
+				return
+			}
+			select {
+			case <-c.balancer.ConnectNotify():
+			case <-rpcCtx.Done():
+			case <-c.ctx.Done():
+				return
+			}
+		}
+	}
+}
+
+type retryKVClient struct {
+	pb.KVClient
+	retryf retryRpcFunc
+}
+
+// RetryKVClient implements a KVClient that uses the client's FailFast retry policy.
+func RetryKVClient(c *Client) pb.KVClient {
+	return &retryKVClient{pb.NewKVClient(c.conn), c.retryWrapper}
+}
+
+func (rkv *retryKVClient) Put(ctx context.Context, in *pb.PutRequest, opts ...grpc.CallOption) (resp *pb.PutResponse, err error) {
+	rkv.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rkv.KVClient.Put(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rkv *retryKVClient) DeleteRange(ctx context.Context, in *pb.DeleteRangeRequest, opts ...grpc.CallOption) (resp *pb.DeleteRangeResponse, err error) {
+	rkv.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rkv.KVClient.DeleteRange(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rkv *retryKVClient) Txn(ctx context.Context, in *pb.TxnRequest, opts ...grpc.CallOption) (resp *pb.TxnResponse, err error) {
+	rkv.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rkv.KVClient.Txn(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rkv *retryKVClient) Compact(ctx context.Context, in *pb.CompactionRequest, opts ...grpc.CallOption) (resp *pb.CompactionResponse, err error) {
+	rkv.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rkv.KVClient.Compact(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+type retryLeaseClient struct {
+	pb.LeaseClient
+	retryf retryRpcFunc
+}
+
+// RetryLeaseClient implements a LeaseClient that uses the client's FailFast retry policy.
+func RetryLeaseClient(c *Client) pb.LeaseClient {
+	return &retryLeaseClient{pb.NewLeaseClient(c.conn), c.retryWrapper}
+}
+
+func (rlc *retryLeaseClient) LeaseGrant(ctx context.Context, in *pb.LeaseGrantRequest, opts ...grpc.CallOption) (resp *pb.LeaseGrantResponse, err error) {
+	rlc.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rlc.LeaseClient.LeaseGrant(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+
+}
+
+func (rlc *retryLeaseClient) LeaseRevoke(ctx context.Context, in *pb.LeaseRevokeRequest, opts ...grpc.CallOption) (resp *pb.LeaseRevokeResponse, err error) {
+	rlc.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rlc.LeaseClient.LeaseRevoke(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+type retryClusterClient struct {
+	pb.ClusterClient
+	retryf retryRpcFunc
+}
+
+// RetryClusterClient implements a ClusterClient that uses the client's FailFast retry policy.
+func RetryClusterClient(c *Client) pb.ClusterClient {
+	return &retryClusterClient{pb.NewClusterClient(c.conn), c.retryWrapper}
+}
+
+func (rcc *retryClusterClient) MemberAdd(ctx context.Context, in *pb.MemberAddRequest, opts ...grpc.CallOption) (resp *pb.MemberAddResponse, err error) {
+	rcc.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rcc.ClusterClient.MemberAdd(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rcc *retryClusterClient) MemberRemove(ctx context.Context, in *pb.MemberRemoveRequest, opts ...grpc.CallOption) (resp *pb.MemberRemoveResponse, err error) {
+	rcc.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rcc.ClusterClient.MemberRemove(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rcc *retryClusterClient) MemberUpdate(ctx context.Context, in *pb.MemberUpdateRequest, opts ...grpc.CallOption) (resp *pb.MemberUpdateResponse, err error) {
+	rcc.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rcc.ClusterClient.MemberUpdate(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+type retryAuthClient struct {
+	pb.AuthClient
+	retryf retryRpcFunc
+}
+
+// RetryAuthClient implements a AuthClient that uses the client's FailFast retry policy.
+func RetryAuthClient(c *Client) pb.AuthClient {
+	return &retryAuthClient{pb.NewAuthClient(c.conn), c.retryWrapper}
+}
+
+func (rac *retryAuthClient) AuthEnable(ctx context.Context, in *pb.AuthEnableRequest, opts ...grpc.CallOption) (resp *pb.AuthEnableResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.AuthEnable(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) AuthDisable(ctx context.Context, in *pb.AuthDisableRequest, opts ...grpc.CallOption) (resp *pb.AuthDisableResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.AuthDisable(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) UserAdd(ctx context.Context, in *pb.AuthUserAddRequest, opts ...grpc.CallOption) (resp *pb.AuthUserAddResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.UserAdd(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) UserDelete(ctx context.Context, in *pb.AuthUserDeleteRequest, opts ...grpc.CallOption) (resp *pb.AuthUserDeleteResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.UserDelete(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) UserChangePassword(ctx context.Context, in *pb.AuthUserChangePasswordRequest, opts ...grpc.CallOption) (resp *pb.AuthUserChangePasswordResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.UserChangePassword(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) UserGrantRole(ctx context.Context, in *pb.AuthUserGrantRoleRequest, opts ...grpc.CallOption) (resp *pb.AuthUserGrantRoleResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.UserGrantRole(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) UserRevokeRole(ctx context.Context, in *pb.AuthUserRevokeRoleRequest, opts ...grpc.CallOption) (resp *pb.AuthUserRevokeRoleResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.UserRevokeRole(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) RoleAdd(ctx context.Context, in *pb.AuthRoleAddRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleAddResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.RoleAdd(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) RoleDelete(ctx context.Context, in *pb.AuthRoleDeleteRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleDeleteResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.RoleDelete(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) RoleGrantPermission(ctx context.Context, in *pb.AuthRoleGrantPermissionRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleGrantPermissionResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.RoleGrantPermission(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+func (rac *retryAuthClient) RoleRevokePermission(ctx context.Context, in *pb.AuthRoleRevokePermissionRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleRevokePermissionResponse, err error) {
+	rac.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rac.AuthClient.RoleRevokePermission(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}

+ 1 - 2
clientv3/txn.go

@@ -19,7 +19,6 @@ import (
 
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 	"golang.org/x/net/context"
-	"google.golang.org/grpc"
 )
 
 // Txn is the interface that wraps mini-transactions.
@@ -153,7 +152,7 @@ func (txn *txn) Commit() (*TxnResponse, error) {
 
 func (txn *txn) commit() (*TxnResponse, error) {
 	r := &pb.TxnRequest{Compare: txn.cmps, Success: txn.sus, Failure: txn.fas}
-	resp, err := txn.kv.remote.Txn(txn.ctx, r, grpc.FailFast(false))
+	resp, err := txn.kv.remote.Txn(txn.ctx, r)
 	if err != nil {
 		return nil, err
 	}

+ 19 - 6
cmd/vendor/google.golang.org/grpc/call.go

@@ -36,6 +36,7 @@ package grpc
 import (
 	"bytes"
 	"io"
+	"math"
 	"time"
 
 	"golang.org/x/net/context"
@@ -51,13 +52,20 @@ import (
 func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
 	// Try to acquire header metadata from the server if there is any.
 	var err error
+	defer func() {
+		if err != nil {
+			if _, ok := err.(transport.ConnectionError); !ok {
+				t.CloseStream(stream, err)
+			}
+		}
+	}()
 	c.headerMD, err = stream.Header()
 	if err != nil {
 		return err
 	}
 	p := &parser{r: stream}
 	for {
-		if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
+		if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
 			if err == io.EOF {
 				break
 			}
@@ -76,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
 	}
 	defer func() {
 		if err != nil {
+			// If err is connection error, t will be closed, no need to close stream here.
 			if _, ok := err.(transport.ConnectionError); !ok {
 				t.CloseStream(stream, err)
 			}
@@ -90,7 +99,10 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
 		return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
 	}
 	err = t.Write(stream, outBuf, opts)
-	if err != nil {
+	// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
+	// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
+	// recvResponse to get the final status.
+	if err != nil && err != io.EOF {
 		return nil, err
 	}
 	// Sent successfully.
@@ -176,7 +188,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 				put()
 				put = nil
 			}
-			if _, ok := err.(transport.ConnectionError); ok {
+			// Retry a non-failfast RPC when
+			// i) there is a connection error; or
+			// ii) the server started to drain before this RPC was initiated.
+			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
 				if c.failFast {
 					return toRPCErr(err)
 				}
@@ -184,20 +199,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 			}
 			return toRPCErr(err)
 		}
-		// Receive the response
 		err = recvResponse(cc.dopts, t, &c, stream, reply)
 		if err != nil {
 			if put != nil {
 				put()
 				put = nil
 			}
-			if _, ok := err.(transport.ConnectionError); ok {
+			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
 				if c.failFast {
 					return toRPCErr(err)
 				}
 				continue
 			}
-			t.CloseStream(stream, err)
 			return toRPCErr(err)
 		}
 		if c.traceInfo.tr != nil {

+ 205 - 97
cmd/vendor/google.golang.org/grpc/clientconn.go

@@ -43,7 +43,6 @@ import (
 
 	"golang.org/x/net/context"
 	"golang.org/x/net/trace"
-	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/credentials"
 	"google.golang.org/grpc/grpclog"
 	"google.golang.org/grpc/transport"
@@ -68,7 +67,7 @@ var (
 	// errCredentialsConflict indicates that grpc.WithTransportCredentials()
 	// and grpc.WithInsecure() are both called for a connection.
 	errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
-	// errNetworkIP indicates that the connection is down due to some network I/O error.
+	// errNetworkIO indicates that the connection is down due to some network I/O error.
 	errNetworkIO = errors.New("grpc: failed with network I/O error")
 	// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
 	errConnDrain = errors.New("grpc: the connection is drained")
@@ -196,9 +195,14 @@ func WithTimeout(d time.Duration) DialOption {
 }
 
 // WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
-func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption {
+func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
 	return func(o *dialOptions) {
-		o.copts.Dialer = f
+		o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
+			if deadline, ok := ctx.Deadline(); ok {
+				return f(addr, deadline.Sub(time.Now()))
+			}
+			return f(addr, 0)
+		}
 	}
 }
 
@@ -211,10 +215,12 @@ func WithUserAgent(s string) DialOption {
 
 // Dial creates a client connection the given target.
 func Dial(target string, opts ...DialOption) (*ClientConn, error) {
+	ctx := context.Background()
 	cc := &ClientConn{
 		target: target,
 		conns:  make(map[Address]*addrConn),
 	}
+	cc.ctx, cc.cancel = context.WithCancel(ctx)
 	for _, opt := range opts {
 		opt(&cc.dopts)
 	}
@@ -226,31 +232,33 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
 	if cc.dopts.bs == nil {
 		cc.dopts.bs = DefaultBackoffConfig
 	}
-	if cc.dopts.balancer == nil {
-		cc.dopts.balancer = RoundRobin(nil)
-	}
 
-	if err := cc.dopts.balancer.Start(target); err != nil {
-		return nil, err
-	}
 	var (
 		ok    bool
 		addrs []Address
 	)
-	ch := cc.dopts.balancer.Notify()
-	if ch == nil {
-		// There is no name resolver installed.
+	if cc.dopts.balancer == nil {
+		// Connect to target directly if balancer is nil.
 		addrs = append(addrs, Address{Addr: target})
 	} else {
-		addrs, ok = <-ch
-		if !ok || len(addrs) == 0 {
-			return nil, errNoAddr
+		if err := cc.dopts.balancer.Start(target); err != nil {
+			return nil, err
+		}
+		ch := cc.dopts.balancer.Notify()
+		if ch == nil {
+			// There is no name resolver installed.
+			addrs = append(addrs, Address{Addr: target})
+		} else {
+			addrs, ok = <-ch
+			if !ok || len(addrs) == 0 {
+				return nil, errNoAddr
+			}
 		}
 	}
 	waitC := make(chan error, 1)
 	go func() {
 		for _, a := range addrs {
-			if err := cc.newAddrConn(a, false); err != nil {
+			if err := cc.resetAddrConn(a, false, nil); err != nil {
 				waitC <- err
 				return
 			}
@@ -267,10 +275,15 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
 			cc.Close()
 			return nil, err
 		}
+	case <-cc.ctx.Done():
+		cc.Close()
+		return nil, cc.ctx.Err()
 	case <-timeoutCh:
 		cc.Close()
 		return nil, ErrClientConnTimeout
 	}
+	// If balancer is nil or balancer.Notify() is nil, ok will be false here.
+	// The lbWatcher goroutine will not be created.
 	if ok {
 		go cc.lbWatcher()
 	}
@@ -317,6 +330,9 @@ func (s ConnectivityState) String() string {
 
 // ClientConn represents a client connection to an RPC server.
 type ClientConn struct {
+	ctx    context.Context
+	cancel context.CancelFunc
+
 	target    string
 	authority string
 	dopts     dialOptions
@@ -347,11 +363,12 @@ func (cc *ClientConn) lbWatcher() {
 			}
 			if !keep {
 				del = append(del, c)
+				delete(cc.conns, c.addr)
 			}
 		}
 		cc.mu.Unlock()
 		for _, a := range add {
-			cc.newAddrConn(a, true)
+			cc.resetAddrConn(a, true, nil)
 		}
 		for _, c := range del {
 			c.tearDown(errConnDrain)
@@ -359,13 +376,17 @@ func (cc *ClientConn) lbWatcher() {
 	}
 }
 
-func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
+// resetAddrConn creates an addrConn for addr and adds it to cc.conns.
+// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason.
+// If tearDownErr is nil, errConnDrain will be used instead.
+func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error {
 	ac := &addrConn{
-		cc:           cc,
-		addr:         addr,
-		dopts:        cc.dopts,
-		shutdownChan: make(chan struct{}),
+		cc:    cc,
+		addr:  addr,
+		dopts: cc.dopts,
 	}
+	ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
+	ac.stateCV = sync.NewCond(&ac.mu)
 	if EnableTracing {
 		ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
 	}
@@ -383,26 +404,44 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
 			}
 		}
 	}
-	// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
-	ac.cc.mu.Lock()
-	if ac.cc.conns == nil {
-		ac.cc.mu.Unlock()
+	// Track ac in cc. This needs to be done before any getTransport(...) is called.
+	cc.mu.Lock()
+	if cc.conns == nil {
+		cc.mu.Unlock()
 		return ErrClientConnClosing
 	}
-	stale := ac.cc.conns[ac.addr]
-	ac.cc.conns[ac.addr] = ac
-	ac.cc.mu.Unlock()
+	stale := cc.conns[ac.addr]
+	cc.conns[ac.addr] = ac
+	cc.mu.Unlock()
 	if stale != nil {
 		// There is an addrConn alive on ac.addr already. This could be due to
-		// i) stale's Close is undergoing;
-		// ii) a buggy Balancer notifies duplicated Addresses.
-		stale.tearDown(errConnDrain)
+		// 1) a buggy Balancer notifies duplicated Addresses;
+		// 2) goaway was received, a new ac will replace the old ac.
+		//    The old ac should be deleted from cc.conns, but the
+		//    underlying transport should drain rather than close.
+		if tearDownErr == nil {
+			// tearDownErr is nil if resetAddrConn is called by
+			// 1) Dial
+			// 2) lbWatcher
+			// In both cases, the stale ac should drain, not close.
+			stale.tearDown(errConnDrain)
+		} else {
+			stale.tearDown(tearDownErr)
+		}
 	}
-	ac.stateCV = sync.NewCond(&ac.mu)
 	// skipWait may overwrite the decision in ac.dopts.block.
 	if ac.dopts.block && !skipWait {
 		if err := ac.resetTransport(false); err != nil {
-			ac.tearDown(err)
+			if err != errConnClosing {
+				// Tear down ac and delete it from cc.conns.
+				cc.mu.Lock()
+				delete(cc.conns, ac.addr)
+				cc.mu.Unlock()
+				ac.tearDown(err)
+			}
+			if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
+				return e.Origin()
+			}
 			return err
 		}
 		// Start to monitor the error status of transport.
@@ -412,7 +451,10 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
 		go func() {
 			if err := ac.resetTransport(false); err != nil {
 				grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
-				ac.tearDown(err)
+				if err != errConnClosing {
+					// Keep this ac in cc.conns, to get the reason it's torn down.
+					ac.tearDown(err)
+				}
 				return
 			}
 			ac.transportMonitor()
@@ -422,24 +464,48 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
 }
 
 func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
-	addr, put, err := cc.dopts.balancer.Get(ctx, opts)
-	if err != nil {
-		return nil, nil, toRPCErr(err)
-	}
-	cc.mu.RLock()
-	if cc.conns == nil {
+	var (
+		ac  *addrConn
+		ok  bool
+		put func()
+	)
+	if cc.dopts.balancer == nil {
+		// If balancer is nil, there should be only one addrConn available.
+		cc.mu.RLock()
+		for _, ac = range cc.conns {
+			// Break after the first iteration to get the first addrConn.
+			ok = true
+			break
+		}
+		cc.mu.RUnlock()
+	} else {
+		var (
+			addr Address
+			err  error
+		)
+		addr, put, err = cc.dopts.balancer.Get(ctx, opts)
+		if err != nil {
+			return nil, nil, toRPCErr(err)
+		}
+		cc.mu.RLock()
+		if cc.conns == nil {
+			cc.mu.RUnlock()
+			return nil, nil, toRPCErr(ErrClientConnClosing)
+		}
+		ac, ok = cc.conns[addr]
 		cc.mu.RUnlock()
-		return nil, nil, toRPCErr(ErrClientConnClosing)
 	}
-	ac, ok := cc.conns[addr]
-	cc.mu.RUnlock()
 	if !ok {
 		if put != nil {
 			put()
 		}
-		return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
+		return nil, nil, errConnClosing
 	}
-	t, err := ac.wait(ctx, !opts.BlockingWait)
+	// ac.wait should block on transient failure only if balancer is nil and RPC is non-failfast.
+	//  - If RPC is failfast, ac.wait should not block.
+	//  - If balancer is not nil, ac.wait should return errConnClosing on transient failure
+	//    so that non-failfast RPCs will try to get a new transport instead of waiting on ac.
+	t, err := ac.wait(ctx, cc.dopts.balancer == nil && opts.BlockingWait)
 	if err != nil {
 		if put != nil {
 			put()
@@ -451,6 +517,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
 
 // Close tears down the ClientConn and all underlying connections.
 func (cc *ClientConn) Close() error {
+	cc.cancel()
+
 	cc.mu.Lock()
 	if cc.conns == nil {
 		cc.mu.Unlock()
@@ -459,7 +527,9 @@ func (cc *ClientConn) Close() error {
 	conns := cc.conns
 	cc.conns = nil
 	cc.mu.Unlock()
-	cc.dopts.balancer.Close()
+	if cc.dopts.balancer != nil {
+		cc.dopts.balancer.Close()
+	}
 	for _, ac := range conns {
 		ac.tearDown(ErrClientConnClosing)
 	}
@@ -468,11 +538,13 @@ func (cc *ClientConn) Close() error {
 
 // addrConn is a network connection to a given address.
 type addrConn struct {
-	cc           *ClientConn
-	addr         Address
-	dopts        dialOptions
-	shutdownChan chan struct{}
-	events       trace.EventLog
+	ctx    context.Context
+	cancel context.CancelFunc
+
+	cc     *ClientConn
+	addr   Address
+	dopts  dialOptions
+	events trace.EventLog
 
 	mu      sync.Mutex
 	state   ConnectivityState
@@ -482,6 +554,9 @@ type addrConn struct {
 	// due to timeout.
 	ready     chan struct{}
 	transport transport.ClientTransport
+
+	// The reason this addrConn is torn down.
+	tearDownErr error
 }
 
 // printf records an event in ac's event log, unless ac has been closed.
@@ -537,8 +612,7 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti
 }
 
 func (ac *addrConn) resetTransport(closeTransport bool) error {
-	var retries int
-	for {
+	for retries := 0; ; retries++ {
 		ac.mu.Lock()
 		ac.printf("connecting")
 		if ac.state == Shutdown {
@@ -558,13 +632,20 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
 			t.Close()
 		}
 		sleepTime := ac.dopts.bs.backoff(retries)
-		ac.dopts.copts.Timeout = sleepTime
-		if sleepTime < minConnectTimeout {
-			ac.dopts.copts.Timeout = minConnectTimeout
+		timeout := minConnectTimeout
+		if timeout < sleepTime {
+			timeout = sleepTime
 		}
+		ctx, cancel := context.WithTimeout(ac.ctx, timeout)
 		connectTime := time.Now()
-		newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts)
+		newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
 		if err != nil {
+			cancel()
+
+			if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
+				return err
+			}
+			grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
 			ac.mu.Lock()
 			if ac.state == Shutdown {
 				// ac.tearDown(...) has been invoked.
@@ -579,17 +660,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
 				ac.ready = nil
 			}
 			ac.mu.Unlock()
-			sleepTime -= time.Since(connectTime)
-			if sleepTime < 0 {
-				sleepTime = 0
-			}
 			closeTransport = false
 			select {
-			case <-time.After(sleepTime):
-			case <-ac.shutdownChan:
+			case <-time.After(sleepTime - time.Since(connectTime)):
+			case <-ac.ctx.Done():
+				return ac.ctx.Err()
 			}
-			retries++
-			grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
 			continue
 		}
 		ac.mu.Lock()
@@ -607,7 +683,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
 			close(ac.ready)
 			ac.ready = nil
 		}
-		ac.down = ac.cc.dopts.balancer.Up(ac.addr)
+		if ac.cc.dopts.balancer != nil {
+			ac.down = ac.cc.dopts.balancer.Up(ac.addr)
+		}
 		ac.mu.Unlock()
 		return nil
 	}
@@ -621,14 +699,42 @@ func (ac *addrConn) transportMonitor() {
 		t := ac.transport
 		ac.mu.Unlock()
 		select {
-		// shutdownChan is needed to detect the teardown when
+		// This is needed to detect the teardown when
 		// the addrConn is idle (i.e., no RPC in flight).
-		case <-ac.shutdownChan:
+		case <-ac.ctx.Done():
+			select {
+			case <-t.Error():
+				t.Close()
+			default:
+			}
+			return
+		case <-t.GoAway():
+			// If GoAway happens without any network I/O error, ac is closed without shutting down the
+			// underlying transport (the transport will be closed when all the pending RPCs finished or
+			// failed.).
+			// If GoAway and some network I/O error happen concurrently, ac and its underlying transport
+			// are closed.
+			// In both cases, a new ac is created.
+			select {
+			case <-t.Error():
+				ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
+			default:
+				ac.cc.resetAddrConn(ac.addr, true, errConnDrain)
+			}
 			return
 		case <-t.Error():
+			select {
+			case <-ac.ctx.Done():
+				t.Close()
+				return
+			case <-t.GoAway():
+				ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
+				return
+			default:
+			}
 			ac.mu.Lock()
 			if ac.state == Shutdown {
-				// ac.tearDown(...) has been invoked.
+				// ac has been shutdown.
 				ac.mu.Unlock()
 				return
 			}
@@ -640,6 +746,10 @@ func (ac *addrConn) transportMonitor() {
 				ac.printf("transport exiting: %v", err)
 				ac.mu.Unlock()
 				grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
+				if err != errConnClosing {
+					// Keep this ac in cc.conns, to get the reason it's torn down.
+					ac.tearDown(err)
+				}
 				return
 			}
 		}
@@ -647,21 +757,22 @@ func (ac *addrConn) transportMonitor() {
 }
 
 // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
-// iv) transport is in TransientFailure and the RPC is fail-fast.
-func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) {
+// iv) transport is in TransientFailure and blocking is false.
+func (ac *addrConn) wait(ctx context.Context, blocking bool) (transport.ClientTransport, error) {
 	for {
 		ac.mu.Lock()
 		switch {
 		case ac.state == Shutdown:
+			err := ac.tearDownErr
 			ac.mu.Unlock()
-			return nil, errConnClosing
+			return nil, err
 		case ac.state == Ready:
 			ct := ac.transport
 			ac.mu.Unlock()
 			return ct, nil
-		case ac.state == TransientFailure && failFast:
+		case ac.state == TransientFailure && !blocking:
 			ac.mu.Unlock()
-			return nil, Errorf(codes.Unavailable, "grpc: RPC failed fast due to transport failure")
+			return nil, errConnClosing
 		default:
 			ready := ac.ready
 			if ready == nil {
@@ -683,24 +794,28 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr
 // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
 // some edge cases (e.g., the caller opens and closes many addrConn's in a
 // tight loop.
+// tearDown doesn't remove ac from ac.cc.conns.
 func (ac *addrConn) tearDown(err error) {
+	ac.cancel()
+
 	ac.mu.Lock()
-	defer func() {
-		ac.mu.Unlock()
-		ac.cc.mu.Lock()
-		if ac.cc.conns != nil {
-			delete(ac.cc.conns, ac.addr)
-		}
-		ac.cc.mu.Unlock()
-	}()
-	if ac.state == Shutdown {
-		return
-	}
-	ac.state = Shutdown
+	defer ac.mu.Unlock()
 	if ac.down != nil {
 		ac.down(downErrorf(false, false, "%v", err))
 		ac.down = nil
 	}
+	if err == errConnDrain && ac.transport != nil {
+		// GracefulClose(...) may be executed multiple times when
+		// i) receiving multiple GoAway frames from the server; or
+		// ii) there are concurrent name resolver/Balancer triggered
+		// address removal and GoAway.
+		ac.transport.GracefulClose()
+	}
+	if ac.state == Shutdown {
+		return
+	}
+	ac.state = Shutdown
+	ac.tearDownErr = err
 	ac.stateCV.Broadcast()
 	if ac.events != nil {
 		ac.events.Finish()
@@ -710,15 +825,8 @@ func (ac *addrConn) tearDown(err error) {
 		close(ac.ready)
 		ac.ready = nil
 	}
-	if ac.transport != nil {
-		if err == errConnDrain {
-			ac.transport.GracefulClose()
-		} else {
-			ac.transport.Close()
-		}
-	}
-	if ac.shutdownChan != nil {
-		close(ac.shutdownChan)
+	if ac.transport != nil && err != errConnDrain {
+		ac.transport.Close()
 	}
 	return
 }

+ 19 - 34
cmd/vendor/google.golang.org/grpc/credentials/credentials.go

@@ -44,7 +44,6 @@ import (
 	"io/ioutil"
 	"net"
 	"strings"
-	"time"
 
 	"golang.org/x/net/context"
 )
@@ -93,11 +92,12 @@ type TransportCredentials interface {
 	// ClientHandshake does the authentication handshake specified by the corresponding
 	// authentication protocol on rawConn for clients. It returns the authenticated
 	// connection and the corresponding auth information about the connection.
-	ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error)
+	// Implementations must use the provided context to implement timely cancellation.
+	ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
 	// ServerHandshake does the authentication handshake for servers. It returns
 	// the authenticated connection and the corresponding auth information about
 	// the connection.
-	ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
+	ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
 	// Info provides the ProtocolInfo of this TransportCredentials.
 	Info() ProtocolInfo
 }
@@ -136,42 +136,28 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
 	return true
 }
 
-type timeoutError struct{}
-
-func (timeoutError) Error() string   { return "credentials: Dial timed out" }
-func (timeoutError) Timeout() bool   { return true }
-func (timeoutError) Temporary() bool { return true }
-
-func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
-	// borrow some code from tls.DialWithDialer
-	var errChannel chan error
-	if timeout != 0 {
-		errChannel = make(chan error, 2)
-		time.AfterFunc(timeout, func() {
-			errChannel <- timeoutError{}
-		})
-	}
+func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
 	// use local cfg to avoid clobbering ServerName if using multiple endpoints
-	cfg := *c.config
-	if c.config.ServerName == "" {
+	cfg := cloneTLSConfig(c.config)
+	if cfg.ServerName == "" {
 		colonPos := strings.LastIndex(addr, ":")
 		if colonPos == -1 {
 			colonPos = len(addr)
 		}
 		cfg.ServerName = addr[:colonPos]
 	}
-	conn := tls.Client(rawConn, &cfg)
-	if timeout == 0 {
-		err = conn.Handshake()
-	} else {
-		go func() {
-			errChannel <- conn.Handshake()
-		}()
-		err = <-errChannel
-	}
-	if err != nil {
-		rawConn.Close()
-		return nil, nil, err
+	conn := tls.Client(rawConn, cfg)
+	errChannel := make(chan error, 1)
+	go func() {
+		errChannel <- conn.Handshake()
+	}()
+	select {
+	case err := <-errChannel:
+		if err != nil {
+			return nil, nil, err
+		}
+	case <-ctx.Done():
+		return nil, nil, ctx.Err()
 	}
 	// TODO(zhaoq): Omit the auth info for client now. It is more for
 	// information than anything else.
@@ -181,7 +167,6 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
 func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
 	conn := tls.Server(rawConn, c.config)
 	if err := conn.Handshake(); err != nil {
-		rawConn.Close()
 		return nil, nil, err
 	}
 	return conn, TLSInfo{conn.ConnectionState()}, nil
@@ -189,7 +174,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
 
 // NewTLS uses c to construct a TransportCredentials based on TLS.
 func NewTLS(c *tls.Config) TransportCredentials {
-	tc := &tlsCreds{c}
+	tc := &tlsCreds{cloneTLSConfig(c)}
 	tc.config.NextProtos = alpnProtoStr
 	return tc
 }

+ 76 - 0
cmd/vendor/google.golang.org/grpc/credentials/credentials_util_go17.go

@@ -0,0 +1,76 @@
+// +build go1.7
+
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ *     * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *     * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *     * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package credentials
+
+import (
+	"crypto/tls"
+)
+
+// cloneTLSConfig returns a shallow clone of the exported
+// fields of cfg, ignoring the unexported sync.Once, which
+// contains a mutex and must not be copied.
+//
+// If cfg is nil, a new zero tls.Config is returned.
+//
+// TODO replace this function with official clone function.
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+	if cfg == nil {
+		return &tls.Config{}
+	}
+	return &tls.Config{
+		Rand:                        cfg.Rand,
+		Time:                        cfg.Time,
+		Certificates:                cfg.Certificates,
+		NameToCertificate:           cfg.NameToCertificate,
+		GetCertificate:              cfg.GetCertificate,
+		RootCAs:                     cfg.RootCAs,
+		NextProtos:                  cfg.NextProtos,
+		ServerName:                  cfg.ServerName,
+		ClientAuth:                  cfg.ClientAuth,
+		ClientCAs:                   cfg.ClientCAs,
+		InsecureSkipVerify:          cfg.InsecureSkipVerify,
+		CipherSuites:                cfg.CipherSuites,
+		PreferServerCipherSuites:    cfg.PreferServerCipherSuites,
+		SessionTicketsDisabled:      cfg.SessionTicketsDisabled,
+		SessionTicketKey:            cfg.SessionTicketKey,
+		ClientSessionCache:          cfg.ClientSessionCache,
+		MinVersion:                  cfg.MinVersion,
+		MaxVersion:                  cfg.MaxVersion,
+		CurvePreferences:            cfg.CurvePreferences,
+		DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
+		Renegotiation:               cfg.Renegotiation,
+	}
+}

+ 74 - 0
cmd/vendor/google.golang.org/grpc/credentials/credentials_util_pre_go17.go

@@ -0,0 +1,74 @@
+// +build !go1.7
+
+/*
+ *
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ *     * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *     * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *     * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package credentials
+
+import (
+	"crypto/tls"
+)
+
+// cloneTLSConfig returns a shallow clone of the exported
+// fields of cfg, ignoring the unexported sync.Once, which
+// contains a mutex and must not be copied.
+//
+// If cfg is nil, a new zero tls.Config is returned.
+//
+// TODO replace this function with official clone function.
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+	if cfg == nil {
+		return &tls.Config{}
+	}
+	return &tls.Config{
+		Rand:                     cfg.Rand,
+		Time:                     cfg.Time,
+		Certificates:             cfg.Certificates,
+		NameToCertificate:        cfg.NameToCertificate,
+		GetCertificate:           cfg.GetCertificate,
+		RootCAs:                  cfg.RootCAs,
+		NextProtos:               cfg.NextProtos,
+		ServerName:               cfg.ServerName,
+		ClientAuth:               cfg.ClientAuth,
+		ClientCAs:                cfg.ClientCAs,
+		InsecureSkipVerify:       cfg.InsecureSkipVerify,
+		CipherSuites:             cfg.CipherSuites,
+		PreferServerCipherSuites: cfg.PreferServerCipherSuites,
+		SessionTicketsDisabled:   cfg.SessionTicketsDisabled,
+		SessionTicketKey:         cfg.SessionTicketKey,
+		ClientSessionCache:       cfg.ClientSessionCache,
+		MinVersion:               cfg.MinVersion,
+		MaxVersion:               cfg.MaxVersion,
+		CurvePreferences:         cfg.CurvePreferences,
+	}
+}

+ 10 - 4
cmd/vendor/google.golang.org/grpc/metadata/metadata.go

@@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
 
 // DecodeKeyValue returns the original key and value corresponding to the
 // encoded data in k, v.
+// If k is a binary header and v contains comma, v is split on comma before decoded,
+// and the decoded v will be joined with comma before returned.
 func DecodeKeyValue(k, v string) (string, string, error) {
 	if !strings.HasSuffix(k, binHdrSuffix) {
 		return k, v, nil
 	}
-	val, err := base64.StdEncoding.DecodeString(v)
-	if err != nil {
-		return "", "", err
+	vvs := strings.Split(v, ",")
+	for i, vv := range vvs {
+		val, err := base64.StdEncoding.DecodeString(vv)
+		if err != nil {
+			return "", "", err
+		}
+		vvs[i] = string(val)
 	}
-	return k, string(val), nil
+	return k, strings.Join(vvs, ","), nil
 }
 
 // MD is a mapping from metadata keys to values. Users should use the following

+ 13 - 5
cmd/vendor/google.golang.org/grpc/rpc_util.go

@@ -227,7 +227,7 @@ type parser struct {
 // No other error values or types must be returned, which also means
 // that the underlying io.Reader must not return an incompatible
 // error.
-func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
+func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
 	if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
 		return 0, nil, err
 	}
@@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
 	if length == 0 {
 		return pf, nil, nil
 	}
+	if length > uint32(maxMsgSize) {
+		return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
+	}
 	// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
 	// of making it for each message:
 	msg = make([]byte, int(length))
@@ -308,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
 	return nil
 }
 
-func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
-	pf, d, err := p.recvMsg()
+func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
+	pf, d, err := p.recvMsg(maxMsgSize)
 	if err != nil {
 		return err
 	}
@@ -319,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
 	if pf == compressionMade {
 		d, err = dc.Do(bytes.NewReader(d))
 		if err != nil {
-			return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
+			return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
 		}
 	}
+	if len(d) > maxMsgSize {
+		// TODO: Revisit the error code. Currently keep it consistent with java
+		// implementation.
+		return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
+	}
 	if err := c.Unmarshal(d, m); err != nil {
-		return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
+		return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
 	}
 	return nil
 }

+ 76 - 19
cmd/vendor/google.golang.org/grpc/server.go

@@ -89,9 +89,13 @@ type service struct {
 type Server struct {
 	opts options
 
-	mu     sync.Mutex // guards following
-	lis    map[net.Listener]bool
-	conns  map[io.Closer]bool
+	mu    sync.Mutex // guards following
+	lis   map[net.Listener]bool
+	conns map[io.Closer]bool
+	drain bool
+	// A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
+	// and all the transport goes away.
+	cv     *sync.Cond
 	m      map[string]*service // service name -> service info
 	events trace.EventLog
 }
@@ -101,12 +105,15 @@ type options struct {
 	codec                Codec
 	cp                   Compressor
 	dc                   Decompressor
+	maxMsgSize           int
 	unaryInt             UnaryServerInterceptor
 	streamInt            StreamServerInterceptor
 	maxConcurrentStreams uint32
 	useHandlerImpl       bool // use http.Handler-based server
 }
 
+var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
+
 // A ServerOption sets options.
 type ServerOption func(*options)
 
@@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
 	}
 }
 
-// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
+// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
 func RPCCompressor(cp Compressor) ServerOption {
 	return func(o *options) {
 		o.cp = cp
 	}
 }
 
-// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
+// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
 func RPCDecompressor(dc Decompressor) ServerOption {
 	return func(o *options) {
 		o.dc = dc
 	}
 }
 
+// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
+// If this is not set, gRPC uses the default 4MB.
+func MaxMsgSize(m int) ServerOption {
+	return func(o *options) {
+		o.maxMsgSize = m
+	}
+}
+
 // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
 // of concurrent streams to each ServerTransport.
 func MaxConcurrentStreams(n uint32) ServerOption {
@@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
 // started to accept requests yet.
 func NewServer(opt ...ServerOption) *Server {
 	var opts options
+	opts.maxMsgSize = defaultMaxMsgSize
 	for _, o := range opt {
 		o(&opts)
 	}
@@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server {
 		conns: make(map[io.Closer]bool),
 		m:     make(map[string]*service),
 	}
+	s.cv = sync.NewCond(&s.mu)
 	if EnableTracing {
 		_, file, line, _ := runtime.Caller(1)
 		s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@@ -264,8 +281,8 @@ type ServiceInfo struct {
 
 // GetServiceInfo returns a map from service names to ServiceInfo.
 // Service names include the package names, in the form of <package>.<service>.
-func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
-	ret := make(map[string]*ServiceInfo)
+func (s *Server) GetServiceInfo() map[string]ServiceInfo {
+	ret := make(map[string]ServiceInfo)
 	for n, srv := range s.m {
 		methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
 		for m := range srv.md {
@@ -283,7 +300,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
 			})
 		}
 
-		ret[n] = &ServiceInfo{
+		ret[n] = ServiceInfo{
 			Methods:  methods,
 			Metadata: srv.mdata,
 		}
@@ -468,7 +485,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
 func (s *Server) addConn(c io.Closer) bool {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	if s.conns == nil {
+	if s.conns == nil || s.drain {
 		return false
 	}
 	s.conns[c] = true
@@ -480,6 +497,7 @@ func (s *Server) removeConn(c io.Closer) {
 	defer s.mu.Unlock()
 	if s.conns != nil {
 		delete(s.conns, c)
+		s.cv.Signal()
 	}
 }
 
@@ -520,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
 	}
 	p := &parser{r: stream}
 	for {
-		pf, req, err := p.recvMsg()
+		pf, req, err := p.recvMsg(s.opts.maxMsgSize)
 		if err == io.EOF {
 			// The entire stream is done (for unary RPC only).
 			return err
@@ -530,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
 		}
 		if err != nil {
 			switch err := err.(type) {
+			case *rpcError:
+				if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
+					grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
+				}
 			case transport.ConnectionError:
 				// Nothing to do here.
 			case transport.StreamError:
@@ -569,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
 					return err
 				}
 			}
+			if len(req) > s.opts.maxMsgSize {
+				// TODO: Revisit the error code. Currently keep it consistent with
+				// java implementation.
+				statusCode = codes.Internal
+				statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
+			}
 			if err := s.opts.codec.Unmarshal(req, v); err != nil {
 				return err
 			}
@@ -628,13 +656,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
 		stream.SetSendCompress(s.opts.cp.Type())
 	}
 	ss := &serverStream{
-		t:      t,
-		s:      stream,
-		p:      &parser{r: stream},
-		codec:  s.opts.codec,
-		cp:     s.opts.cp,
-		dc:     s.opts.dc,
-		trInfo: trInfo,
+		t:          t,
+		s:          stream,
+		p:          &parser{r: stream},
+		codec:      s.opts.codec,
+		cp:         s.opts.cp,
+		dc:         s.opts.dc,
+		maxMsgSize: s.opts.maxMsgSize,
+		trInfo:     trInfo,
 	}
 	if ss.cp != nil {
 		ss.cbuf = new(bytes.Buffer)
@@ -766,14 +795,16 @@ func (s *Server) Stop() {
 	s.mu.Lock()
 	listeners := s.lis
 	s.lis = nil
-	cs := s.conns
+	st := s.conns
 	s.conns = nil
+	// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
+	s.cv.Signal()
 	s.mu.Unlock()
 
 	for lis := range listeners {
 		lis.Close()
 	}
-	for c := range cs {
+	for c := range st {
 		c.Close()
 	}
 
@@ -785,6 +816,32 @@ func (s *Server) Stop() {
 	s.mu.Unlock()
 }
 
+// GracefulStop stops the gRPC server gracefully. It stops the server to accept new
+// connections and RPCs and blocks until all the pending RPCs are finished.
+func (s *Server) GracefulStop() {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.drain == true || s.conns == nil {
+		return
+	}
+	s.drain = true
+	for lis := range s.lis {
+		lis.Close()
+	}
+	s.lis = nil
+	for c := range s.conns {
+		c.(transport.ServerTransport).Drain()
+	}
+	for len(s.conns) != 0 {
+		s.cv.Wait()
+	}
+	s.conns = nil
+	if s.events != nil {
+		s.events.Finish()
+		s.events = nil
+	}
+}
+
 func init() {
 	internal.TestingCloseConns = func(arg interface{}) {
 		arg.(*Server).testingCloseConns()

+ 71 - 40
cmd/vendor/google.golang.org/grpc/stream.go

@@ -37,6 +37,7 @@ import (
 	"bytes"
 	"errors"
 	"io"
+	"math"
 	"sync"
 	"time"
 
@@ -84,12 +85,9 @@ type ClientStream interface {
 	// Header returns the header metadata received from the server if there
 	// is any. It blocks if the metadata is not ready to read.
 	Header() (metadata.MD, error)
-	// Trailer returns the trailer metadata from the server. It must be called
-	// after stream.Recv() returns non-nil error (including io.EOF) for
-	// bi-directional streaming and server streaming or stream.CloseAndRecv()
-	// returns for client streaming in order to receive trailer metadata if
-	// present. Otherwise, it could returns an empty MD even though trailer
-	// is present.
+	// Trailer returns the trailer metadata from the server, if there is any.
+	// It must only be called after stream.CloseAndRecv has returned, or
+	// stream.Recv has returned a non-nil error (including io.EOF).
 	Trailer() metadata.MD
 	// CloseSend closes the send direction of the stream. It closes the stream
 	// when non-nil error is met.
@@ -99,11 +97,10 @@ type ClientStream interface {
 
 // NewClientStream creates a new Stream for the client side. This is called
 // by generated code.
-func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
+func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
 	var (
 		t   transport.ClientTransport
 		s   *transport.Stream
-		err error
 		put func()
 	)
 	c := defaultCallInfo
@@ -120,27 +117,24 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 	if cc.dopts.cp != nil {
 		callHdr.SendCompress = cc.dopts.cp.Type()
 	}
-	cs := &clientStream{
-		opts:    opts,
-		c:       c,
-		desc:    desc,
-		codec:   cc.dopts.codec,
-		cp:      cc.dopts.cp,
-		dc:      cc.dopts.dc,
-		tracing: EnableTracing,
-	}
-	if cc.dopts.cp != nil {
-		callHdr.SendCompress = cc.dopts.cp.Type()
-		cs.cbuf = new(bytes.Buffer)
-	}
-	if cs.tracing {
-		cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
-		cs.trInfo.firstLine.client = true
+	var trInfo traceInfo
+	if EnableTracing {
+		trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
+		trInfo.firstLine.client = true
 		if deadline, ok := ctx.Deadline(); ok {
-			cs.trInfo.firstLine.deadline = deadline.Sub(time.Now())
+			trInfo.firstLine.deadline = deadline.Sub(time.Now())
 		}
-		cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false)
-		ctx = trace.NewContext(ctx, cs.trInfo.tr)
+		trInfo.tr.LazyLog(&trInfo.firstLine, false)
+		ctx = trace.NewContext(ctx, trInfo.tr)
+		defer func() {
+			if err != nil {
+				// Need to call tr.finish() if error is returned.
+				// Because tr will not be returned to caller.
+				trInfo.tr.LazyPrintf("RPC: [%v]", err)
+				trInfo.tr.SetError()
+				trInfo.tr.Finish()
+			}
+		}()
 	}
 	gopts := BalancerGetOptions{
 		BlockingWait: !c.failFast,
@@ -168,9 +162,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 				put()
 				put = nil
 			}
-			if _, ok := err.(transport.ConnectionError); ok {
+			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
 				if c.failFast {
-					cs.finish(err)
 					return nil, toRPCErr(err)
 				}
 				continue
@@ -179,16 +172,43 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 		}
 		break
 	}
-	cs.put = put
-	cs.t = t
-	cs.s = s
-	cs.p = &parser{r: s}
-	// Listen on ctx.Done() to detect cancellation when there is no pending
-	// I/O operations on this stream.
+	cs := &clientStream{
+		opts:  opts,
+		c:     c,
+		desc:  desc,
+		codec: cc.dopts.codec,
+		cp:    cc.dopts.cp,
+		dc:    cc.dopts.dc,
+
+		put: put,
+		t:   t,
+		s:   s,
+		p:   &parser{r: s},
+
+		tracing: EnableTracing,
+		trInfo:  trInfo,
+	}
+	if cc.dopts.cp != nil {
+		cs.cbuf = new(bytes.Buffer)
+	}
+	// Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
+	// when there is no pending I/O operations on this stream.
 	go func() {
 		select {
 		case <-t.Error():
 			// Incur transport error, simply exit.
+		case <-s.Done():
+			// TODO: The trace of the RPC is terminated here when there is no pending
+			// I/O, which is probably not the optimal solution.
+			if s.StatusCode() == codes.OK {
+				cs.finish(nil)
+			} else {
+				cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
+			}
+			cs.closeTransportStream(nil)
+		case <-s.GoAway():
+			cs.finish(errConnDrain)
+			cs.closeTransportStream(errConnDrain)
 		case <-s.Context().Done():
 			err := s.Context().Err()
 			cs.finish(err)
@@ -251,7 +271,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
 		if err != nil {
 			cs.finish(err)
 		}
-		if err == nil || err == io.EOF {
+		if err == nil {
+			return
+		}
+		if err == io.EOF {
+			// Specialize the process for server streaming. SendMesg is only called
+			// once when creating the stream object. io.EOF needs to be skipped when
+			// the rpc is early finished (before the stream object is created.).
+			// TODO: It is probably better to move this into the generated code.
+			if !cs.desc.ClientStreams && cs.desc.ServerStreams {
+				err = nil
+			}
 			return
 		}
 		if _, ok := err.(transport.ConnectionError); !ok {
@@ -272,7 +302,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
 }
 
 func (cs *clientStream) RecvMsg(m interface{}) (err error) {
-	err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
+	err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
 	defer func() {
 		// err != nil indicates the termination of the stream.
 		if err != nil {
@@ -291,7 +321,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
 			return
 		}
 		// Special handling for client streaming rpc.
-		err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
+		err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
 		cs.closeTransportStream(err)
 		if err == nil {
 			return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -326,7 +356,7 @@ func (cs *clientStream) CloseSend() (err error) {
 		}
 	}()
 	if err == nil || err == io.EOF {
-		return
+		return nil
 	}
 	if _, ok := err.(transport.ConnectionError); !ok {
 		cs.closeTransportStream(err)
@@ -392,6 +422,7 @@ type serverStream struct {
 	cp         Compressor
 	dc         Decompressor
 	cbuf       *bytes.Buffer
+	maxMsgSize int
 	statusCode codes.Code
 	statusDesc string
 	trInfo     *traceInfo
@@ -458,5 +489,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
 			ss.mu.Unlock()
 		}
 	}()
-	return recv(ss.p, ss.codec, ss.s, ss.dc, m)
+	return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
 }

+ 5 - 0
cmd/vendor/google.golang.org/grpc/transport/control.go

@@ -72,6 +72,11 @@ type resetStream struct {
 
 func (*resetStream) item() {}
 
+type goAway struct {
+}
+
+func (*goAway) item() {}
+
 type flushIO struct {
 }
 

+ 46 - 0
cmd/vendor/google.golang.org/grpc/transport/go16.go

@@ -0,0 +1,46 @@
+// +build go1.6,!go1.7
+
+/*
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ *     * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *     * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *     * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package transport
+
+import (
+	"net"
+
+	"golang.org/x/net/context"
+)
+
+// dialContext connects to the address on the named network.
+func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
+	return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
+}

+ 46 - 0
cmd/vendor/google.golang.org/grpc/transport/go17.go

@@ -0,0 +1,46 @@
+// +build go1.7
+
+/*
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ *     * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *     * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *     * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package transport
+
+import (
+	"net"
+
+	"golang.org/x/net/context"
+)
+
+// dialContext connects to the address on the named network.
+func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
+	return (&net.Dialer{}).DialContext(ctx, network, address)
+}

+ 6 - 2
cmd/vendor/google.golang.org/grpc/transport/handler_server.go

@@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
 	}
 
 	if v := r.Header.Get("grpc-timeout"); v != "" {
-		to, err := timeoutDecode(v)
+		to, err := decodeTimeout(v)
 		if err != nil {
 			return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
 		}
@@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
 		h := ht.rw.Header()
 		h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
 		if statusDesc != "" {
-			h.Set("Grpc-Message", statusDesc)
+			h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
 		}
 		if md := s.Trailer(); len(md) > 0 {
 			for k, vv := range md {
@@ -370,6 +370,10 @@ func (ht *serverHandlerTransport) runStream() {
 	}
 }
 
+func (ht *serverHandlerTransport) Drain() {
+	panic("Drain() is not implemented")
+}
+
 // mapRecvMsgError returns the non-nil err into the appropriate
 // error value as expected by callers of *grpc.parser.recvMsg.
 // In particular, in can only be:

+ 141 - 63
cmd/vendor/google.golang.org/grpc/transport/http2_client.go

@@ -35,6 +35,7 @@ package transport
 
 import (
 	"bytes"
+	"fmt"
 	"io"
 	"math"
 	"net"
@@ -71,6 +72,9 @@ type http2Client struct {
 	shutdownChan chan struct{}
 	// errorChan is closed to notify the I/O error to the caller.
 	errorChan chan struct{}
+	// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
+	// that the server sent GoAway on this transport.
+	goAway chan struct{}
 
 	framer *framer
 	hBuf   *bytes.Buffer  // the buffer for HPACK encoding
@@ -97,41 +101,44 @@ type http2Client struct {
 	maxStreams int
 	// the per-stream outbound flow control window size set by the peer.
 	streamSendQuota uint32
+	// goAwayID records the Last-Stream-ID in the GoAway frame from the server.
+	goAwayID uint32
+	// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
+	prevGoAwayID uint32
+}
+
+func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
+	if fn != nil {
+		return fn(ctx, addr)
+	}
+	return dialContext(ctx, "tcp", addr)
 }
 
 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
 // and starts to receive messages on it. Non-nil error returns if construction
 // fails.
-func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
-	if opts.Dialer == nil {
-		// Set the default Dialer.
-		opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
-			return net.DialTimeout("tcp", addr, timeout)
-		}
-	}
+func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
 	scheme := "http"
-	startT := time.Now()
-	timeout := opts.Timeout
-	conn, connErr := opts.Dialer(addr, timeout)
+	conn, connErr := dial(opts.Dialer, ctx, addr)
 	if connErr != nil {
-		return nil, ConnectionErrorf("transport: %v", connErr)
+		return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr)
 	}
+	// Any further errors will close the underlying connection
+	defer func(conn net.Conn) {
+		if err != nil {
+			conn.Close()
+		}
+	}(conn)
 	var authInfo credentials.AuthInfo
-	if opts.TransportCredentials != nil {
+	if creds := opts.TransportCredentials; creds != nil {
 		scheme = "https"
-		if timeout > 0 {
-			timeout -= time.Since(startT)
-		}
-		conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
+		conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
 	}
 	if connErr != nil {
-		return nil, ConnectionErrorf("transport: %v", connErr)
+		// Credentials handshake error is not a temporary error (unless the error
+		// was the connection closing).
+		return nil, ConnectionErrorf(connErr == io.EOF, connErr, "transport: %v", connErr)
 	}
-	defer func() {
-		if err != nil {
-			conn.Close()
-		}
-	}()
 	ua := primaryUA
 	if opts.UserAgent != "" {
 		ua = opts.UserAgent + " " + ua
@@ -147,6 +154,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
 		writableChan:    make(chan int, 1),
 		shutdownChan:    make(chan struct{}),
 		errorChan:       make(chan struct{}),
+		goAway:          make(chan struct{}),
 		framer:          newFramer(conn),
 		hBuf:            &buf,
 		hEnc:            hpack.NewEncoder(&buf),
@@ -168,11 +176,11 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
 	n, err := t.conn.Write(clientPreface)
 	if err != nil {
 		t.Close()
-		return nil, ConnectionErrorf("transport: %v", err)
+		return nil, ConnectionErrorf(true, err, "transport: %v", err)
 	}
 	if n != len(clientPreface) {
 		t.Close()
-		return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
+		return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
 	}
 	if initialWindowSize != defaultWindowSize {
 		err = t.framer.writeSettings(true, http2.Setting{
@@ -184,13 +192,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
 	}
 	if err != nil {
 		t.Close()
-		return nil, ConnectionErrorf("transport: %v", err)
+		return nil, ConnectionErrorf(true, err, "transport: %v", err)
 	}
 	// Adjust the connection flow control window if needed.
 	if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
 		if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
 			t.Close()
-			return nil, ConnectionErrorf("transport: %v", err)
+			return nil, ConnectionErrorf(true, err, "transport: %v", err)
 		}
 	}
 	go t.controller()
@@ -202,6 +210,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
 	// TODO(zhaoq): Handle uint32 overflow of Stream.id.
 	s := &Stream{
 		id:            t.nextID,
+		done:          make(chan struct{}),
+		goAway:        make(chan struct{}),
 		method:        callHdr.Method,
 		sendCompress:  callHdr.SendCompress,
 		buf:           newRecvBuffer(),
@@ -216,8 +226,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
 	// Make a stream be able to cancel the pending operations by itself.
 	s.ctx, s.cancel = context.WithCancel(ctx)
 	s.dec = &recvBufferReader{
-		ctx:  s.ctx,
-		recv: s.buf,
+		ctx:    s.ctx,
+		goAway: s.goAway,
+		recv:   s.buf,
 	}
 	return s
 }
@@ -271,6 +282,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 		t.mu.Unlock()
 		return nil, ErrConnClosing
 	}
+	if t.state == draining {
+		t.mu.Unlock()
+		return nil, ErrStreamDrain
+	}
 	if t.state != reachable {
 		t.mu.Unlock()
 		return nil, ErrConnClosing
@@ -278,7 +293,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 	checkStreamsQuota := t.streamsQuota != nil
 	t.mu.Unlock()
 	if checkStreamsQuota {
-		sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire())
+		sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
 		if err != nil {
 			return nil, err
 		}
@@ -287,7 +302,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 			t.streamsQuota.add(sq - 1)
 		}
 	}
-	if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
+	if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
 		// Return the quota back now because there is no stream returned to the caller.
 		if _, ok := err.(StreamError); ok && checkStreamsQuota {
 			t.streamsQuota.add(1)
@@ -295,6 +310,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 		return nil, err
 	}
 	t.mu.Lock()
+	if t.state == draining {
+		t.mu.Unlock()
+		if checkStreamsQuota {
+			t.streamsQuota.add(1)
+		}
+		// Need to make t writable again so that the rpc in flight can still proceed.
+		t.writableChan <- 0
+		return nil, ErrStreamDrain
+	}
 	if t.state != reachable {
 		t.mu.Unlock()
 		return nil, ErrConnClosing
@@ -329,7 +353,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 		t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
 	}
 	if timeout > 0 {
-		t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
+		t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
 	}
 	for k, v := range authData {
 		// Capital header names are illegal in HTTP/2.
@@ -384,7 +408,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 		}
 		if err != nil {
 			t.notifyError(err)
-			return nil, ConnectionErrorf("transport: %v", err)
+			return nil, ConnectionErrorf(true, err, "transport: %v", err)
 		}
 	}
 	t.writableChan <- 0
@@ -403,22 +427,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
 	if t.streamsQuota != nil {
 		updateStreams = true
 	}
-	if t.state == draining && len(t.activeStreams) == 1 {
+	delete(t.activeStreams, s.id)
+	if t.state == draining && len(t.activeStreams) == 0 {
 		// The transport is draining and s is the last live stream on t.
 		t.mu.Unlock()
 		t.Close()
 		return
 	}
-	delete(t.activeStreams, s.id)
 	t.mu.Unlock()
 	if updateStreams {
 		t.streamsQuota.add(1)
 	}
-	// In case stream sending and receiving are invoked in separate
-	// goroutines (e.g., bi-directional streaming), the caller needs
-	// to call cancel on the stream to interrupt the blocking on
-	// other goroutines.
-	s.cancel()
 	s.mu.Lock()
 	if q := s.fc.resetPendingData(); q > 0 {
 		if n := t.fc.onRead(q); n > 0 {
@@ -445,13 +464,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
 // accessed any more.
 func (t *http2Client) Close() (err error) {
 	t.mu.Lock()
-	if t.state == reachable {
-		close(t.errorChan)
-	}
 	if t.state == closing {
 		t.mu.Unlock()
 		return
 	}
+	if t.state == reachable || t.state == draining {
+		close(t.errorChan)
+	}
 	t.state = closing
 	t.mu.Unlock()
 	close(t.shutdownChan)
@@ -475,10 +494,35 @@ func (t *http2Client) Close() (err error) {
 
 func (t *http2Client) GracefulClose() error {
 	t.mu.Lock()
-	if t.state == closing {
+	switch t.state {
+	case unreachable:
+		// The server may close the connection concurrently. t is not available for
+		// any streams. Close it now.
+		t.mu.Unlock()
+		t.Close()
+		return nil
+	case closing:
 		t.mu.Unlock()
 		return nil
 	}
+	// Notify the streams which were initiated after the server sent GOAWAY.
+	select {
+	case <-t.goAway:
+		n := t.prevGoAwayID
+		if n == 0 && t.nextID > 1 {
+			n = t.nextID - 2
+		}
+		m := t.goAwayID + 2
+		if m == 2 {
+			m = 1
+		}
+		for i := m; i <= n; i += 2 {
+			if s, ok := t.activeStreams[i]; ok {
+				close(s.goAway)
+			}
+		}
+	default:
+	}
 	if t.state == draining {
 		t.mu.Unlock()
 		return nil
@@ -504,15 +548,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
 			size := http2MaxFrameLen
 			s.sendQuotaPool.add(0)
 			// Wait until the stream has some quota to send the data.
-			sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
+			sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
 			if err != nil {
 				return err
 			}
 			t.sendQuotaPool.add(0)
 			// Wait until the transport has some quota to send the data.
-			tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
+			tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
 			if err != nil {
-				if _, ok := err.(StreamError); ok {
+				if _, ok := err.(StreamError); ok || err == io.EOF {
 					t.sendQuotaPool.cancel()
 				}
 				return err
@@ -544,8 +588,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
 		// Indicate there is a writer who is about to write a data frame.
 		t.framer.adjustNumWriters(1)
 		// Got some quota. Try to acquire writing privilege on the transport.
-		if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
-			if _, ok := err.(StreamError); ok {
+		if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
+			if _, ok := err.(StreamError); ok || err == io.EOF {
 				// Return the connection quota back.
 				t.sendQuotaPool.add(len(p))
 			}
@@ -578,7 +622,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
 		// invoked.
 		if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
 			t.notifyError(err)
-			return ConnectionErrorf("transport: %v", err)
+			return ConnectionErrorf(true, err, "transport: %v", err)
 		}
 		if t.framer.adjustNumWriters(-1) == 0 {
 			t.framer.flushWrite()
@@ -593,11 +637,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
 	}
 	s.mu.Lock()
 	if s.state != streamDone {
-		if s.state == streamReadDone {
-			s.state = streamDone
-		} else {
-			s.state = streamWriteDone
-		}
+		s.state = streamWriteDone
 	}
 	s.mu.Unlock()
 	return nil
@@ -630,7 +670,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
 func (t *http2Client) handleData(f *http2.DataFrame) {
 	size := len(f.Data())
 	if err := t.fc.onData(uint32(size)); err != nil {
-		t.notifyError(ConnectionErrorf("%v", err))
+		t.notifyError(ConnectionErrorf(true, err, "%v", err))
 		return
 	}
 	// Select the right stream to dispatch.
@@ -655,6 +695,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
 			s.state = streamDone
 			s.statusCode = codes.Internal
 			s.statusDesc = err.Error()
+			close(s.done)
 			s.mu.Unlock()
 			s.write(recvMsg{err: io.EOF})
 			t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
@@ -672,13 +713,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
 	// the read direction is closed, and set the status appropriately.
 	if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
 		s.mu.Lock()
-		if s.state == streamWriteDone {
-			s.state = streamDone
-		} else {
-			s.state = streamReadDone
+		if s.state == streamDone {
+			s.mu.Unlock()
+			return
 		}
+		s.state = streamDone
 		s.statusCode = codes.Internal
 		s.statusDesc = "server closed the stream without sending trailers"
+		close(s.done)
 		s.mu.Unlock()
 		s.write(recvMsg{err: io.EOF})
 	}
@@ -704,6 +746,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
 		grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
 		s.statusCode = codes.Unknown
 	}
+	s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
+	close(s.done)
 	s.mu.Unlock()
 	s.write(recvMsg{err: io.EOF})
 }
@@ -728,7 +772,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
 }
 
 func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
-	// TODO(zhaoq): GoAwayFrame handler to be implemented
+	t.mu.Lock()
+	if t.state == reachable || t.state == draining {
+		if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
+			t.mu.Unlock()
+			t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
+			return
+		}
+		select {
+		case <-t.goAway:
+			id := t.goAwayID
+			// t.goAway has been closed (i.e.,multiple GoAways).
+			if id < f.LastStreamID {
+				t.mu.Unlock()
+				t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
+				return
+			}
+			t.prevGoAwayID = id
+			t.goAwayID = f.LastStreamID
+			t.mu.Unlock()
+			return
+		default:
+		}
+		t.goAwayID = f.LastStreamID
+		close(t.goAway)
+	}
+	t.mu.Unlock()
 }
 
 func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
@@ -780,11 +849,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
 	if len(state.mdata) > 0 {
 		s.trailer = state.mdata
 	}
-	s.state = streamDone
 	s.statusCode = state.statusCode
 	s.statusDesc = state.statusDesc
+	close(s.done)
+	s.state = streamDone
 	s.mu.Unlock()
-
 	s.write(recvMsg{err: io.EOF})
 }
 
@@ -937,13 +1006,22 @@ func (t *http2Client) Error() <-chan struct{} {
 	return t.errorChan
 }
 
+func (t *http2Client) GoAway() <-chan struct{} {
+	return t.goAway
+}
+
 func (t *http2Client) notifyError(err error) {
 	t.mu.Lock()
-	defer t.mu.Unlock()
 	// make sure t.errorChan is closed only once.
+	if t.state == draining {
+		t.mu.Unlock()
+		t.Close()
+		return
+	}
 	if t.state == reachable {
 		t.state = unreachable
 		close(t.errorChan)
 		grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
 	}
+	t.mu.Unlock()
 }

+ 51 - 25
cmd/vendor/google.golang.org/grpc/transport/http2_server.go

@@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
 			Val: uint32(initialWindowSize)})
 	}
 	if err := framer.writeSettings(true, settings...); err != nil {
-		return nil, ConnectionErrorf("transport: %v", err)
+		return nil, ConnectionErrorf(true, err, "transport: %v", err)
 	}
 	// Adjust the connection flow control window if needed.
 	if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
 		if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
-			return nil, ConnectionErrorf("transport: %v", err)
+			return nil, ConnectionErrorf(true, err, "transport: %v", err)
 		}
 	}
 	var buf bytes.Buffer
@@ -142,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
 }
 
 // operateHeader takes action on the decoded headers.
-func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) {
+func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) {
 	buf := newRecvBuffer()
 	s := &Stream{
 		id:  frame.Header().StreamID,
@@ -205,6 +205,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
 		t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
 		return
 	}
+	if s.id%2 != 1 || s.id <= t.maxStreamID {
+		t.mu.Unlock()
+		// illegal gRPC stream id.
+		grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id)
+		return true
+	}
+	t.maxStreamID = s.id
 	s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
 	t.activeStreams[s.id] = s
 	t.mu.Unlock()
@@ -212,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
 		t.updateWindow(s, uint32(n))
 	}
 	handle(s)
+	return
 }
 
 // HandleStreams receives incoming streams using the given handler. This is
@@ -231,6 +239,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
 	}
 
 	frame, err := t.framer.readFrame()
+	if err == io.EOF || err == io.ErrUnexpectedEOF {
+		t.Close()
+		return
+	}
 	if err != nil {
 		grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
 		t.Close()
@@ -257,20 +269,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
 				t.controlBuf.put(&resetStream{se.StreamID, se.Code})
 				continue
 			}
+			if err == io.EOF || err == io.ErrUnexpectedEOF {
+				t.Close()
+				return
+			}
+			grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
 			t.Close()
 			return
 		}
 		switch frame := frame.(type) {
 		case *http2.MetaHeadersFrame:
-			id := frame.Header().StreamID
-			if id%2 != 1 || id <= t.maxStreamID {
-				// illegal gRPC stream id.
-				grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
+			if t.operateHeaders(frame, handle) {
 				t.Close()
 				break
 			}
-			t.maxStreamID = id
-			t.operateHeaders(frame, handle)
 		case *http2.DataFrame:
 			t.handleData(frame)
 		case *http2.RSTStreamFrame:
@@ -282,7 +294,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
 		case *http2.WindowUpdateFrame:
 			t.handleWindowUpdate(frame)
 		case *http2.GoAwayFrame:
-			break
+			// TODO: Handle GoAway from the client appropriately.
 		default:
 			grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
 		}
@@ -364,11 +376,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
 		// Received the end of stream from the client.
 		s.mu.Lock()
 		if s.state != streamDone {
-			if s.state == streamWriteDone {
-				s.state = streamDone
-			} else {
-				s.state = streamReadDone
-			}
+			s.state = streamReadDone
 		}
 		s.mu.Unlock()
 		s.write(recvMsg{err: io.EOF})
@@ -440,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
 		}
 		if err != nil {
 			t.Close()
-			return ConnectionErrorf("transport: %v", err)
+			return ConnectionErrorf(true, err, "transport: %v", err)
 		}
 	}
 	return nil
@@ -455,7 +463,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
 	}
 	s.headerOk = true
 	s.mu.Unlock()
-	if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
+	if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
 		return err
 	}
 	t.hBuf.Reset()
@@ -495,7 +503,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
 		headersSent = true
 	}
 	s.mu.Unlock()
-	if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
+	if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
 		return err
 	}
 	t.hBuf.Reset()
@@ -508,7 +516,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
 			Name:  "grpc-status",
 			Value: strconv.Itoa(int(statusCode)),
 		})
-	t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc})
+	t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
 	// Attach the trailer metadata.
 	for k, v := range s.trailer {
 		// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
@@ -544,7 +552,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
 	}
 	s.mu.Unlock()
 	if writeHeaderFrame {
-		if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
+		if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
 			return err
 		}
 		t.hBuf.Reset()
@@ -560,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
 		}
 		if err := t.framer.writeHeaders(false, p); err != nil {
 			t.Close()
-			return ConnectionErrorf("transport: %v", err)
+			return ConnectionErrorf(true, err, "transport: %v", err)
 		}
 		t.writableChan <- 0
 	}
@@ -572,13 +580,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
 		size := http2MaxFrameLen
 		s.sendQuotaPool.add(0)
 		// Wait until the stream has some quota to send the data.
-		sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
+		sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire())
 		if err != nil {
 			return err
 		}
 		t.sendQuotaPool.add(0)
 		// Wait until the transport has some quota to send the data.
-		tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
+		tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire())
 		if err != nil {
 			if _, ok := err.(StreamError); ok {
 				t.sendQuotaPool.cancel()
@@ -604,7 +612,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
 		t.framer.adjustNumWriters(1)
 		// Got some quota. Try to acquire writing privilege on the
 		// transport.
-		if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
+		if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
 			if _, ok := err.(StreamError); ok {
 				// Return the connection quota back.
 				t.sendQuotaPool.add(ps)
@@ -634,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
 		}
 		if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
 			t.Close()
-			return ConnectionErrorf("transport: %v", err)
+			return ConnectionErrorf(true, err, "transport: %v", err)
 		}
 		if t.framer.adjustNumWriters(-1) == 0 {
 			t.framer.flushWrite()
@@ -679,6 +687,17 @@ func (t *http2Server) controller() {
 					}
 				case *resetStream:
 					t.framer.writeRSTStream(true, i.streamID, i.code)
+				case *goAway:
+					t.mu.Lock()
+					if t.state == closing {
+						t.mu.Unlock()
+						// The transport is closing.
+						return
+					}
+					sid := t.maxStreamID
+					t.state = draining
+					t.mu.Unlock()
+					t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
 				case *flushIO:
 					t.framer.flushWrite()
 				case *ping:
@@ -724,6 +743,9 @@ func (t *http2Server) Close() (err error) {
 func (t *http2Server) closeStream(s *Stream) {
 	t.mu.Lock()
 	delete(t.activeStreams, s.id)
+	if t.state == draining && len(t.activeStreams) == 0 {
+		defer t.Close()
+	}
 	t.mu.Unlock()
 	// In case stream sending and receiving are invoked in separate
 	// goroutines (e.g., bi-directional streaming), cancel needs to be
@@ -746,3 +768,7 @@ func (t *http2Server) closeStream(s *Stream) {
 func (t *http2Server) RemoteAddr() net.Addr {
 	return t.conn.RemoteAddr()
 }
+
+func (t *http2Server) Drain() {
+	t.controlBuf.put(&goAway{})
+}

+ 79 - 4
cmd/vendor/google.golang.org/grpc/transport/http_util.go

@@ -35,6 +35,7 @@ package transport
 
 import (
 	"bufio"
+	"bytes"
 	"fmt"
 	"io"
 	"net"
@@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
 		}
 		d.statusCode = codes.Code(code)
 	case "grpc-message":
-		d.statusDesc = f.Value
+		d.statusDesc = decodeGrpcMessage(f.Value)
 	case "grpc-timeout":
 		d.timeoutSet = true
 		var err error
-		d.timeout, err = timeoutDecode(f.Value)
+		d.timeout, err = decodeTimeout(f.Value)
 		if err != nil {
 			d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
 			return
@@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
 }
 
 // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
-func timeoutEncode(t time.Duration) string {
+func encodeTimeout(t time.Duration) string {
 	if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
 		return strconv.FormatInt(d, 10) + "n"
 	}
@@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
 	return strconv.FormatInt(div(t, time.Hour), 10) + "H"
 }
 
-func timeoutDecode(s string) (time.Duration, error) {
+func decodeTimeout(s string) (time.Duration, error) {
 	size := len(s)
 	if size < 2 {
 		return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
@@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
 	return d * time.Duration(t), nil
 }
 
+const (
+	spaceByte   = ' '
+	tildaByte   = '~'
+	percentByte = '%'
+)
+
+// encodeGrpcMessage is used to encode status code in header field
+// "grpc-message".
+// It checks to see if each individual byte in msg is an
+// allowable byte, and then either percent encoding or passing it through.
+// When percent encoding, the byte is converted into hexadecimal notation
+// with a '%' prepended.
+func encodeGrpcMessage(msg string) string {
+	if msg == "" {
+		return ""
+	}
+	lenMsg := len(msg)
+	for i := 0; i < lenMsg; i++ {
+		c := msg[i]
+		if !(c >= spaceByte && c < tildaByte && c != percentByte) {
+			return encodeGrpcMessageUnchecked(msg)
+		}
+	}
+	return msg
+}
+
+func encodeGrpcMessageUnchecked(msg string) string {
+	var buf bytes.Buffer
+	lenMsg := len(msg)
+	for i := 0; i < lenMsg; i++ {
+		c := msg[i]
+		if c >= spaceByte && c < tildaByte && c != percentByte {
+			buf.WriteByte(c)
+		} else {
+			buf.WriteString(fmt.Sprintf("%%%02X", c))
+		}
+	}
+	return buf.String()
+}
+
+// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
+func decodeGrpcMessage(msg string) string {
+	if msg == "" {
+		return ""
+	}
+	lenMsg := len(msg)
+	for i := 0; i < lenMsg; i++ {
+		if msg[i] == percentByte && i+2 < lenMsg {
+			return decodeGrpcMessageUnchecked(msg)
+		}
+	}
+	return msg
+}
+
+func decodeGrpcMessageUnchecked(msg string) string {
+	var buf bytes.Buffer
+	lenMsg := len(msg)
+	for i := 0; i < lenMsg; i++ {
+		c := msg[i]
+		if c == percentByte && i+2 < lenMsg {
+			parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
+			if err != nil {
+				buf.WriteByte(c)
+			} else {
+				buf.WriteByte(byte(parsed))
+				i += 2
+			}
+		} else {
+			buf.WriteByte(c)
+		}
+	}
+	return buf.String()
+}
+
 type framer struct {
 	numWriters int32
 	reader     io.Reader

+ 51 - 0
cmd/vendor/google.golang.org/grpc/transport/pre_go16.go

@@ -0,0 +1,51 @@
+// +build !go1.6
+
+/*
+ * Copyright 2016, Google Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ *     * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *     * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *     * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+package transport
+
+import (
+	"net"
+	"time"
+
+	"golang.org/x/net/context"
+)
+
+// dialContext connects to the address on the named network.
+func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
+	var dialer net.Dialer
+	if deadline, ok := ctx.Deadline(); ok {
+		dialer.Timeout = deadline.Sub(time.Now())
+	}
+	return dialer.Dial(network, address)
+}

+ 78 - 15
cmd/vendor/google.golang.org/grpc/transport/transport.go

@@ -44,7 +44,6 @@ import (
 	"io"
 	"net"
 	"sync"
-	"time"
 
 	"golang.org/x/net/context"
 	"golang.org/x/net/trace"
@@ -120,10 +119,11 @@ func (b *recvBuffer) get() <-chan item {
 // recvBufferReader implements io.Reader interface to read the data from
 // recvBuffer.
 type recvBufferReader struct {
-	ctx  context.Context
-	recv *recvBuffer
-	last *bytes.Reader // Stores the remaining data in the previous calls.
-	err  error
+	ctx    context.Context
+	goAway chan struct{}
+	recv   *recvBuffer
+	last   *bytes.Reader // Stores the remaining data in the previous calls.
+	err    error
 }
 
 // Read reads the next len(p) bytes from last. If last is drained, it tries to
@@ -141,6 +141,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
 	select {
 	case <-r.ctx.Done():
 		return 0, ContextErr(r.ctx.Err())
+	case <-r.goAway:
+		return 0, ErrStreamDrain
 	case i := <-r.recv.get():
 		r.recv.load()
 		m := i.(*recvMsg)
@@ -158,7 +160,7 @@ const (
 	streamActive    streamState = iota
 	streamWriteDone             // EndStream sent
 	streamReadDone              // EndStream received
-	streamDone                  // sendDone and recvDone or RSTStreamFrame is sent or received.
+	streamDone                  // the entire stream is finished.
 )
 
 // Stream represents an RPC in the transport layer.
@@ -169,6 +171,10 @@ type Stream struct {
 	// ctx is the associated context of the stream.
 	ctx    context.Context
 	cancel context.CancelFunc
+	// done is closed when the final status arrives.
+	done chan struct{}
+	// goAway is closed when the server sent GoAways signal before this stream was initiated.
+	goAway chan struct{}
 	// method records the associated RPC method of the stream.
 	method       string
 	recvCompress string
@@ -214,6 +220,18 @@ func (s *Stream) SetSendCompress(str string) {
 	s.sendCompress = str
 }
 
+// Done returns a chanel which is closed when it receives the final status
+// from the server.
+func (s *Stream) Done() <-chan struct{} {
+	return s.done
+}
+
+// GoAway returns a channel which is closed when the server sent GoAways signal
+// before this stream was initiated.
+func (s *Stream) GoAway() <-chan struct{} {
+	return s.goAway
+}
+
 // Header acquires the key-value pairs of header metadata once it
 // is available. It blocks until i) the metadata is ready or ii) there is no
 // header metadata or iii) the stream is cancelled/expired.
@@ -221,6 +239,8 @@ func (s *Stream) Header() (metadata.MD, error) {
 	select {
 	case <-s.ctx.Done():
 		return nil, ContextErr(s.ctx.Err())
+	case <-s.goAway:
+		return nil, ErrStreamDrain
 	case <-s.headerChan:
 		return s.header.Copy(), nil
 	}
@@ -335,19 +355,17 @@ type ConnectOptions struct {
 	// UserAgent is the application user agent.
 	UserAgent string
 	// Dialer specifies how to dial a network address.
-	Dialer func(string, time.Duration) (net.Conn, error)
+	Dialer func(context.Context, string) (net.Conn, error)
 	// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
 	PerRPCCredentials []credentials.PerRPCCredentials
 	// TransportCredentials stores the Authenticator required to setup a client connection.
 	TransportCredentials credentials.TransportCredentials
-	// Timeout specifies the timeout for dialing a ClientTransport.
-	Timeout time.Duration
 }
 
 // NewClientTransport establishes the transport with the required ConnectOptions
 // and returns it to the caller.
-func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) {
-	return newHTTP2Client(target, opts)
+func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
+	return newHTTP2Client(ctx, target, opts)
 }
 
 // Options provides additional hints and information for message
@@ -417,6 +435,11 @@ type ClientTransport interface {
 	// and create a new one) in error case. It should not return nil
 	// once the transport is initiated.
 	Error() <-chan struct{}
+
+	// GoAway returns a channel that is closed when ClientTranspor
+	// receives the draining signal from the server (e.g., GOAWAY frame in
+	// HTTP/2).
+	GoAway() <-chan struct{}
 }
 
 // ServerTransport is the common interface for all gRPC server-side transport
@@ -448,6 +471,9 @@ type ServerTransport interface {
 
 	// RemoteAddr returns the remote network address.
 	RemoteAddr() net.Addr
+
+	// Drain notifies the client this ServerTransport stops accepting new RPCs.
+	Drain()
 }
 
 // StreamErrorf creates an StreamError with the specified error code and description.
@@ -459,9 +485,11 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
 }
 
 // ConnectionErrorf creates an ConnectionError with the specified error description.
-func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
+func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
 	return ConnectionError{
 		Desc: fmt.Sprintf(format, a...),
+		temp: temp,
+		err:  e,
 	}
 }
 
@@ -469,14 +497,36 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
 // entire connection and the retry of all the active streams.
 type ConnectionError struct {
 	Desc string
+	temp bool
+	err  error
 }
 
 func (e ConnectionError) Error() string {
 	return fmt.Sprintf("connection error: desc = %q", e.Desc)
 }
 
-// ErrConnClosing indicates that the transport is closing.
-var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
+// Temporary indicates if this connection error is temporary or fatal.
+func (e ConnectionError) Temporary() bool {
+	return e.temp
+}
+
+// Origin returns the original error of this connection error.
+func (e ConnectionError) Origin() error {
+	// Never return nil error here.
+	// If the original error is nil, return itself.
+	if e.err == nil {
+		return e
+	}
+	return e.err
+}
+
+var (
+	// ErrConnClosing indicates that the transport is closing.
+	ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true}
+	// ErrStreamDrain indicates that the stream is rejected by the server because
+	// the server stops accepting new RPCs.
+	ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
+)
 
 // StreamError is an error that only affects one stream within a connection.
 type StreamError struct {
@@ -501,12 +551,25 @@ func ContextErr(err error) StreamError {
 
 // wait blocks until it can receive from ctx.Done, closing, or proceed.
 // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
+// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
+// it return the StreamError for ctx.Err.
+// If it receives from goAway, it returns 0, ErrStreamDrain.
 // If it receives from closing, it returns 0, ErrConnClosing.
 // If it receives from proceed, it returns the received integer, nil.
-func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) {
+func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
 	select {
 	case <-ctx.Done():
 		return 0, ContextErr(ctx.Err())
+	case <-done:
+		// User cancellation has precedence.
+		select {
+		case <-ctx.Done():
+			return 0, ContextErr(ctx.Err())
+		default:
+		}
+		return 0, io.EOF
+	case <-goAway:
+		return 0, ErrStreamDrain
 	case <-closing:
 		return 0, ErrConnClosing
 	case i := <-proceed:

+ 3 - 3
glide.lock

@@ -1,5 +1,5 @@
-hash: d56fcba6f86aae631325ec24baa27b102e915f533ed63391807795f78e8c25f6
-updated: 2016-08-15T12:13:05.482527088-07:00
+hash: c04a6b4fa79b3e780e38f875e3863551c8a7e5a1d1020b9a3a06268d50e592e3
+updated: 2016-08-15T15:51:12.533454067-07:00
 imports:
 - name: bitbucket.org/ww/goautoneg
   version: 75cd24fc2f2c2a2088577d12123ddee5f54e0675
@@ -136,7 +136,7 @@ imports:
   subpackages:
   - rate
 - name: google.golang.org/grpc
-  version: 02fca896ff5f50c6bbbee0860345a49344b37a03
+  version: c2781963b3af261a37e0f14fdcb7c1fa13259e1f
   subpackages:
   - codes
   - credentials

+ 1 - 1
glide.yaml

@@ -134,7 +134,7 @@ import:
   subpackages:
   - rate
 - package: google.golang.org/grpc
-  version: 02fca896ff5f50c6bbbee0860345a49344b37a03
+  version: c2781963b3af261a37e0f14fdcb7c1fa13259e1f
   subpackages:
   - codes
   - credentials