Browse Source

Merge pull request #8987 from gyuho/tls-shutdown

embed: fix *grpc.Server panic on GracefulStop with TLS-enabled server
Gyuho Lee 8 years ago
parent
commit
015c04bcf5
3 changed files with 82 additions and 29 deletions
  1. 42 14
      embed/etcd.go
  2. 13 8
      embed/serve.go
  3. 27 7
      integration/embed_test.go

+ 42 - 14
embed/etcd.go

@@ -100,9 +100,9 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
 			return
 		}
 		if !serving {
-			// errored before starting gRPC server for serveCtx.grpcServerC
+			// errored before starting gRPC server for serveCtx.serversC
 			for _, sctx := range e.sctxs {
-				close(sctx.grpcServerC)
+				close(sctx.serversC)
 			}
 		}
 		e.Close()
@@ -219,23 +219,35 @@ func (e *Etcd) Config() Config {
 	return e.cfg
 }
 
+// Close gracefully shuts down all servers/listeners.
+// Client requests will be terminated with request timeout.
+// After timeout, enforce remaning requests be closed immediately.
 func (e *Etcd) Close() {
 	e.closeOnce.Do(func() { close(e.stopc) })
 
+	// close client requests with request timeout
+	timeout := 2 * time.Second
+	if e.Server != nil {
+		timeout = e.Server.Cfg.ReqTimeout()
+	}
 	for _, sctx := range e.sctxs {
-		for gs := range sctx.grpcServerC {
-			e.stopGRPCServer(gs)
+		for ss := range sctx.serversC {
+			ctx, cancel := context.WithTimeout(context.Background(), timeout)
+			stopServers(ctx, ss)
+			cancel()
 		}
 	}
 
 	for _, sctx := range e.sctxs {
 		sctx.cancel()
 	}
+
 	for i := range e.Clients {
 		if e.Clients[i] != nil {
 			e.Clients[i].Close()
 		}
 	}
+
 	for i := range e.metricsListeners {
 		e.metricsListeners[i].Close()
 	}
@@ -255,25 +267,38 @@ func (e *Etcd) Close() {
 	}
 }
 
-func (e *Etcd) stopGRPCServer(gs *grpc.Server) {
-	timeout := 2 * time.Second
-	if e.Server != nil {
-		timeout = e.Server.Cfg.ReqTimeout()
+func stopServers(ctx context.Context, ss *servers) {
+	shutdownNow := func() {
+		// first, close the http.Server
+		ss.http.Shutdown(ctx)
+		// then close grpc.Server; cancels all active RPCs
+		ss.grpc.Stop()
+	}
+
+	// do not grpc.Server.GracefulStop with TLS enabled etcd server
+	// See https://github.com/grpc/grpc-go/issues/1384#issuecomment-317124531
+	// and https://github.com/coreos/etcd/issues/8916
+	if ss.secure {
+		shutdownNow()
+		return
 	}
+
 	ch := make(chan struct{})
 	go func() {
 		defer close(ch)
 		// close listeners to stop accepting new connections,
 		// will block on any existing transports
-		gs.GracefulStop()
+		ss.grpc.GracefulStop()
 	}()
+
 	// wait until all pending RPCs are finished
 	select {
 	case <-ch:
-	case <-time.After(timeout):
+	case <-ctx.Done():
 		// took too long, manually close open transports
 		// e.g. watch streams
-		gs.Stop()
+		shutdownNow()
+
 		// concurrent GracefulStop should be interrupted
 		<-ch
 	}
@@ -297,7 +322,9 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) {
 		for i := range peers {
 			if peers[i] != nil && peers[i].close != nil {
 				plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String())
-				peers[i].close(context.Background())
+				ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+				peers[i].close(ctx)
+				cancel()
 			}
 		}
 	}()
@@ -334,6 +361,7 @@ func (e *Etcd) servePeers() (err error) {
 			return err
 		}
 	}
+
 	for _, p := range e.Peers {
 		gs := v3rpc.Server(e.Server, peerTLScfg)
 		m := cmux.New(p.Listener)
@@ -349,8 +377,8 @@ func (e *Etcd) servePeers() (err error) {
 			// gracefully shutdown http.Server
 			// close open listeners, idle connections
 			// until context cancel or time-out
-			e.stopGRPCServer(gs)
-			return srv.Shutdown(ctx)
+			stopServers(ctx, &servers{secure: peerTLScfg != nil, grpc: gs, http: srv})
+			return nil
 		}
 	}
 

+ 13 - 8
embed/serve.go

@@ -54,13 +54,19 @@ type serveCtx struct {
 
 	userHandlers    map[string]http.Handler
 	serviceRegister func(*grpc.Server)
-	grpcServerC     chan *grpc.Server
+	serversC        chan *servers
+}
+
+type servers struct {
+	secure bool
+	grpc   *grpc.Server
+	http   *http.Server
 }
 
 func newServeCtx() *serveCtx {
 	ctx, cancel := context.WithCancel(context.Background())
 	return &serveCtx{ctx: ctx, cancel: cancel, userHandlers: make(map[string]http.Handler),
-		grpcServerC: make(chan *grpc.Server, 2), // in case sctx.insecure,sctx.secure true
+		serversC: make(chan *servers, 2), // in case sctx.insecure,sctx.secure true
 	}
 }
 
@@ -84,7 +90,6 @@ func (sctx *serveCtx) serve(
 
 	if sctx.insecure {
 		gs := v3rpc.Server(s, nil, gopts...)
-		sctx.grpcServerC <- gs
 		v3electionpb.RegisterElectionServer(gs, servElection)
 		v3lockpb.RegisterLockServer(gs, servLock)
 		if sctx.serviceRegister != nil {
@@ -93,9 +98,7 @@ func (sctx *serveCtx) serve(
 		grpcl := m.Match(cmux.HTTP2())
 		go func() { errHandler(gs.Serve(grpcl)) }()
 
-		opts := []grpc.DialOption{
-			grpc.WithInsecure(),
-		}
+		opts := []grpc.DialOption{grpc.WithInsecure()}
 		gwmux, err := sctx.registerGateway(opts)
 		if err != nil {
 			return err
@@ -109,6 +112,8 @@ func (sctx *serveCtx) serve(
 		}
 		httpl := m.Match(cmux.HTTP1())
 		go func() { errHandler(srvhttp.Serve(httpl)) }()
+
+		sctx.serversC <- &servers{grpc: gs, http: srvhttp}
 		plog.Noticef("serving insecure client requests on %s, this is strongly discouraged!", sctx.l.Addr().String())
 	}
 
@@ -118,7 +123,6 @@ func (sctx *serveCtx) serve(
 			return tlsErr
 		}
 		gs := v3rpc.Server(s, tlscfg, gopts...)
-		sctx.grpcServerC <- gs
 		v3electionpb.RegisterElectionServer(gs, servElection)
 		v3lockpb.RegisterLockServer(gs, servLock)
 		if sctx.serviceRegister != nil {
@@ -150,10 +154,11 @@ func (sctx *serveCtx) serve(
 		}
 		go func() { errHandler(srv.Serve(tlsl)) }()
 
+		sctx.serversC <- &servers{secure: true, grpc: gs, http: srv}
 		plog.Infof("serving client requests on %s", sctx.l.Addr().String())
 	}
 
-	close(sctx.grpcServerC)
+	close(sctx.serversC)
 	return m.Serve()
 }
 

+ 27 - 7
integration/embed_test.go

@@ -47,7 +47,7 @@ func TestEmbedEtcd(t *testing.T) {
 		{werr: "expected IP"},
 	}
 
-	urls := newEmbedURLs(10)
+	urls := newEmbedURLs(false, 10)
 
 	// setup defaults
 	for i := range tests {
@@ -105,12 +105,19 @@ func TestEmbedEtcd(t *testing.T) {
 	}
 }
 
-// TestEmbedEtcdGracefulStop ensures embedded server stops
+func TestEmbedEtcdGracefulStopSecure(t *testing.T)   { testEmbedEtcdGracefulStop(t, true) }
+func TestEmbedEtcdGracefulStopInsecure(t *testing.T) { testEmbedEtcdGracefulStop(t, false) }
+
+// testEmbedEtcdGracefulStop ensures embedded server stops
 // cutting existing transports.
-func TestEmbedEtcdGracefulStop(t *testing.T) {
+func testEmbedEtcdGracefulStop(t *testing.T, secure bool) {
 	cfg := embed.NewConfig()
+	if secure {
+		cfg.ClientTLSInfo = testTLSInfo
+		cfg.PeerTLSInfo = testTLSInfo
+	}
 
-	urls := newEmbedURLs(2)
+	urls := newEmbedURLs(secure, 2)
 	setupEmbedCfg(cfg, []url.URL{urls[0]}, []url.URL{urls[1]})
 
 	cfg.Dir = filepath.Join(os.TempDir(), fmt.Sprintf("embed-etcd"))
@@ -123,7 +130,16 @@ func TestEmbedEtcdGracefulStop(t *testing.T) {
 	}
 	<-e.Server.ReadyNotify() // wait for e.Server to join the cluster
 
-	cli, err := clientv3.New(clientv3.Config{Endpoints: []string{urls[0].String()}})
+	clientCfg := clientv3.Config{
+		Endpoints: []string{urls[0].String()},
+	}
+	if secure {
+		clientCfg.TLS, err = testTLSInfo.ClientConfig()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}
+	cli, err := clientv3.New(clientCfg)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -146,9 +162,13 @@ func TestEmbedEtcdGracefulStop(t *testing.T) {
 	}
 }
 
-func newEmbedURLs(n int) (urls []url.URL) {
+func newEmbedURLs(secure bool, n int) (urls []url.URL) {
+	scheme := "unix"
+	if secure {
+		scheme = "unixs"
+	}
 	for i := 0; i < n; i++ {
-		u, _ := url.Parse(fmt.Sprintf("unix://localhost:%d%06d", os.Getpid(), i))
+		u, _ := url.Parse(fmt.Sprintf("%s://localhost:%d%06d", scheme, os.Getpid(), i))
 		urls = append(urls, *u)
 	}
 	return urls