Browse Source

Merge pull request #3917 from xiang90/raft_stepdown

raft: support quorum check when raft is leader
Xiang Li 10 years ago
parent
commit
9d7be9ec6f
8 changed files with 134 additions and 25 deletions
  1. 1 1
      raft/multinode_test.go
  2. 9 9
      raft/node_test.go
  3. 6 0
      raft/progress.go
  4. 77 13
      raft/raft.go
  5. 36 1
      raft/raft_test.go
  6. 3 0
      raft/raftpb/raft.pb.go
  7. 1 0
      raft/raftpb/raft.proto
  8. 1 1
      raft/util.go

+ 1 - 1
raft/multinode_test.go

@@ -42,7 +42,7 @@ func TestMultiNodeStep(t *testing.T) {
 				t.Errorf("%d: cannot receive %s on propc chan", msgt, msgn)
 				t.Errorf("%d: cannot receive %s on propc chan", msgt, msgn)
 			}
 			}
 		} else {
 		} else {
-			if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup || msgt == raftpb.MsgUnreachable || msgt == raftpb.MsgSnapStatus {
+			if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup || msgt == raftpb.MsgUnreachable || msgt == raftpb.MsgSnapStatus || msgt == raftpb.MsgCheckQuorum {
 				select {
 				select {
 				case <-mn.recvc:
 				case <-mn.recvc:
 					t.Errorf("%d: step should ignore %s", msgt, msgn)
 					t.Errorf("%d: step should ignore %s", msgt, msgn)

+ 9 - 9
raft/node_test.go

@@ -42,7 +42,7 @@ func TestNodeStep(t *testing.T) {
 				t.Errorf("%d: cannot receive %s on propc chan", msgt, msgn)
 				t.Errorf("%d: cannot receive %s on propc chan", msgt, msgn)
 			}
 			}
 		} else {
 		} else {
-			if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup || msgt == raftpb.MsgUnreachable || msgt == raftpb.MsgSnapStatus {
+			if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup || msgt == raftpb.MsgUnreachable || msgt == raftpb.MsgSnapStatus || msgt == raftpb.MsgCheckQuorum {
 				select {
 				select {
 				case <-n.recvc:
 				case <-n.recvc:
 					t.Errorf("%d: step should ignore %s", msgt, msgn)
 					t.Errorf("%d: step should ignore %s", msgt, msgn)
@@ -225,11 +225,11 @@ func TestNodeTick(t *testing.T) {
 	s := NewMemoryStorage()
 	s := NewMemoryStorage()
 	r := newTestRaft(1, []uint64{1}, 10, 1, s)
 	r := newTestRaft(1, []uint64{1}, 10, 1, s)
 	go n.run(r)
 	go n.run(r)
-	elapsed := r.elapsed
+	elapsed := r.electionElapsed
 	n.Tick()
 	n.Tick()
 	n.Stop()
 	n.Stop()
-	if r.elapsed != elapsed+1 {
-		t.Errorf("elapsed = %d, want %d", r.elapsed, elapsed+1)
+	if r.electionElapsed != elapsed+1 {
+		t.Errorf("elapsed = %d, want %d", r.electionElapsed, elapsed+1)
 	}
 	}
 }
 }
 
 
@@ -246,7 +246,7 @@ func TestNodeStop(t *testing.T) {
 		close(donec)
 		close(donec)
 	}()
 	}()
 
 
-	elapsed := r.elapsed
+	elapsed := r.electionElapsed
 	n.Tick()
 	n.Tick()
 	n.Stop()
 	n.Stop()
 
 
@@ -256,13 +256,13 @@ func TestNodeStop(t *testing.T) {
 		t.Fatalf("timed out waiting for node to stop!")
 		t.Fatalf("timed out waiting for node to stop!")
 	}
 	}
 
 
-	if r.elapsed != elapsed+1 {
-		t.Errorf("elapsed = %d, want %d", r.elapsed, elapsed+1)
+	if r.electionElapsed != elapsed+1 {
+		t.Errorf("elapsed = %d, want %d", r.electionElapsed, elapsed+1)
 	}
 	}
 	// Further ticks should have no effect, the node is stopped.
 	// Further ticks should have no effect, the node is stopped.
 	n.Tick()
 	n.Tick()
-	if r.elapsed != elapsed+1 {
-		t.Errorf("elapsed = %d, want %d", r.elapsed, elapsed+1)
+	if r.electionElapsed != elapsed+1 {
+		t.Errorf("elapsed = %d, want %d", r.electionElapsed, elapsed+1)
 	}
 	}
 	// Subsequent Stops should have no effect.
 	// Subsequent Stops should have no effect.
 	n.Stop()
 	n.Stop()

+ 6 - 0
raft/progress.go

@@ -56,6 +56,11 @@ type Progress struct {
 	// is reported to be failed.
 	// is reported to be failed.
 	PendingSnapshot uint64
 	PendingSnapshot uint64
 
 
+	// recentActive is true if the progress is recently active. Receiving any messages
+	// from the corresponding follower indicates the progress is active.
+	// recentActive can be reset to false after an election timeout.
+	recentActive bool
+
 	// inflights is a sliding window for the inflight messages.
 	// inflights is a sliding window for the inflight messages.
 	// When inflights is full, no more message should be sent.
 	// When inflights is full, no more message should be sent.
 	// When a leader sends out a message, the index of the last
 	// When a leader sends out a message, the index of the last
@@ -68,6 +73,7 @@ type Progress struct {
 
 
 func (pr *Progress) resetState(state ProgressStateType) {
 func (pr *Progress) resetState(state ProgressStateType) {
 	pr.Paused = false
 	pr.Paused = false
+	pr.recentActive = false
 	pr.PendingSnapshot = 0
 	pr.PendingSnapshot = 0
 	pr.State = state
 	pr.State = state
 	pr.ins.reset()
 	pr.ins.reset()

+ 77 - 13
raft/raft.go

@@ -99,6 +99,10 @@ type Config struct {
 	// TODO (xiangli): feedback to application to limit the proposal rate?
 	// TODO (xiangli): feedback to application to limit the proposal rate?
 	MaxInflightMsgs int
 	MaxInflightMsgs int
 
 
+	// CheckQuorum specifies if the leader should check quorum activity. Leader steps down when
+	// quorum is not active for an electionTimeout.
+	CheckQuorum bool
+
 	// logger is the logger used for raft log. For multinode which
 	// logger is the logger used for raft log. For multinode which
 	// can host multiple raft group, each raft group can have its
 	// can host multiple raft group, each raft group can have its
 	// own logger
 	// own logger
@@ -157,7 +161,18 @@ type raft struct {
 	// New configuration is ignored if there exists unapplied configuration.
 	// New configuration is ignored if there exists unapplied configuration.
 	pendingConf bool
 	pendingConf bool
 
 
-	elapsed          int // number of ticks since the last msg
+	// number of ticks since it reached last electionTimeout when it is leader
+	// or candidate.
+	// number of ticks since it reached last electionTimeout or received a
+	// valid message from current leader when it is a follower.
+	electionElapsed int
+
+	// number of ticks since it reached last heartbeatTimeout.
+	// only leader keeps heartbeatElapsed.
+	heartbeatElapsed int
+
+	checkQuorum bool
+
 	heartbeatTimeout int
 	heartbeatTimeout int
 	electionTimeout  int
 	electionTimeout  int
 	rand             *rand.Rand
 	rand             *rand.Rand
@@ -196,6 +211,7 @@ func newRaft(c *Config) *raft {
 		electionTimeout:  c.ElectionTick,
 		electionTimeout:  c.ElectionTick,
 		heartbeatTimeout: c.HeartbeatTick,
 		heartbeatTimeout: c.HeartbeatTick,
 		logger:           c.Logger,
 		logger:           c.Logger,
+		checkQuorum:      c.CheckQuorum,
 	}
 	}
 	r.rand = rand.New(rand.NewSource(int64(c.ID)))
 	r.rand = rand.New(rand.NewSource(int64(c.ID)))
 	for _, p := range peers {
 	for _, p := range peers {
@@ -356,7 +372,10 @@ func (r *raft) reset(term uint64) {
 		r.Vote = None
 		r.Vote = None
 	}
 	}
 	r.lead = None
 	r.lead = None
-	r.elapsed = 0
+
+	r.electionElapsed = 0
+	r.heartbeatElapsed = 0
+
 	r.votes = make(map[uint64]bool)
 	r.votes = make(map[uint64]bool)
 	for i := range r.prs {
 	for i := range r.prs {
 		r.prs[i] = &Progress{Next: r.raftLog.lastIndex() + 1, ins: newInflights(r.maxInflight)}
 		r.prs[i] = &Progress{Next: r.raftLog.lastIndex() + 1, ins: newInflights(r.maxInflight)}
@@ -381,21 +400,34 @@ func (r *raft) appendEntry(es ...pb.Entry) {
 // tickElection is run by followers and candidates after r.electionTimeout.
 // tickElection is run by followers and candidates after r.electionTimeout.
 func (r *raft) tickElection() {
 func (r *raft) tickElection() {
 	if !r.promotable() {
 	if !r.promotable() {
-		r.elapsed = 0
+		r.electionElapsed = 0
 		return
 		return
 	}
 	}
-	r.elapsed++
+	r.electionElapsed++
 	if r.isElectionTimeout() {
 	if r.isElectionTimeout() {
-		r.elapsed = 0
+		r.electionElapsed = 0
 		r.Step(pb.Message{From: r.id, Type: pb.MsgHup})
 		r.Step(pb.Message{From: r.id, Type: pb.MsgHup})
 	}
 	}
 }
 }
 
 
 // tickHeartbeat is run by leaders to send a MsgBeat after r.heartbeatTimeout.
 // tickHeartbeat is run by leaders to send a MsgBeat after r.heartbeatTimeout.
 func (r *raft) tickHeartbeat() {
 func (r *raft) tickHeartbeat() {
-	r.elapsed++
-	if r.elapsed >= r.heartbeatTimeout {
-		r.elapsed = 0
+	r.heartbeatElapsed++
+	r.electionElapsed++
+
+	if r.electionElapsed >= r.electionTimeout {
+		r.electionElapsed = 0
+		if r.checkQuorum {
+			r.Step(pb.Message{From: r.id, Type: pb.MsgCheckQuorum})
+		}
+	}
+
+	if r.state != StateLeader {
+		return
+	}
+
+	if r.heartbeatElapsed >= r.heartbeatTimeout {
+		r.heartbeatElapsed = 0
 		r.Step(pb.Message{From: r.id, Type: pb.MsgBeat})
 		r.Step(pb.Message{From: r.id, Type: pb.MsgBeat})
 	}
 	}
 }
 }
@@ -525,6 +557,11 @@ func stepLeader(r *raft, m pb.Message) {
 	switch m.Type {
 	switch m.Type {
 	case pb.MsgBeat:
 	case pb.MsgBeat:
 		r.bcastHeartbeat()
 		r.bcastHeartbeat()
+	case pb.MsgCheckQuorum:
+		if !r.checkQuorumActive() {
+			r.logger.Warningf("%x stepped down to follower since quorum is not active", r.id)
+			r.becomeFollower(r.Term, None)
+		}
 	case pb.MsgProp:
 	case pb.MsgProp:
 		if len(m.Entries) == 0 {
 		if len(m.Entries) == 0 {
 			r.logger.Panicf("%x stepped empty MsgProp", r.id)
 			r.logger.Panicf("%x stepped empty MsgProp", r.id)
@@ -546,6 +583,8 @@ func stepLeader(r *raft, m pb.Message) {
 		r.appendEntry(m.Entries...)
 		r.appendEntry(m.Entries...)
 		r.bcastAppend()
 		r.bcastAppend()
 	case pb.MsgAppResp:
 	case pb.MsgAppResp:
+		pr.recentActive = true
+
 		if m.Reject {
 		if m.Reject {
 			r.logger.Debugf("%x received msgApp rejection(lastindex: %d) from %x for index %d",
 			r.logger.Debugf("%x received msgApp rejection(lastindex: %d) from %x for index %d",
 				r.id, m.RejectHint, m.From, m.Index)
 				r.id, m.RejectHint, m.From, m.Index)
@@ -579,6 +618,8 @@ func stepLeader(r *raft, m pb.Message) {
 			}
 			}
 		}
 		}
 	case pb.MsgHeartbeatResp:
 	case pb.MsgHeartbeatResp:
+		pr.recentActive = true
+
 		// free one slot for the full inflights window to allow progress.
 		// free one slot for the full inflights window to allow progress.
 		if pr.State == ProgressStateReplicate && pr.ins.full() {
 		if pr.State == ProgressStateReplicate && pr.ins.full() {
 			pr.ins.freeFirstOne()
 			pr.ins.freeFirstOne()
@@ -657,19 +698,19 @@ func stepFollower(r *raft, m pb.Message) {
 		m.To = r.lead
 		m.To = r.lead
 		r.send(m)
 		r.send(m)
 	case pb.MsgApp:
 	case pb.MsgApp:
-		r.elapsed = 0
+		r.electionElapsed = 0
 		r.lead = m.From
 		r.lead = m.From
 		r.handleAppendEntries(m)
 		r.handleAppendEntries(m)
 	case pb.MsgHeartbeat:
 	case pb.MsgHeartbeat:
-		r.elapsed = 0
+		r.electionElapsed = 0
 		r.lead = m.From
 		r.lead = m.From
 		r.handleHeartbeat(m)
 		r.handleHeartbeat(m)
 	case pb.MsgSnap:
 	case pb.MsgSnap:
-		r.elapsed = 0
+		r.electionElapsed = 0
 		r.handleSnapshot(m)
 		r.handleSnapshot(m)
 	case pb.MsgVote:
 	case pb.MsgVote:
 		if (r.Vote == None || r.Vote == m.From) && r.raftLog.isUpToDate(m.Index, m.LogTerm) {
 		if (r.Vote == None || r.Vote == m.From) && r.raftLog.isUpToDate(m.Index, m.LogTerm) {
-			r.elapsed = 0
+			r.electionElapsed = 0
 			r.logger.Infof("%x [logterm: %d, index: %d, vote: %x] voted for %x [logterm: %d, index: %d] at term %d",
 			r.logger.Infof("%x [logterm: %d, index: %d, vote: %x] voted for %x [logterm: %d, index: %d] at term %d",
 				r.id, r.raftLog.lastTerm(), r.raftLog.lastIndex(), r.Vote, m.From, m.LogTerm, m.Index, r.Term)
 				r.id, r.raftLog.lastTerm(), r.raftLog.lastIndex(), r.Vote, m.From, m.LogTerm, m.Index, r.Term)
 			r.Vote = m.From
 			r.Vote = m.From
@@ -793,9 +834,32 @@ func (r *raft) loadState(state pb.HardState) {
 // randomized election timeout in (electiontimeout, 2 * electiontimeout - 1).
 // randomized election timeout in (electiontimeout, 2 * electiontimeout - 1).
 // Otherwise, it returns false.
 // Otherwise, it returns false.
 func (r *raft) isElectionTimeout() bool {
 func (r *raft) isElectionTimeout() bool {
-	d := r.elapsed - r.electionTimeout
+	d := r.electionElapsed - r.electionTimeout
 	if d < 0 {
 	if d < 0 {
 		return false
 		return false
 	}
 	}
 	return d > r.rand.Int()%r.electionTimeout
 	return d > r.rand.Int()%r.electionTimeout
 }
 }
+
+// checkQuorumActive returns true if the quorum is active from
+// the view of the local raft state machine. Otherwise, it returns
+// false.
+// checkQuorumActive also reset all recentActive to false.
+func (r *raft) checkQuorumActive() bool {
+	var act int
+
+	for id := range r.prs {
+		if id == r.id { // self is always active
+			act += 1
+			continue
+		}
+
+		if r.prs[id].recentActive {
+			act += 1
+		}
+
+		r.prs[id].recentActive = false
+	}
+
+	return act >= r.q()
+}

+ 36 - 1
raft/raft_test.go

@@ -762,7 +762,7 @@ func TestIsElectionTimeout(t *testing.T) {
 
 
 	for i, tt := range tests {
 	for i, tt := range tests {
 		sm := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 		sm := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
-		sm.elapsed = tt.elapse
+		sm.electionElapsed = tt.elapse
 		c := 0
 		c := 0
 		for j := 0; j < 10000; j++ {
 		for j := 0; j < 10000; j++ {
 			if sm.isElectionTimeout() {
 			if sm.isElectionTimeout() {
@@ -1172,6 +1172,41 @@ func TestAllServerStepdown(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestLeaderStepdownWhenQuorumActive(t *testing.T) {
+	sm := newTestRaft(1, []uint64{1, 2, 3}, 5, 1, NewMemoryStorage())
+
+	sm.checkQuorum = true
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	for i := 0; i < sm.electionTimeout+1; i++ {
+		sm.Step(pb.Message{From: 2, Type: pb.MsgHeartbeatResp, Term: sm.Term})
+		sm.tick()
+	}
+
+	if sm.state != StateLeader {
+		t.Errorf("state = %v, want %v", sm.state, StateLeader)
+	}
+}
+
+func TestLeaderStepdownWhenQuorumLost(t *testing.T) {
+	sm := newTestRaft(1, []uint64{1, 2, 3}, 5, 1, NewMemoryStorage())
+
+	sm.checkQuorum = true
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	for i := 0; i < sm.electionTimeout+1; i++ {
+		sm.tick()
+	}
+
+	if sm.state != StateFollower {
+		t.Errorf("state = %v, want %v", sm.state, StateFollower)
+	}
+}
+
 func TestLeaderAppResp(t *testing.T) {
 func TestLeaderAppResp(t *testing.T) {
 	// initial progress: match = 0; next = 3
 	// initial progress: match = 0; next = 3
 	tests := []struct {
 	tests := []struct {

+ 3 - 0
raft/raftpb/raft.pb.go

@@ -79,6 +79,7 @@ const (
 	MsgHeartbeatResp MessageType = 9
 	MsgHeartbeatResp MessageType = 9
 	MsgUnreachable   MessageType = 10
 	MsgUnreachable   MessageType = 10
 	MsgSnapStatus    MessageType = 11
 	MsgSnapStatus    MessageType = 11
+	MsgCheckQuorum   MessageType = 12
 )
 )
 
 
 var MessageType_name = map[int32]string{
 var MessageType_name = map[int32]string{
@@ -94,6 +95,7 @@ var MessageType_name = map[int32]string{
 	9:  "MsgHeartbeatResp",
 	9:  "MsgHeartbeatResp",
 	10: "MsgUnreachable",
 	10: "MsgUnreachable",
 	11: "MsgSnapStatus",
 	11: "MsgSnapStatus",
+	12: "MsgCheckQuorum",
 }
 }
 var MessageType_value = map[string]int32{
 var MessageType_value = map[string]int32{
 	"MsgHup":           0,
 	"MsgHup":           0,
@@ -108,6 +110,7 @@ var MessageType_value = map[string]int32{
 	"MsgHeartbeatResp": 9,
 	"MsgHeartbeatResp": 9,
 	"MsgUnreachable":   10,
 	"MsgUnreachable":   10,
 	"MsgSnapStatus":    11,
 	"MsgSnapStatus":    11,
+	"MsgCheckQuorum":   12,
 }
 }
 
 
 func (x MessageType) Enum() *MessageType {
 func (x MessageType) Enum() *MessageType {

+ 1 - 0
raft/raftpb/raft.proto

@@ -45,6 +45,7 @@ enum MessageType {
 	MsgHeartbeatResp   = 9;
 	MsgHeartbeatResp   = 9;
 	MsgUnreachable     = 10;
 	MsgUnreachable     = 10;
 	MsgSnapStatus      = 11;
 	MsgSnapStatus      = 11;
+	MsgCheckQuorum     = 12;
 }
 }
 
 
 message Message {
 message Message {

+ 1 - 1
raft/util.go

@@ -47,7 +47,7 @@ func max(a, b uint64) uint64 {
 }
 }
 
 
 func IsLocalMsg(m pb.Message) bool {
 func IsLocalMsg(m pb.Message) bool {
-	return m.Type == pb.MsgHup || m.Type == pb.MsgBeat || m.Type == pb.MsgUnreachable || m.Type == pb.MsgSnapStatus
+	return m.Type == pb.MsgHup || m.Type == pb.MsgBeat || m.Type == pb.MsgUnreachable || m.Type == pb.MsgSnapStatus || m.Type == pb.MsgCheckQuorum
 }
 }
 
 
 func IsResponseMsg(m pb.Message) bool {
 func IsResponseMsg(m pb.Message) bool {