瀏覽代碼

rafthttp: stop etcd if it is found removed when stream dial

The original process is stopping etcd only when pipeline message finds itself
has been removed. After this PR, stream dial has this functionality too.
It helps fast etcd stop, which doesn't need to wait for stream break to
fall back to pipeline, and wait for election timeout to send out message
to detect self removal.
Yicheng Qin 10 年之前
父節點
當前提交
1c1cccd236

+ 2 - 0
etcdserver/server.go

@@ -348,6 +348,8 @@ func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
 	return s.r.Step(ctx, m)
 }
 
+func (s *EtcdServer) IsIDRemoved(id uint64) bool { return s.Cluster.IsIDRemoved(types.ID(id)) }
+
 func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
 
 func (s *EtcdServer) ReportSnapshot(id uint64, status raft.SnapshotStatus) {

+ 5 - 2
rafthttp/functional_test.go

@@ -134,8 +134,9 @@ func waitStreamWorking(p *peer) bool {
 }
 
 type fakeRaft struct {
-	recvc chan<- raftpb.Message
-	err   error
+	recvc     chan<- raftpb.Message
+	err       error
+	removedID uint64
 }
 
 func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
@@ -146,6 +147,8 @@ func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
 	return p.err
 }
 
+func (p *fakeRaft) IsIDRemoved(id uint64) bool { return id == p.removedID }
+
 func (p *fakeRaft) ReportUnreachable(id uint64) {}
 
 func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}

+ 8 - 1
rafthttp/http.go

@@ -46,9 +46,10 @@ type peerGetter interface {
 	Get(id types.ID) Peer
 }
 
-func newStreamHandler(peerGetter peerGetter, id, cid types.ID) http.Handler {
+func newStreamHandler(peerGetter peerGetter, r Raft, id, cid types.ID) http.Handler {
 	return &streamHandler{
 		peerGetter: peerGetter,
+		r:          r,
 		id:         id,
 		cid:        cid,
 	}
@@ -112,6 +113,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 type streamHandler struct {
 	peerGetter peerGetter
+	r          Raft
 	id         types.ID
 	cid        types.ID
 }
@@ -145,6 +147,11 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "invalid from", http.StatusNotFound)
 		return
 	}
+	if h.r.IsIDRemoved(uint64(from)) {
+		log.Printf("rafthttp: reject the stream from peer %s since it was removed", from)
+		http.Error(w, "removed member", http.StatusGone)
+		return
+	}
 	p := h.peerGetter.Get(from)
 	if p == nil {
 		log.Printf("rafthttp: fail to find sender %s", from)

+ 13 - 2
rafthttp/http_test.go

@@ -17,6 +17,7 @@ package rafthttp
 import (
 	"bytes"
 	"errors"
+	"fmt"
 	"io"
 	"net/http"
 	"net/http/httptest"
@@ -185,7 +186,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
 
 		peer := newFakePeer()
 		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
-		h := newStreamHandler(peerGetter, types.ID(2), types.ID(1))
+		h := newStreamHandler(peerGetter, &fakeRaft{}, types.ID(2), types.ID(1))
 
 		rw := httptest.NewRecorder()
 		go h.ServeHTTP(rw, req)
@@ -207,6 +208,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
 }
 
 func TestServeRaftStreamPrefixBad(t *testing.T) {
+	removedID := uint64(5)
 	tests := []struct {
 		method    string
 		path      string
@@ -263,6 +265,14 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
 			"1",
 			http.StatusNotFound,
 		},
+		// removed peer
+		{
+			"GET",
+			RaftStreamPrefix + "/message/" + fmt.Sprint(removedID),
+			"1",
+			"1",
+			http.StatusGone,
+		},
 		// wrong cluster ID
 		{
 			"GET",
@@ -289,7 +299,8 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
 		req.Header.Set("X-Raft-To", tt.remote)
 		rw := httptest.NewRecorder()
 		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
-		h := newStreamHandler(peerGetter, types.ID(1), types.ID(1))
+		r := &fakeRaft{removedID: removedID}
+		h := newStreamHandler(peerGetter, r, types.ID(1), types.ID(1))
 		h.ServeHTTP(rw, req)
 
 		if rw.Code != tt.wcode {

+ 2 - 2
rafthttp/peer.go

@@ -149,8 +149,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
 
 	go func() {
 		var paused bool
-		msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc)
-		reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc)
+		msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc, errorc)
+		reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc, errorc)
 		for {
 			select {
 			case m := <-p.sendc:

+ 15 - 3
rafthttp/stream.go

@@ -226,6 +226,7 @@ type streamReader struct {
 	cid      types.ID
 	recvc    chan<- raftpb.Message
 	propc    chan<- raftpb.Message
+	errorc   chan<- error
 
 	mu         sync.Mutex
 	msgAppTerm uint64
@@ -235,7 +236,7 @@ type streamReader struct {
 	done       chan struct{}
 }
 
-func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message) *streamReader {
+func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
 	r := &streamReader{
 		tr:     tr,
 		picker: picker,
@@ -245,6 +246,7 @@ func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, fr
 		cid:    cid,
 		recvc:  recvc,
 		propc:  propc,
+		errorc: errorc,
 		stopc:  make(chan struct{}),
 		done:   make(chan struct{}),
 	}
@@ -367,11 +369,21 @@ func (cr *streamReader) dial() (io.ReadCloser, error) {
 		cr.picker.unreachable(u)
 		return nil, fmt.Errorf("error roundtripping to %s: %v", req.URL, err)
 	}
-	if resp.StatusCode != http.StatusOK {
+	switch resp.StatusCode {
+	case http.StatusGone:
+		resp.Body.Close()
+		err := fmt.Errorf("the member has been permanently removed from the cluster")
+		select {
+		case cr.errorc <- err:
+		default:
+		}
+		return nil, err
+	case http.StatusOK:
+		return resp.Body, nil
+	default:
 		resp.Body.Close()
 		return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
 	}
-	return resp.Body, nil
 }
 
 func (cr *streamReader) cancelRequest() {

+ 15 - 9
rafthttp/stream_test.go

@@ -119,15 +119,17 @@ func TestStreamReaderDialRequest(t *testing.T) {
 // HTTP response received.
 func TestStreamReaderDialResult(t *testing.T) {
 	tests := []struct {
-		code int
-		err  error
-		wok  bool
+		code  int
+		err   error
+		wok   bool
+		whalt bool
 	}{
-		{0, errors.New("blah"), false},
-		{http.StatusOK, nil, true},
-		{http.StatusMethodNotAllowed, nil, false},
-		{http.StatusNotFound, nil, false},
-		{http.StatusPreconditionFailed, nil, false},
+		{0, errors.New("blah"), false, false},
+		{http.StatusOK, nil, true, false},
+		{http.StatusMethodNotAllowed, nil, false, false},
+		{http.StatusNotFound, nil, false, false},
+		{http.StatusPreconditionFailed, nil, false, false},
+		{http.StatusGone, nil, false, true},
 	}
 	for i, tt := range tests {
 		tr := newRespRoundTripper(tt.code, tt.err)
@@ -138,12 +140,16 @@ func TestStreamReaderDialResult(t *testing.T) {
 			from:   types.ID(1),
 			to:     types.ID(2),
 			cid:    types.ID(1),
+			errorc: make(chan error, 1),
 		}
 
 		_, err := sr.dial()
 		if ok := err == nil; ok != tt.wok {
 			t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
 		}
+		if halt := len(sr.errorc) > 0; halt != tt.whalt {
+			t.Errorf("#%d: halt = %v, want %v", i, halt, tt.whalt)
+		}
 	}
 }
 
@@ -203,7 +209,7 @@ func TestStream(t *testing.T) {
 		h.sw = sw
 
 		picker := mustNewURLPicker(t, []string{srv.URL})
-		sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc)
+		sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc, nil)
 		defer sr.stop()
 		if tt.t == streamTypeMsgApp {
 			sr.updateMsgAppTerm(tt.term)

+ 2 - 1
rafthttp/transport.go

@@ -28,6 +28,7 @@ import (
 
 type Raft interface {
 	Process(ctx context.Context, m raftpb.Message) error
+	IsIDRemoved(id uint64) bool
 	ReportUnreachable(id uint64)
 	ReportSnapshot(id uint64, status raft.SnapshotStatus)
 }
@@ -98,7 +99,7 @@ func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan
 
 func (t *transport) Handler() http.Handler {
 	pipelineHandler := NewHandler(t.raft, t.clusterID)
-	streamHandler := newStreamHandler(t, t.id, t.clusterID)
+	streamHandler := newStreamHandler(t, t.raft, t.id, t.clusterID)
 	mux := http.NewServeMux()
 	mux.Handle(RaftPrefix, pipelineHandler)
 	mux.Handle(RaftStreamPrefix+"/", streamHandler)

+ 2 - 0
rafthttp/transport_bench_test.go

@@ -88,6 +88,8 @@ func (r *countRaft) Process(ctx context.Context, m raftpb.Message) error {
 	return nil
 }
 
+func (r *countRaft) IsIDRemoved(id uint64) bool { return false }
+
 func (r *countRaft) ReportUnreachable(id uint64) {}
 
 func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}