Procházet zdrojové kódy

Merge pull request #2762 from yichengq/343

rafthttp: stop etcd if it is found removed when stream dial
Yicheng Qin před 10 roky
rodič
revize
d080c33c07

+ 2 - 0
etcdserver/server.go

@@ -348,6 +348,8 @@ func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
 	return s.r.Step(ctx, m)
 }
 
+func (s *EtcdServer) IsIDRemoved(id uint64) bool { return s.Cluster.IsIDRemoved(types.ID(id)) }
+
 func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
 
 func (s *EtcdServer) ReportSnapshot(id uint64, status raft.SnapshotStatus) {

+ 5 - 2
rafthttp/functional_test.go

@@ -134,8 +134,9 @@ func waitStreamWorking(p *peer) bool {
 }
 
 type fakeRaft struct {
-	recvc chan<- raftpb.Message
-	err   error
+	recvc     chan<- raftpb.Message
+	err       error
+	removedID uint64
 }
 
 func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
@@ -146,6 +147,8 @@ func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
 	return p.err
 }
 
+func (p *fakeRaft) IsIDRemoved(id uint64) bool { return id == p.removedID }
+
 func (p *fakeRaft) ReportUnreachable(id uint64) {}
 
 func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}

+ 8 - 1
rafthttp/http.go

@@ -46,9 +46,10 @@ type peerGetter interface {
 	Get(id types.ID) Peer
 }
 
-func newStreamHandler(peerGetter peerGetter, id, cid types.ID) http.Handler {
+func newStreamHandler(peerGetter peerGetter, r Raft, id, cid types.ID) http.Handler {
 	return &streamHandler{
 		peerGetter: peerGetter,
+		r:          r,
 		id:         id,
 		cid:        cid,
 	}
@@ -112,6 +113,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 type streamHandler struct {
 	peerGetter peerGetter
+	r          Raft
 	id         types.ID
 	cid        types.ID
 }
@@ -145,6 +147,11 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "invalid from", http.StatusNotFound)
 		return
 	}
+	if h.r.IsIDRemoved(uint64(from)) {
+		log.Printf("rafthttp: reject the stream from peer %s since it was removed", from)
+		http.Error(w, "removed member", http.StatusGone)
+		return
+	}
 	p := h.peerGetter.Get(from)
 	if p == nil {
 		log.Printf("rafthttp: fail to find sender %s", from)

+ 13 - 2
rafthttp/http_test.go

@@ -17,6 +17,7 @@ package rafthttp
 import (
 	"bytes"
 	"errors"
+	"fmt"
 	"io"
 	"net/http"
 	"net/http/httptest"
@@ -185,7 +186,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
 
 		peer := newFakePeer()
 		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
-		h := newStreamHandler(peerGetter, types.ID(2), types.ID(1))
+		h := newStreamHandler(peerGetter, &fakeRaft{}, types.ID(2), types.ID(1))
 
 		rw := httptest.NewRecorder()
 		go h.ServeHTTP(rw, req)
@@ -207,6 +208,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
 }
 
 func TestServeRaftStreamPrefixBad(t *testing.T) {
+	removedID := uint64(5)
 	tests := []struct {
 		method    string
 		path      string
@@ -263,6 +265,14 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
 			"1",
 			http.StatusNotFound,
 		},
+		// removed peer
+		{
+			"GET",
+			RaftStreamPrefix + "/message/" + fmt.Sprint(removedID),
+			"1",
+			"1",
+			http.StatusGone,
+		},
 		// wrong cluster ID
 		{
 			"GET",
@@ -289,7 +299,8 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
 		req.Header.Set("X-Raft-To", tt.remote)
 		rw := httptest.NewRecorder()
 		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
-		h := newStreamHandler(peerGetter, types.ID(1), types.ID(1))
+		r := &fakeRaft{removedID: removedID}
+		h := newStreamHandler(peerGetter, r, types.ID(1), types.ID(1))
 		h.ServeHTTP(rw, req)
 
 		if rw.Code != tt.wcode {

+ 2 - 2
rafthttp/peer.go

@@ -149,8 +149,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
 
 	go func() {
 		var paused bool
-		msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc)
-		reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc)
+		msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc, errorc)
+		reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc, errorc)
 		for {
 			select {
 			case m := <-p.sendc:

+ 15 - 3
rafthttp/stream.go

@@ -226,6 +226,7 @@ type streamReader struct {
 	cid      types.ID
 	recvc    chan<- raftpb.Message
 	propc    chan<- raftpb.Message
+	errorc   chan<- error
 
 	mu         sync.Mutex
 	msgAppTerm uint64
@@ -235,7 +236,7 @@ type streamReader struct {
 	done       chan struct{}
 }
 
-func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message) *streamReader {
+func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
 	r := &streamReader{
 		tr:     tr,
 		picker: picker,
@@ -245,6 +246,7 @@ func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, fr
 		cid:    cid,
 		recvc:  recvc,
 		propc:  propc,
+		errorc: errorc,
 		stopc:  make(chan struct{}),
 		done:   make(chan struct{}),
 	}
@@ -367,11 +369,21 @@ func (cr *streamReader) dial() (io.ReadCloser, error) {
 		cr.picker.unreachable(u)
 		return nil, fmt.Errorf("error roundtripping to %s: %v", req.URL, err)
 	}
-	if resp.StatusCode != http.StatusOK {
+	switch resp.StatusCode {
+	case http.StatusGone:
+		resp.Body.Close()
+		err := fmt.Errorf("the member has been permanently removed from the cluster")
+		select {
+		case cr.errorc <- err:
+		default:
+		}
+		return nil, err
+	case http.StatusOK:
+		return resp.Body, nil
+	default:
 		resp.Body.Close()
 		return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
 	}
-	return resp.Body, nil
 }
 
 func (cr *streamReader) cancelRequest() {

+ 15 - 9
rafthttp/stream_test.go

@@ -119,15 +119,17 @@ func TestStreamReaderDialRequest(t *testing.T) {
 // HTTP response received.
 func TestStreamReaderDialResult(t *testing.T) {
 	tests := []struct {
-		code int
-		err  error
-		wok  bool
+		code  int
+		err   error
+		wok   bool
+		whalt bool
 	}{
-		{0, errors.New("blah"), false},
-		{http.StatusOK, nil, true},
-		{http.StatusMethodNotAllowed, nil, false},
-		{http.StatusNotFound, nil, false},
-		{http.StatusPreconditionFailed, nil, false},
+		{0, errors.New("blah"), false, false},
+		{http.StatusOK, nil, true, false},
+		{http.StatusMethodNotAllowed, nil, false, false},
+		{http.StatusNotFound, nil, false, false},
+		{http.StatusPreconditionFailed, nil, false, false},
+		{http.StatusGone, nil, false, true},
 	}
 	for i, tt := range tests {
 		tr := newRespRoundTripper(tt.code, tt.err)
@@ -138,12 +140,16 @@ func TestStreamReaderDialResult(t *testing.T) {
 			from:   types.ID(1),
 			to:     types.ID(2),
 			cid:    types.ID(1),
+			errorc: make(chan error, 1),
 		}
 
 		_, err := sr.dial()
 		if ok := err == nil; ok != tt.wok {
 			t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
 		}
+		if halt := len(sr.errorc) > 0; halt != tt.whalt {
+			t.Errorf("#%d: halt = %v, want %v", i, halt, tt.whalt)
+		}
 	}
 }
 
@@ -203,7 +209,7 @@ func TestStream(t *testing.T) {
 		h.sw = sw
 
 		picker := mustNewURLPicker(t, []string{srv.URL})
-		sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc)
+		sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc, nil)
 		defer sr.stop()
 		if tt.t == streamTypeMsgApp {
 			sr.updateMsgAppTerm(tt.term)

+ 2 - 1
rafthttp/transport.go

@@ -28,6 +28,7 @@ import (
 
 type Raft interface {
 	Process(ctx context.Context, m raftpb.Message) error
+	IsIDRemoved(id uint64) bool
 	ReportUnreachable(id uint64)
 	ReportSnapshot(id uint64, status raft.SnapshotStatus)
 }
@@ -98,7 +99,7 @@ func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan
 
 func (t *transport) Handler() http.Handler {
 	pipelineHandler := NewHandler(t.raft, t.clusterID)
-	streamHandler := newStreamHandler(t, t.id, t.clusterID)
+	streamHandler := newStreamHandler(t, t.raft, t.id, t.clusterID)
 	mux := http.NewServeMux()
 	mux.Handle(RaftPrefix, pipelineHandler)
 	mux.Handle(RaftStreamPrefix+"/", streamHandler)

+ 2 - 0
rafthttp/transport_bench_test.go

@@ -88,6 +88,8 @@ func (r *countRaft) Process(ctx context.Context, m raftpb.Message) error {
 	return nil
 }
 
+func (r *countRaft) IsIDRemoved(id uint64) bool { return false }
+
 func (r *countRaft) ReportUnreachable(id uint64) {}
 
 func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}