Browse Source

Merge pull request #1339 from coreos/checkcid

etcdserver: checking clusterID
Xiang Li 11 years ago
parent
commit
0398a31b16

+ 25 - 6
etcdserver/cluster_store.go

@@ -22,6 +22,7 @@ import (
 	"fmt"
 	"log"
 	"net/http"
+	"strconv"
 	"time"
 
 	etcdErr "github.com/coreos/etcd/error"
@@ -46,6 +47,9 @@ type ClusterStore interface {
 
 type clusterStore struct {
 	Store store.Store
+	// TODO: write the id into the actual store?
+	// TODO: save the id as string?
+	id uint64
 }
 
 // Add puts a new Member into the store.
@@ -72,6 +76,7 @@ func (s *clusterStore) Add(m Member) {
 // lock here.
 func (s *clusterStore) Get() Cluster {
 	c := NewCluster()
+	c.id = s.id
 	e, err := s.Store.Get(membersKVPrefix, true, true)
 	if err != nil {
 		if v, ok := err.(*etcdErr.Error); ok && v.ErrorCode == etcdErr.EcodeKeyNotFound {
@@ -141,6 +146,7 @@ func Sender(t *http.Transport, cls ClusterStore, ss *stats.ServerStats, ls *stat
 // ClusterStore, retrying up to 3 times for each message. The given
 // ServerStats and LeaderStats are updated appropriately
 func send(c *http.Client, cls ClusterStore, m raftpb.Message, ss *stats.ServerStats, ls *stats.LeaderStats) {
+	cid := cls.Get().ID()
 	// TODO (xiangli): reasonable retry logic
 	for i := 0; i < 3; i++ {
 		u := cls.Get().Pick(m.To)
@@ -167,7 +173,7 @@ func send(c *http.Client, cls ClusterStore, m raftpb.Message, ss *stats.ServerSt
 		fs := ls.Follower(to)
 
 		start := time.Now()
-		sent := httpPost(c, u, data)
+		sent := httpPost(c, u, cid, data)
 		end := time.Now()
 		if sent {
 			fs.Succ(end.Sub(start))
@@ -180,16 +186,29 @@ func send(c *http.Client, cls ClusterStore, m raftpb.Message, ss *stats.ServerSt
 
 // httpPost POSTs a data payload to a url using the given client. Returns true
 // if the POST succeeds, false on any failure.
-func httpPost(c *http.Client, url string, data []byte) bool {
-	resp, err := c.Post(url, "application/protobuf", bytes.NewBuffer(data))
+func httpPost(c *http.Client, url string, cid uint64, data []byte) bool {
+	req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
 	if err != nil {
 		// TODO: log the error?
 		return false
 	}
-	resp.Body.Close()
-	if resp.StatusCode != http.StatusNoContent {
+	req.Header.Set("Content-Type", "application/protobuf")
+	req.Header.Set("X-Etcd-Cluster-ID", strconv.FormatUint(cid, 16))
+	resp, err := c.Do(req)
+	if err != nil {
 		// TODO: log the error?
 		return false
 	}
-	return true
+	resp.Body.Close()
+
+	switch resp.StatusCode {
+	case http.StatusPreconditionFailed:
+		// TODO: shutdown the etcdserver gracefully?
+		log.Panicf("clusterID mismatch")
+		return false
+	case http.StatusNoContent:
+		return true
+	default:
+		return false
+	}
 }

+ 5 - 4
etcdserver/cluster_store_test.go

@@ -92,14 +92,15 @@ func TestClusterStoreGet(t *testing.T) {
 		},
 	}
 	for i, tt := range tests {
-		cs := &clusterStore{Store: newGetAllStore()}
-		for _, m := range tt.mems {
-			cs.Add(m)
-		}
 		c := NewCluster()
 		if err := c.AddSlice(tt.mems); err != nil {
 			t.Fatal(err)
 		}
+		c.GenID(nil)
+		cs := &clusterStore{Store: newGetAllStore(), id: c.id}
+		for _, m := range tt.mems {
+			cs.Add(m)
+		}
 		if g := cs.Get(); !reflect.DeepEqual(&g, c) {
 			t.Errorf("#%d: mems = %v, want %v", i, &g, c)
 		}

+ 11 - 2
etcdserver/etcdhttp/http.go

@@ -80,8 +80,9 @@ func NewClientHandler(server *etcdserver.EtcdServer) http.Handler {
 // NewPeerHandler generates an http.Handler to handle etcd peer (raft) requests.
 func NewPeerHandler(server *etcdserver.EtcdServer) http.Handler {
 	sh := &serverHandler{
-		server: server,
-		stats:  server,
+		server:       server,
+		stats:        server,
+		clusterStore: server.ClusterStore,
 	}
 	mux := http.NewServeMux()
 	mux.HandleFunc(raftPrefix, sh.serveRaft)
@@ -215,6 +216,14 @@ func (h serverHandler) serveRaft(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	gcid := r.Header.Get("X-Etcd-Cluster-ID")
+	wcid := strconv.FormatUint(h.clusterStore.Get().ID(), 16)
+	if gcid != wcid {
+		log.Printf("etcdhttp: request ignored: clusterID mismatch got %s want %x", gcid, wcid)
+		http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed)
+		return
+	}
+
 	b, err := ioutil.ReadAll(r.Body)
 	if err != nil {
 		log.Println("etcdhttp: error reading raft message:", err)

+ 25 - 2
etcdserver/etcdhttp/http_test.go

@@ -862,6 +862,7 @@ func TestServeRaft(t *testing.T) {
 		method    string
 		body      io.Reader
 		serverErr error
+		clusterID string
 
 		wcode int
 	}{
@@ -875,6 +876,7 @@ func TestServeRaft(t *testing.T) {
 				),
 			),
 			nil,
+			"0",
 			http.StatusMethodNotAllowed,
 		},
 		{
@@ -887,6 +889,7 @@ func TestServeRaft(t *testing.T) {
 				),
 			),
 			nil,
+			"0",
 			http.StatusMethodNotAllowed,
 		},
 		{
@@ -899,6 +902,7 @@ func TestServeRaft(t *testing.T) {
 				),
 			),
 			nil,
+			"0",
 			http.StatusMethodNotAllowed,
 		},
 		{
@@ -906,6 +910,7 @@ func TestServeRaft(t *testing.T) {
 			"POST",
 			&errReader{},
 			nil,
+			"0",
 			http.StatusBadRequest,
 		},
 		{
@@ -913,6 +918,7 @@ func TestServeRaft(t *testing.T) {
 			"POST",
 			strings.NewReader("malformed garbage"),
 			nil,
+			"0",
 			http.StatusBadRequest,
 		},
 		{
@@ -925,6 +931,7 @@ func TestServeRaft(t *testing.T) {
 				),
 			),
 			errors.New("some error"),
+			"0",
 			http.StatusInternalServerError,
 		},
 		{
@@ -937,6 +944,20 @@ func TestServeRaft(t *testing.T) {
 				),
 			),
 			nil,
+			"1",
+			http.StatusPreconditionFailed,
+		},
+		{
+			// good request
+			"POST",
+			bytes.NewReader(
+				mustMarshalMsg(
+					t,
+					raftpb.Message{},
+				),
+			),
+			nil,
+			"0",
 			http.StatusNoContent,
 		},
 	}
@@ -945,9 +966,11 @@ func TestServeRaft(t *testing.T) {
 		if err != nil {
 			t.Fatalf("#%d: could not create request: %#v", i, err)
 		}
+		req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID)
 		h := &serverHandler{
-			timeout: time.Hour,
-			server:  &errServer{tt.serverErr},
+			timeout:      time.Hour,
+			server:       &errServer{tt.serverErr},
+			clusterStore: &fakeCluster{},
 		}
 		rw := httptest.NewRecorder()
 		h.serveRaft(rw, req)

+ 1 - 1
etcdserver/server.go

@@ -204,7 +204,7 @@ func NewServer(cfg *ServerConfig) *EtcdServer {
 		id, cid, n, w = restartNode(cfg, index, snapshot)
 	}
 
-	cls := &clusterStore{Store: st}
+	cls := &clusterStore{Store: st, id: cid}
 
 	sstats := &stats.ServerStats{
 		Name: cfg.Name,