Browse Source

Merge pull request #2407 from yichengq/334

rafthttp: report unreachable status of the peer
Yicheng Qin 10 years ago
parent
commit
9989bf1d36
8 changed files with 43 additions and 22 deletions
  1. 2 0
      etcdserver/server.go
  2. 21 11
      raft/raft.go
  3. 2 0
      rafthttp/http_test.go
  4. 3 3
      rafthttp/peer.go
  5. 4 1
      rafthttp/pipeline.go
  6. 6 6
      rafthttp/pipeline_test.go
  7. 4 1
      rafthttp/stream.go
  8. 1 0
      rafthttp/transport.go

+ 2 - 0
etcdserver/server.go

@@ -326,6 +326,8 @@ func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
 	return s.r.Step(ctx, m)
 	return s.r.Step(ctx, m)
 }
 }
 
 
+func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
+
 func (s *EtcdServer) run() {
 func (s *EtcdServer) run() {
 	var syncC <-chan time.Time
 	var syncC <-chan time.Time
 	var shouldstop bool
 	var shouldstop bool

+ 21 - 11
raft/raft.go

@@ -117,11 +117,12 @@ func (pr *Progress) waitDecr(i int) {
 		pr.Wait = 0
 		pr.Wait = 0
 	}
 	}
 }
 }
-func (pr *Progress) waitSet(w int)    { pr.Wait = w }
-func (pr *Progress) waitReset()       { pr.Wait = 0 }
-func (pr *Progress) reachable()       { pr.Unreachable = false }
-func (pr *Progress) unreachable()     { pr.Unreachable = true }
-func (pr *Progress) shouldWait() bool { return (pr.Unreachable || pr.Match == 0) && pr.Wait > 0 }
+func (pr *Progress) waitSet(w int)       { pr.Wait = w }
+func (pr *Progress) waitReset()          { pr.Wait = 0 }
+func (pr *Progress) isUnreachable() bool { return pr.Unreachable }
+func (pr *Progress) reachable()          { pr.Unreachable = false }
+func (pr *Progress) unreachable()        { pr.Unreachable = true }
+func (pr *Progress) shouldWait() bool    { return (pr.Unreachable || pr.Match == 0) && pr.Wait > 0 }
 
 
 func (pr *Progress) hasPendingSnapshot() bool    { return pr.PendingSnapshot != 0 }
 func (pr *Progress) hasPendingSnapshot() bool    { return pr.PendingSnapshot != 0 }
 func (pr *Progress) setPendingSnapshot(i uint64) { pr.PendingSnapshot = i }
 func (pr *Progress) setPendingSnapshot(i uint64) { pr.PendingSnapshot = i }
@@ -269,7 +270,7 @@ func (r *raft) sendAppend(to uint64) {
 	m := pb.Message{}
 	m := pb.Message{}
 	m.To = to
 	m.To = to
 	if r.needSnapshot(pr.Next) {
 	if r.needSnapshot(pr.Next) {
-		if pr.Unreachable {
+		if pr.isUnreachable() {
 			// do not try to send snapshot until the Progress is
 			// do not try to send snapshot until the Progress is
 			// reachable
 			// reachable
 			return
 			return
@@ -297,9 +298,9 @@ func (r *raft) sendAppend(to uint64) {
 		m.Commit = r.raftLog.committed
 		m.Commit = r.raftLog.committed
 		// optimistically increase the next if the follower
 		// optimistically increase the next if the follower
 		// has been matched.
 		// has been matched.
-		if n := len(m.Entries); pr.Match != 0 && !pr.Unreachable && n != 0 {
+		if n := len(m.Entries); pr.Match != 0 && !pr.isUnreachable() && n != 0 {
 			pr.optimisticUpdate(m.Entries[n-1].Index)
 			pr.optimisticUpdate(m.Entries[n-1].Index)
-		} else if pr.Match == 0 || pr.Unreachable {
+		} else if pr.Match == 0 || pr.isUnreachable() {
 			pr.waitSet(r.heartbeatTimeout)
 			pr.waitSet(r.heartbeatTimeout)
 		}
 		}
 	}
 	}
@@ -535,7 +536,10 @@ func stepLeader(r *raft, m pb.Message) {
 		r.appendEntry(m.Entries...)
 		r.appendEntry(m.Entries...)
 		r.bcastAppend()
 		r.bcastAppend()
 	case pb.MsgAppResp:
 	case pb.MsgAppResp:
-		pr.reachable()
+		if pr.isUnreachable() {
+			pr.reachable()
+			log.Printf("raft: %x received msgAppResp from %x and changed it to be reachable [%s]", r.id, m.From, pr)
+		}
 		if m.Reject {
 		if m.Reject {
 			log.Printf("raft: %x received msgApp rejection(lastindex: %d) from %x for index %d",
 			log.Printf("raft: %x received msgApp rejection(lastindex: %d) from %x for index %d",
 				r.id, m.RejectHint, m.From, m.Index)
 				r.id, m.RejectHint, m.From, m.Index)
@@ -558,7 +562,10 @@ func stepLeader(r *raft, m pb.Message) {
 			}
 			}
 		}
 		}
 	case pb.MsgHeartbeatResp:
 	case pb.MsgHeartbeatResp:
-		pr.reachable()
+		if pr.isUnreachable() {
+			pr.reachable()
+			log.Printf("raft: %x received msgHeartbeatResp from %x and changed it to be reachable [%s]", r.id, m.From, pr)
+		}
 		if pr.Match < r.raftLog.lastIndex() {
 		if pr.Match < r.raftLog.lastIndex() {
 			r.sendAppend(m.From)
 			r.sendAppend(m.From)
 		}
 		}
@@ -581,7 +588,10 @@ func stepLeader(r *raft, m pb.Message) {
 			pr.waitSet(r.electionTimeout)
 			pr.waitSet(r.electionTimeout)
 		}
 		}
 	case pb.MsgUnreachable:
 	case pb.MsgUnreachable:
-		r.prs[m.From].unreachable()
+		if !pr.isUnreachable() {
+			pr.unreachable()
+			log.Printf("raft: %x failed to send message to %x and changed it to be unreachable [%s]", r.id, m.From, pr)
+		}
 	}
 	}
 }
 }
 
 

+ 2 - 0
rafthttp/http_test.go

@@ -162,12 +162,14 @@ func (er *errReader) Read(_ []byte) (int, error) { return 0, errors.New("some er
 type nopProcessor struct{}
 type nopProcessor struct{}
 
 
 func (p *nopProcessor) Process(ctx context.Context, m raftpb.Message) error { return nil }
 func (p *nopProcessor) Process(ctx context.Context, m raftpb.Message) error { return nil }
+func (p *nopProcessor) ReportUnreachable(id uint64)                         {}
 
 
 type errProcessor struct {
 type errProcessor struct {
 	err error
 	err error
 }
 }
 
 
 func (p *errProcessor) Process(ctx context.Context, m raftpb.Message) error { return p.err }
 func (p *errProcessor) Process(ctx context.Context, m raftpb.Message) error { return p.err }
+func (p *errProcessor) ReportUnreachable(id uint64)                         {}
 
 
 type resWriterToError struct {
 type resWriterToError struct {
 	code int
 	code int

+ 3 - 3
rafthttp/peer.go

@@ -65,9 +65,9 @@ type peer struct {
 func startPeer(tr http.RoundTripper, u string, local, to, cid types.ID, r Raft, fs *stats.FollowerStats, errorc chan error) *peer {
 func startPeer(tr http.RoundTripper, u string, local, to, cid types.ID, r Raft, fs *stats.FollowerStats, errorc chan error) *peer {
 	p := &peer{
 	p := &peer{
 		id:           to,
 		id:           to,
-		msgAppWriter: startStreamWriter(fs),
-		writer:       startStreamWriter(fs),
-		pipeline:     newPipeline(tr, u, to, cid, fs, errorc),
+		msgAppWriter: startStreamWriter(fs, r),
+		writer:       startStreamWriter(fs, r),
+		pipeline:     newPipeline(tr, u, to, cid, fs, r, errorc),
 		sendc:        make(chan raftpb.Message),
 		sendc:        make(chan raftpb.Message),
 		recvc:        make(chan raftpb.Message, recvBufSize),
 		recvc:        make(chan raftpb.Message, recvBufSize),
 		newURLc:      make(chan string),
 		newURLc:      make(chan string),

+ 4 - 1
rafthttp/pipeline.go

@@ -45,6 +45,7 @@ type pipeline struct {
 	// the url this pipeline sends to
 	// the url this pipeline sends to
 	u      string
 	u      string
 	fs     *stats.FollowerStats
 	fs     *stats.FollowerStats
+	r      Raft
 	errorc chan error
 	errorc chan error
 
 
 	msgc chan raftpb.Message
 	msgc chan raftpb.Message
@@ -57,13 +58,14 @@ type pipeline struct {
 	errored error
 	errored error
 }
 }
 
 
-func newPipeline(tr http.RoundTripper, u string, id, cid types.ID, fs *stats.FollowerStats, errorc chan error) *pipeline {
+func newPipeline(tr http.RoundTripper, u string, id, cid types.ID, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline {
 	p := &pipeline{
 	p := &pipeline{
 		id:     id,
 		id:     id,
 		cid:    cid,
 		cid:    cid,
 		tr:     tr,
 		tr:     tr,
 		u:      u,
 		u:      u,
 		fs:     fs,
 		fs:     fs,
+		r:      r,
 		errorc: errorc,
 		errorc: errorc,
 		msgc:   make(chan raftpb.Message, pipelineBufSize),
 		msgc:   make(chan raftpb.Message, pipelineBufSize),
 		active: true,
 		active: true,
@@ -102,6 +104,7 @@ func (p *pipeline) handle() {
 			if m.Type == raftpb.MsgApp {
 			if m.Type == raftpb.MsgApp {
 				p.fs.Fail()
 				p.fs.Fail()
 			}
 			}
+			p.r.ReportUnreachable(m.To)
 		} else {
 		} else {
 			if !p.active {
 			if !p.active {
 				log.Printf("pipeline: the connection with %s became active", p.id)
 				log.Printf("pipeline: the connection with %s became active", p.id)

+ 6 - 6
rafthttp/pipeline_test.go

@@ -32,7 +32,7 @@ import (
 func TestPipelineSend(t *testing.T) {
 func TestPipelineSend(t *testing.T) {
 	tr := &roundTripperRecorder{}
 	tr := &roundTripperRecorder{}
 	fs := &stats.FollowerStats{}
 	fs := &stats.FollowerStats{}
-	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), fs, nil)
+	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), fs, &nopProcessor{}, nil)
 
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
 	p.stop()
 	p.stop()
@@ -50,7 +50,7 @@ func TestPipelineSend(t *testing.T) {
 func TestPipelineExceedMaximalServing(t *testing.T) {
 func TestPipelineExceedMaximalServing(t *testing.T) {
 	tr := newRoundTripperBlocker()
 	tr := newRoundTripperBlocker()
 	fs := &stats.FollowerStats{}
 	fs := &stats.FollowerStats{}
-	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), fs, nil)
+	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), fs, &nopProcessor{}, nil)
 
 
 	// keep the sender busy and make the buffer full
 	// keep the sender busy and make the buffer full
 	// nothing can go out as we block the sender
 	// nothing can go out as we block the sender
@@ -89,7 +89,7 @@ func TestPipelineExceedMaximalServing(t *testing.T) {
 // it increases fail count in stats.
 // it increases fail count in stats.
 func TestPipelineSendFailed(t *testing.T) {
 func TestPipelineSendFailed(t *testing.T) {
 	fs := &stats.FollowerStats{}
 	fs := &stats.FollowerStats{}
-	p := newPipeline(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), types.ID(1), fs, nil)
+	p := newPipeline(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), types.ID(1), fs, &nopProcessor{}, nil)
 
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
 	p.stop()
 	p.stop()
@@ -103,7 +103,7 @@ func TestPipelineSendFailed(t *testing.T) {
 
 
 func TestPipelinePost(t *testing.T) {
 func TestPipelinePost(t *testing.T) {
 	tr := &roundTripperRecorder{}
 	tr := &roundTripperRecorder{}
-	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), nil, nil)
+	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), nil, &nopProcessor{}, nil)
 	if err := p.post([]byte("some data")); err != nil {
 	if err := p.post([]byte("some data")); err != nil {
 		t.Fatalf("unexpect post error: %v", err)
 		t.Fatalf("unexpect post error: %v", err)
 	}
 	}
@@ -145,7 +145,7 @@ func TestPipelinePostBad(t *testing.T) {
 		{"http://10.0.0.1", http.StatusCreated, nil},
 		{"http://10.0.0.1", http.StatusCreated, nil},
 	}
 	}
 	for i, tt := range tests {
 	for i, tt := range tests {
-		p := newPipeline(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), types.ID(1), nil, make(chan error))
+		p := newPipeline(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), types.ID(1), nil, &nopProcessor{}, make(chan error))
 		err := p.post([]byte("some data"))
 		err := p.post([]byte("some data"))
 		p.stop()
 		p.stop()
 
 
@@ -166,7 +166,7 @@ func TestPipelinePostErrorc(t *testing.T) {
 	}
 	}
 	for i, tt := range tests {
 	for i, tt := range tests {
 		errorc := make(chan error, 1)
 		errorc := make(chan error, 1)
-		p := newPipeline(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), types.ID(1), nil, errorc)
+		p := newPipeline(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), types.ID(1), nil, &nopProcessor{}, errorc)
 		p.post([]byte("some data"))
 		p.post([]byte("some data"))
 		p.stop()
 		p.stop()
 		select {
 		select {

+ 4 - 1
rafthttp/stream.go

@@ -63,6 +63,7 @@ type outgoingConn struct {
 // attached outgoingConn.
 // attached outgoingConn.
 type streamWriter struct {
 type streamWriter struct {
 	fs *stats.FollowerStats
 	fs *stats.FollowerStats
+	r  Raft
 
 
 	mu      sync.Mutex // guard field working and closer
 	mu      sync.Mutex // guard field working and closer
 	closer  io.Closer
 	closer  io.Closer
@@ -74,9 +75,10 @@ type streamWriter struct {
 	done  chan struct{}
 	done  chan struct{}
 }
 }
 
 
-func startStreamWriter(fs *stats.FollowerStats) *streamWriter {
+func startStreamWriter(fs *stats.FollowerStats, r Raft) *streamWriter {
 	w := &streamWriter{
 	w := &streamWriter{
 		fs:    fs,
 		fs:    fs,
+		r:     r,
 		msgc:  make(chan raftpb.Message, streamBufSize),
 		msgc:  make(chan raftpb.Message, streamBufSize),
 		connc: make(chan *outgoingConn),
 		connc: make(chan *outgoingConn),
 		stopc: make(chan struct{}),
 		stopc: make(chan struct{}),
@@ -118,6 +120,7 @@ func (cw *streamWriter) run() {
 				log.Printf("rafthttp: failed to send message on stream %s due to %v. waiting for a new stream to be established.", t, err)
 				log.Printf("rafthttp: failed to send message on stream %s due to %v. waiting for a new stream to be established.", t, err)
 				cw.resetCloser()
 				cw.resetCloser()
 				heartbeatc, msgc = nil, nil
 				heartbeatc, msgc = nil, nil
+				cw.r.ReportUnreachable(m.To)
 				continue
 				continue
 			}
 			}
 			flusher.Flush()
 			flusher.Flush()

+ 1 - 0
rafthttp/transport.go

@@ -29,6 +29,7 @@ import (
 
 
 type Raft interface {
 type Raft interface {
 	Process(ctx context.Context, m raftpb.Message) error
 	Process(ctx context.Context, m raftpb.Message) error
+	ReportUnreachable(id uint64)
 }
 }
 
 
 type Transporter interface {
 type Transporter interface {