Browse Source

Merge pull request #2289 from fabxc/feature/graceful_shutdown

main: shutdown gracefully.
Xiang Li 11 years ago
parent
commit
b06e43b803
3 changed files with 102 additions and 0 deletions
  1. 5 0
      etcdmain/etcd.go
  2. 54 0
      pkg/osutil/osutil.go
  3. 43 0
      pkg/osutil/osutil_test.go

+ 5 - 0
etcdmain/etcd.go

@@ -31,6 +31,7 @@ import (
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver/etcdhttp"
 	"github.com/coreos/etcd/etcdserver/etcdhttp"
 	"github.com/coreos/etcd/pkg/cors"
 	"github.com/coreos/etcd/pkg/cors"
+	"github.com/coreos/etcd/pkg/osutil"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/proxy"
 	"github.com/coreos/etcd/proxy"
@@ -73,7 +74,10 @@ func Main() {
 		}
 		}
 	}
 	}
 
 
+	osutil.HandleInterrupts()
+
 	<-stopped
 	<-stopped
+	osutil.Exit(0)
 }
 }
 
 
 // startEtcd launches the etcd server and HTTP handlers for client/server communication.
 // startEtcd launches the etcd server and HTTP handlers for client/server communication.
@@ -160,6 +164,7 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 	s.Start()
 	s.Start()
+	osutil.RegisterInterruptHandler(s.Stop)
 
 
 	if cfg.corsInfo.String() != "" {
 	if cfg.corsInfo.String() != "" {
 		log.Printf("etcd: cors = %s", cfg.corsInfo)
 		log.Printf("etcd: cors = %s", cfg.corsInfo)

+ 54 - 0
pkg/osutil/osutil.go

@@ -15,8 +15,12 @@
 package osutil
 package osutil
 
 
 import (
 import (
+	"log"
 	"os"
 	"os"
+	"os/signal"
 	"strings"
 	"strings"
+	"sync"
+	"syscall"
 )
 )
 
 
 func Unsetenv(key string) error {
 func Unsetenv(key string) error {
@@ -33,3 +37,53 @@ func Unsetenv(key string) error {
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+// InterruptHandler is a function that is called on receiving a
+// SIGTERM or SIGINT signal.
+type InterruptHandler func()
+
+var (
+	interruptRegisterMu, interruptExitMu sync.Mutex
+	// interruptHandlers holds all registered InterruptHandlers in order
+	// they will be executed.
+	interruptHandlers = []InterruptHandler{}
+)
+
+// RegisterInterruptHandler registers a new InterruptHandler. Handlers registered
+// after interrupt handing was initiated will not be executed.
+func RegisterInterruptHandler(h InterruptHandler) {
+	interruptRegisterMu.Lock()
+	defer interruptRegisterMu.Unlock()
+	interruptHandlers = append(interruptHandlers, h)
+}
+
+// HandleInterrupts calls the handler functions on receiving a SIGINT or SIGTERM.
+func HandleInterrupts() {
+	notifier := make(chan os.Signal, 1)
+	signal.Notify(notifier, syscall.SIGINT, syscall.SIGTERM)
+
+	go func() {
+		sig := <-notifier
+
+		interruptRegisterMu.Lock()
+		ihs := make([]InterruptHandler, len(interruptHandlers))
+		copy(ihs, interruptHandlers)
+		interruptRegisterMu.Unlock()
+
+		interruptExitMu.Lock()
+
+		log.Printf("received %v signal, shutting down", sig)
+
+		for _, h := range ihs {
+			h()
+		}
+		signal.Stop(notifier)
+		syscall.Kill(syscall.Getpid(), sig.(syscall.Signal))
+	}()
+}
+
+// Exit relays to os.Exit if no interrupt handlers are running, blocks otherwise.
+func Exit(code int) {
+	interruptExitMu.Lock()
+	os.Exit(code)
+}

+ 43 - 0
pkg/osutil/osutil_test.go

@@ -16,8 +16,11 @@ package osutil
 
 
 import (
 import (
 	"os"
 	"os"
+	"os/signal"
 	"reflect"
 	"reflect"
+	"syscall"
 	"testing"
 	"testing"
+	"time"
 )
 )
 
 
 func TestUnsetenv(t *testing.T) {
 func TestUnsetenv(t *testing.T) {
@@ -43,3 +46,43 @@ func TestUnsetenv(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
+	select {
+	case s := <-c:
+		if s != sig {
+			t.Fatalf("signal was %v, want %v", s, sig)
+		}
+	case <-time.After(1 * time.Second):
+		t.Fatalf("timeout waiting for %v", sig)
+	}
+}
+
+func TestHandleInterrupts(t *testing.T) {
+	for _, sig := range []syscall.Signal{syscall.SIGINT, syscall.SIGTERM} {
+		n := 1
+		RegisterInterruptHandler(func() { n++ })
+		RegisterInterruptHandler(func() { n *= 2 })
+
+		c := make(chan os.Signal, 2)
+		signal.Notify(c, sig)
+
+		HandleInterrupts()
+		syscall.Kill(syscall.Getpid(), sig)
+
+		// we should receive the signal once from our own kill and
+		// a second time from HandleInterrupts
+		waitSig(t, c, sig)
+		waitSig(t, c, sig)
+
+		if n == 3 {
+			t.Fatalf("interrupt handlers were called in wrong order")
+		}
+		if n != 4 {
+			t.Fatalf("interrupt handlers were not called properly")
+		}
+		// reset interrupt handlers
+		interruptHandlers = interruptHandlers[:0]
+		interruptExitMu.Unlock()
+	}
+}