Sfoglia il codice sorgente

embed: fix gRPC server panic on GracefulStop

Cherry-pick https://github.com/coreos/etcd/pull/8987.

Signed-off-by: Gyuho Lee <gyuhox@gmail.com>
Gyuho Lee 8 anni fa
parent
commit
288ef7d6fc
2 ha cambiato i file con 128 aggiunte e 65 eliminazioni
  1. 111 56
      embed/etcd.go
  2. 17 9
      embed/serve.go

+ 111 - 56
embed/etcd.go

@@ -29,12 +29,15 @@ import (
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver/api/etcdhttp"
 	"github.com/coreos/etcd/etcdserver/api/v2http"
+	"github.com/coreos/etcd/etcdserver/api/v3rpc"
 	"github.com/coreos/etcd/pkg/cors"
 	"github.com/coreos/etcd/pkg/debugutil"
 	runtimeutil "github.com/coreos/etcd/pkg/runtime"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/rafthttp"
+
+	"github.com/cockroachdb/cmux"
 	"github.com/coreos/pkg/capnslog"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/keepalive"
@@ -60,12 +63,14 @@ const (
 type Etcd struct {
 	Peers   []*peerListener
 	Clients []net.Listener
-	Server  *etcdserver.EtcdServer
+	// a map of contexts for the servers that serves client requests.
+	sctxs map[string]*serveCtx
+
+	Server *etcdserver.EtcdServer
 
 	cfg   Config
 	stopc chan struct{}
 	errc  chan error
-	sctxs map[string]*serveCtx
 
 	closeOnce sync.Once
 }
@@ -91,9 +96,9 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
 			return
 		}
 		if !serving {
-			// errored before starting gRPC server for serveCtx.grpcServerC
+			// errored before starting gRPC server for serveCtx.serversC
 			for _, sctx := range e.sctxs {
-				close(sctx.grpcServerC)
+				close(sctx.serversC)
 			}
 		}
 		e.Close()
@@ -101,10 +106,10 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
 	}()
 
 	if e.Peers, err = startPeerListeners(cfg); err != nil {
-		return
+		return e, err
 	}
 	if e.sctxs, err = startClientListeners(cfg); err != nil {
-		return
+		return e, err
 	}
 	for _, sctx := range e.sctxs {
 		e.Clients = append(e.Clients, sctx.l)
@@ -150,37 +155,23 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
 	}
 
 	if e.Server, err = etcdserver.NewServer(srvcfg); err != nil {
-		return
-	}
-
-	// configure peer handlers after rafthttp.Transport started
-	ph := etcdhttp.NewPeerHandler(e.Server)
-	for _, p := range e.Peers {
-		srv := &http.Server{
-			Handler:     ph,
-			ReadTimeout: 5 * time.Minute,
-			ErrorLog:    defaultLog.New(ioutil.Discard, "", 0), // do not log user error
-		}
-
-		l := p.Listener
-		p.serve = func() error { return srv.Serve(l) }
-		p.close = func(ctx context.Context) error {
-			// gracefully shutdown http.Server
-			// close open listeners, idle connections
-			// until context cancel or time-out
-			return srv.Shutdown(ctx)
-		}
+		return e, err
 	}
 
 	// buffer channel so goroutines on closed connections won't wait forever
 	e.errc = make(chan error, len(e.Peers)+len(e.Clients)+2*len(e.sctxs))
 
 	e.Server.Start()
-	if err = e.serve(); err != nil {
-		return
+
+	if err = e.servePeers(); err != nil {
+		return e, err
 	}
+	if err = e.serveClients(); err != nil {
+		return e, err
+	}
+
 	serving = true
-	return
+	return e, nil
 }
 
 // Config returns the current configuration.
@@ -188,38 +179,29 @@ func (e *Etcd) Config() Config {
 	return e.cfg
 }
 
+// Close gracefully shuts down all servers/listeners.
+// Client requests will be terminated with request timeout.
+// After timeout, enforce remaning requests be closed immediately.
 func (e *Etcd) Close() {
 	e.closeOnce.Do(func() { close(e.stopc) })
 
+	// close client requests with request timeout
 	timeout := 2 * time.Second
 	if e.Server != nil {
 		timeout = e.Server.Cfg.ReqTimeout()
 	}
 	for _, sctx := range e.sctxs {
-		for gs := range sctx.grpcServerC {
-			ch := make(chan struct{})
-			go func() {
-				defer close(ch)
-				// close listeners to stop accepting new connections,
-				// will block on any existing transports
-				gs.GracefulStop()
-			}()
-			// wait until all pending RPCs are finished
-			select {
-			case <-ch:
-			case <-time.After(timeout):
-				// took too long, manually close open transports
-				// e.g. watch streams
-				gs.Stop()
-				// concurrent GracefulStop should be interrupted
-				<-ch
-			}
+		for ss := range sctx.serversC {
+			ctx, cancel := context.WithTimeout(context.Background(), timeout)
+			stopServers(ctx, ss)
+			cancel()
 		}
 	}
 
 	for _, sctx := range e.sctxs {
 		sctx.cancel()
 	}
+
 	for i := range e.Clients {
 		if e.Clients[i] != nil {
 			e.Clients[i].Close()
@@ -241,6 +223,43 @@ func (e *Etcd) Close() {
 	}
 }
 
+func stopServers(ctx context.Context, ss *servers) {
+	shutdownNow := func() {
+		// first, close the http.Server
+		ss.http.Shutdown(ctx)
+		// then close grpc.Server; cancels all active RPCs
+		ss.grpc.Stop()
+	}
+
+	// do not grpc.Server.GracefulStop with TLS enabled etcd server
+	// See https://github.com/grpc/grpc-go/issues/1384#issuecomment-317124531
+	// and https://github.com/coreos/etcd/issues/8916
+	if ss.secure {
+		shutdownNow()
+		return
+	}
+
+	ch := make(chan struct{})
+	go func() {
+		defer close(ch)
+		// close listeners to stop accepting new connections,
+		// will block on any existing transports
+		ss.grpc.GracefulStop()
+	}()
+
+	// wait until all pending RPCs are finished
+	select {
+	case <-ch:
+	case <-ctx.Done():
+		// took too long, manually close open transports
+		// e.g. watch streams
+		shutdownNow()
+
+		// concurrent GracefulStop should be interrupted
+		<-ch
+	}
+}
+
 func (e *Etcd) Err() <-chan error { return e.errc }
 
 func startPeerListeners(cfg *Config) (peers []*peerListener, err error) {
@@ -269,7 +288,9 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) {
 		for i := range peers {
 			if peers[i] != nil && peers[i].close != nil {
 				plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String())
-				peers[i].close(context.Background())
+				ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+				peers[i].close(ctx)
+				cancel()
 			}
 		}
 	}()
@@ -297,6 +318,45 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) {
 	return peers, nil
 }
 
+// configure peer handlers after rafthttp.Transport started
+func (e *Etcd) servePeers() (err error) {
+	ph := etcdhttp.NewPeerHandler(e.Server)
+	var peerTLScfg *tls.Config
+	if !e.cfg.PeerTLSInfo.Empty() {
+		if peerTLScfg, err = e.cfg.PeerTLSInfo.ServerConfig(); err != nil {
+			return err
+		}
+	}
+
+	for _, p := range e.Peers {
+		gs := v3rpc.Server(e.Server, peerTLScfg)
+		m := cmux.New(p.Listener)
+		go gs.Serve(m.Match(cmux.HTTP2()))
+		srv := &http.Server{
+			Handler:     grpcHandlerFunc(gs, ph),
+			ReadTimeout: 5 * time.Minute,
+			ErrorLog:    defaultLog.New(ioutil.Discard, "", 0), // do not log user error
+		}
+		go srv.Serve(m.Match(cmux.Any()))
+		p.serve = func() error { return m.Serve() }
+		p.close = func(ctx context.Context) error {
+			// gracefully shutdown http.Server
+			// close open listeners, idle connections
+			// until context cancel or time-out
+			stopServers(ctx, &servers{secure: peerTLScfg != nil, grpc: gs, http: srv})
+			return nil
+		}
+	}
+
+	// start peer servers in a goroutine
+	for _, pl := range e.Peers {
+		go func(l *peerListener) {
+			e.errHandler(l.serve())
+		}(pl)
+	}
+	return nil
+}
+
 func startClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err error) {
 	if cfg.ClientAutoTLS && cfg.ClientTLSInfo.Empty() {
 		chosts := make([]string, len(cfg.LCUrls))
@@ -388,7 +448,7 @@ func startClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err error) {
 	return sctxs, nil
 }
 
-func (e *Etcd) serve() (err error) {
+func (e *Etcd) serveClients() (err error) {
 	var ctlscfg *tls.Config
 	if !e.cfg.ClientTLSInfo.Empty() {
 		plog.Infof("ClientTLS: %s", e.cfg.ClientTLSInfo)
@@ -401,13 +461,6 @@ func (e *Etcd) serve() (err error) {
 		plog.Infof("cors = %s", e.cfg.CorsInfo)
 	}
 
-	// Start the peer server in a goroutine
-	for _, pl := range e.Peers {
-		go func(l *peerListener) {
-			e.errHandler(l.serve())
-		}(pl)
-	}
-
 	// Start a client server goroutine for each listen address
 	var h http.Handler
 	if e.Config().EnableV2 {
@@ -433,6 +486,8 @@ func (e *Etcd) serve() (err error) {
 			Timeout: e.cfg.GRPCKeepAliveTimeout,
 		}))
 	}
+
+	// start client servers in a goroutine
 	for _, sctx := range e.sctxs {
 		go func(s *serveCtx) {
 			e.errHandler(s.serve(e.Server, ctlscfg, h, e.errHandler, gopts...))

+ 17 - 9
embed/serve.go

@@ -53,13 +53,22 @@ type serveCtx struct {
 
 	userHandlers    map[string]http.Handler
 	serviceRegister func(*grpc.Server)
-	grpcServerC     chan *grpc.Server
+	serversC        chan *servers
+}
+
+type servers struct {
+	secure bool
+	grpc   *grpc.Server
+	http   *http.Server
 }
 
 func newServeCtx() *serveCtx {
 	ctx, cancel := context.WithCancel(context.Background())
-	return &serveCtx{ctx: ctx, cancel: cancel, userHandlers: make(map[string]http.Handler),
-		grpcServerC: make(chan *grpc.Server, 2), // in case sctx.insecure,sctx.secure true
+	return &serveCtx{
+		ctx:          ctx,
+		cancel:       cancel,
+		userHandlers: make(map[string]http.Handler),
+		serversC:     make(chan *servers, 2), // in case sctx.insecure,sctx.secure true
 	}
 }
 
@@ -83,7 +92,6 @@ func (sctx *serveCtx) serve(
 
 	if sctx.insecure {
 		gs := v3rpc.Server(s, nil, gopts...)
-		sctx.grpcServerC <- gs
 		v3electionpb.RegisterElectionServer(gs, servElection)
 		v3lockpb.RegisterLockServer(gs, servLock)
 		if sctx.serviceRegister != nil {
@@ -92,9 +100,7 @@ func (sctx *serveCtx) serve(
 		grpcl := m.Match(cmux.HTTP2())
 		go func() { errHandler(gs.Serve(grpcl)) }()
 
-		opts := []grpc.DialOption{
-			grpc.WithInsecure(),
-		}
+		opts := []grpc.DialOption{grpc.WithInsecure()}
 		gwmux, err := sctx.registerGateway(opts)
 		if err != nil {
 			return err
@@ -108,12 +114,13 @@ func (sctx *serveCtx) serve(
 		}
 		httpl := m.Match(cmux.HTTP1())
 		go func() { errHandler(srvhttp.Serve(httpl)) }()
+
+		sctx.serversC <- &servers{grpc: gs, http: srvhttp}
 		plog.Noticef("serving insecure client requests on %s, this is strongly discouraged!", sctx.l.Addr().String())
 	}
 
 	if sctx.secure {
 		gs := v3rpc.Server(s, tlscfg, gopts...)
-		sctx.grpcServerC <- gs
 		v3electionpb.RegisterElectionServer(gs, servElection)
 		v3lockpb.RegisterLockServer(gs, servLock)
 		if sctx.serviceRegister != nil {
@@ -142,10 +149,11 @@ func (sctx *serveCtx) serve(
 		}
 		go func() { errHandler(srv.Serve(tlsl)) }()
 
+		sctx.serversC <- &servers{secure: true, grpc: gs, http: srv}
 		plog.Infof("serving client requests on %s", sctx.l.Addr().String())
 	}
 
-	close(sctx.grpcServerC)
+	close(sctx.serversC)
 	return m.Serve()
 }