Browse Source

rafthttp: stop etcd if it is found removed when stream dial

The original process is stopping etcd only when pipeline message finds itself
has been removed. After this PR, stream dial has this functionality too.
It helps fast etcd stop, which doesn't need to wait for stream break to
fall back to pipeline, and wait for election timeout to send out message
to detect self removal.
Yicheng Qin 10 years ago
parent
commit
1c1cccd236

+ 2 - 0
etcdserver/server.go

@@ -348,6 +348,8 @@ func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
 	return s.r.Step(ctx, m)
 	return s.r.Step(ctx, m)
 }
 }
 
 
+func (s *EtcdServer) IsIDRemoved(id uint64) bool { return s.Cluster.IsIDRemoved(types.ID(id)) }
+
 func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
 func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
 
 
 func (s *EtcdServer) ReportSnapshot(id uint64, status raft.SnapshotStatus) {
 func (s *EtcdServer) ReportSnapshot(id uint64, status raft.SnapshotStatus) {

+ 5 - 2
rafthttp/functional_test.go

@@ -134,8 +134,9 @@ func waitStreamWorking(p *peer) bool {
 }
 }
 
 
 type fakeRaft struct {
 type fakeRaft struct {
-	recvc chan<- raftpb.Message
-	err   error
+	recvc     chan<- raftpb.Message
+	err       error
+	removedID uint64
 }
 }
 
 
 func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
 func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
@@ -146,6 +147,8 @@ func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
 	return p.err
 	return p.err
 }
 }
 
 
+func (p *fakeRaft) IsIDRemoved(id uint64) bool { return id == p.removedID }
+
 func (p *fakeRaft) ReportUnreachable(id uint64) {}
 func (p *fakeRaft) ReportUnreachable(id uint64) {}
 
 
 func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
 func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}

+ 8 - 1
rafthttp/http.go

@@ -46,9 +46,10 @@ type peerGetter interface {
 	Get(id types.ID) Peer
 	Get(id types.ID) Peer
 }
 }
 
 
-func newStreamHandler(peerGetter peerGetter, id, cid types.ID) http.Handler {
+func newStreamHandler(peerGetter peerGetter, r Raft, id, cid types.ID) http.Handler {
 	return &streamHandler{
 	return &streamHandler{
 		peerGetter: peerGetter,
 		peerGetter: peerGetter,
+		r:          r,
 		id:         id,
 		id:         id,
 		cid:        cid,
 		cid:        cid,
 	}
 	}
@@ -112,6 +113,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 
 type streamHandler struct {
 type streamHandler struct {
 	peerGetter peerGetter
 	peerGetter peerGetter
+	r          Raft
 	id         types.ID
 	id         types.ID
 	cid        types.ID
 	cid        types.ID
 }
 }
@@ -145,6 +147,11 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "invalid from", http.StatusNotFound)
 		http.Error(w, "invalid from", http.StatusNotFound)
 		return
 		return
 	}
 	}
+	if h.r.IsIDRemoved(uint64(from)) {
+		log.Printf("rafthttp: reject the stream from peer %s since it was removed", from)
+		http.Error(w, "removed member", http.StatusGone)
+		return
+	}
 	p := h.peerGetter.Get(from)
 	p := h.peerGetter.Get(from)
 	if p == nil {
 	if p == nil {
 		log.Printf("rafthttp: fail to find sender %s", from)
 		log.Printf("rafthttp: fail to find sender %s", from)

+ 13 - 2
rafthttp/http_test.go

@@ -17,6 +17,7 @@ package rafthttp
 import (
 import (
 	"bytes"
 	"bytes"
 	"errors"
 	"errors"
+	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
@@ -185,7 +186,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
 
 
 		peer := newFakePeer()
 		peer := newFakePeer()
 		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
 		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
-		h := newStreamHandler(peerGetter, types.ID(2), types.ID(1))
+		h := newStreamHandler(peerGetter, &fakeRaft{}, types.ID(2), types.ID(1))
 
 
 		rw := httptest.NewRecorder()
 		rw := httptest.NewRecorder()
 		go h.ServeHTTP(rw, req)
 		go h.ServeHTTP(rw, req)
@@ -207,6 +208,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
 }
 }
 
 
 func TestServeRaftStreamPrefixBad(t *testing.T) {
 func TestServeRaftStreamPrefixBad(t *testing.T) {
+	removedID := uint64(5)
 	tests := []struct {
 	tests := []struct {
 		method    string
 		method    string
 		path      string
 		path      string
@@ -263,6 +265,14 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
 			"1",
 			"1",
 			http.StatusNotFound,
 			http.StatusNotFound,
 		},
 		},
+		// removed peer
+		{
+			"GET",
+			RaftStreamPrefix + "/message/" + fmt.Sprint(removedID),
+			"1",
+			"1",
+			http.StatusGone,
+		},
 		// wrong cluster ID
 		// wrong cluster ID
 		{
 		{
 			"GET",
 			"GET",
@@ -289,7 +299,8 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
 		req.Header.Set("X-Raft-To", tt.remote)
 		req.Header.Set("X-Raft-To", tt.remote)
 		rw := httptest.NewRecorder()
 		rw := httptest.NewRecorder()
 		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
 		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
-		h := newStreamHandler(peerGetter, types.ID(1), types.ID(1))
+		r := &fakeRaft{removedID: removedID}
+		h := newStreamHandler(peerGetter, r, types.ID(1), types.ID(1))
 		h.ServeHTTP(rw, req)
 		h.ServeHTTP(rw, req)
 
 
 		if rw.Code != tt.wcode {
 		if rw.Code != tt.wcode {

+ 2 - 2
rafthttp/peer.go

@@ -149,8 +149,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
 
 
 	go func() {
 	go func() {
 		var paused bool
 		var paused bool
-		msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc)
-		reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc)
+		msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc, errorc)
+		reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc, errorc)
 		for {
 		for {
 			select {
 			select {
 			case m := <-p.sendc:
 			case m := <-p.sendc:

+ 15 - 3
rafthttp/stream.go

@@ -226,6 +226,7 @@ type streamReader struct {
 	cid      types.ID
 	cid      types.ID
 	recvc    chan<- raftpb.Message
 	recvc    chan<- raftpb.Message
 	propc    chan<- raftpb.Message
 	propc    chan<- raftpb.Message
+	errorc   chan<- error
 
 
 	mu         sync.Mutex
 	mu         sync.Mutex
 	msgAppTerm uint64
 	msgAppTerm uint64
@@ -235,7 +236,7 @@ type streamReader struct {
 	done       chan struct{}
 	done       chan struct{}
 }
 }
 
 
-func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message) *streamReader {
+func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
 	r := &streamReader{
 	r := &streamReader{
 		tr:     tr,
 		tr:     tr,
 		picker: picker,
 		picker: picker,
@@ -245,6 +246,7 @@ func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, fr
 		cid:    cid,
 		cid:    cid,
 		recvc:  recvc,
 		recvc:  recvc,
 		propc:  propc,
 		propc:  propc,
+		errorc: errorc,
 		stopc:  make(chan struct{}),
 		stopc:  make(chan struct{}),
 		done:   make(chan struct{}),
 		done:   make(chan struct{}),
 	}
 	}
@@ -367,11 +369,21 @@ func (cr *streamReader) dial() (io.ReadCloser, error) {
 		cr.picker.unreachable(u)
 		cr.picker.unreachable(u)
 		return nil, fmt.Errorf("error roundtripping to %s: %v", req.URL, err)
 		return nil, fmt.Errorf("error roundtripping to %s: %v", req.URL, err)
 	}
 	}
-	if resp.StatusCode != http.StatusOK {
+	switch resp.StatusCode {
+	case http.StatusGone:
+		resp.Body.Close()
+		err := fmt.Errorf("the member has been permanently removed from the cluster")
+		select {
+		case cr.errorc <- err:
+		default:
+		}
+		return nil, err
+	case http.StatusOK:
+		return resp.Body, nil
+	default:
 		resp.Body.Close()
 		resp.Body.Close()
 		return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
 		return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
 	}
 	}
-	return resp.Body, nil
 }
 }
 
 
 func (cr *streamReader) cancelRequest() {
 func (cr *streamReader) cancelRequest() {

+ 15 - 9
rafthttp/stream_test.go

@@ -119,15 +119,17 @@ func TestStreamReaderDialRequest(t *testing.T) {
 // HTTP response received.
 // HTTP response received.
 func TestStreamReaderDialResult(t *testing.T) {
 func TestStreamReaderDialResult(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
-		code int
-		err  error
-		wok  bool
+		code  int
+		err   error
+		wok   bool
+		whalt bool
 	}{
 	}{
-		{0, errors.New("blah"), false},
-		{http.StatusOK, nil, true},
-		{http.StatusMethodNotAllowed, nil, false},
-		{http.StatusNotFound, nil, false},
-		{http.StatusPreconditionFailed, nil, false},
+		{0, errors.New("blah"), false, false},
+		{http.StatusOK, nil, true, false},
+		{http.StatusMethodNotAllowed, nil, false, false},
+		{http.StatusNotFound, nil, false, false},
+		{http.StatusPreconditionFailed, nil, false, false},
+		{http.StatusGone, nil, false, true},
 	}
 	}
 	for i, tt := range tests {
 	for i, tt := range tests {
 		tr := newRespRoundTripper(tt.code, tt.err)
 		tr := newRespRoundTripper(tt.code, tt.err)
@@ -138,12 +140,16 @@ func TestStreamReaderDialResult(t *testing.T) {
 			from:   types.ID(1),
 			from:   types.ID(1),
 			to:     types.ID(2),
 			to:     types.ID(2),
 			cid:    types.ID(1),
 			cid:    types.ID(1),
+			errorc: make(chan error, 1),
 		}
 		}
 
 
 		_, err := sr.dial()
 		_, err := sr.dial()
 		if ok := err == nil; ok != tt.wok {
 		if ok := err == nil; ok != tt.wok {
 			t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
 			t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
 		}
 		}
+		if halt := len(sr.errorc) > 0; halt != tt.whalt {
+			t.Errorf("#%d: halt = %v, want %v", i, halt, tt.whalt)
+		}
 	}
 	}
 }
 }
 
 
@@ -203,7 +209,7 @@ func TestStream(t *testing.T) {
 		h.sw = sw
 		h.sw = sw
 
 
 		picker := mustNewURLPicker(t, []string{srv.URL})
 		picker := mustNewURLPicker(t, []string{srv.URL})
-		sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc)
+		sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc, nil)
 		defer sr.stop()
 		defer sr.stop()
 		if tt.t == streamTypeMsgApp {
 		if tt.t == streamTypeMsgApp {
 			sr.updateMsgAppTerm(tt.term)
 			sr.updateMsgAppTerm(tt.term)

+ 2 - 1
rafthttp/transport.go

@@ -28,6 +28,7 @@ import (
 
 
 type Raft interface {
 type Raft interface {
 	Process(ctx context.Context, m raftpb.Message) error
 	Process(ctx context.Context, m raftpb.Message) error
+	IsIDRemoved(id uint64) bool
 	ReportUnreachable(id uint64)
 	ReportUnreachable(id uint64)
 	ReportSnapshot(id uint64, status raft.SnapshotStatus)
 	ReportSnapshot(id uint64, status raft.SnapshotStatus)
 }
 }
@@ -98,7 +99,7 @@ func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan
 
 
 func (t *transport) Handler() http.Handler {
 func (t *transport) Handler() http.Handler {
 	pipelineHandler := NewHandler(t.raft, t.clusterID)
 	pipelineHandler := NewHandler(t.raft, t.clusterID)
-	streamHandler := newStreamHandler(t, t.id, t.clusterID)
+	streamHandler := newStreamHandler(t, t.raft, t.id, t.clusterID)
 	mux := http.NewServeMux()
 	mux := http.NewServeMux()
 	mux.Handle(RaftPrefix, pipelineHandler)
 	mux.Handle(RaftPrefix, pipelineHandler)
 	mux.Handle(RaftStreamPrefix+"/", streamHandler)
 	mux.Handle(RaftStreamPrefix+"/", streamHandler)

+ 2 - 0
rafthttp/transport_bench_test.go

@@ -88,6 +88,8 @@ func (r *countRaft) Process(ctx context.Context, m raftpb.Message) error {
 	return nil
 	return nil
 }
 }
 
 
+func (r *countRaft) IsIDRemoved(id uint64) bool { return false }
+
 func (r *countRaft) ReportUnreachable(id uint64) {}
 func (r *countRaft) ReportUnreachable(id uint64) {}
 
 
 func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
 func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}