瀏覽代碼

Merge pull request #1740 from xiang90/handleheartbeat

raft: add handleHeartbeat
Xiang Li 11 年之前
父節點
當前提交
e07ef6991c
共有 5 個文件被更改,包括 82 次插入8 次删除
  1. 12 6
      raft/log.go
  2. 32 0
      raft/log_test.go
  3. 9 1
      raft/raft.go
  4. 0 1
      raft/raft_paper_test.go
  5. 29 0
      raft/raft_test.go

+ 12 - 6
raft/log.go

@@ -67,11 +67,7 @@ func (l *raftLog) maybeAppend(index, logTerm, committed uint64, ents ...pb.Entry
 		default:
 			l.append(ci-1, ents[ci-from:]...)
 		}
-		tocommit := min(committed, lastnewi)
-		// if toCommit > commitIndex, set commitIndex = toCommit
-		if l.committed < tocommit {
-			l.committed = tocommit
-		}
+		l.commitTo(min(committed, lastnewi))
 		return lastnewi, true
 	}
 	return 0, false
@@ -125,6 +121,16 @@ func (l *raftLog) nextEnts() (ents []pb.Entry) {
 	return nil
 }
 
+func (l *raftLog) commitTo(tocommit uint64) {
+	// never decrease commit
+	if l.committed < tocommit {
+		if l.lastIndex() < tocommit {
+			panic("committed out of range")
+		}
+		l.committed = tocommit
+	}
+}
+
 func (l *raftLog) appliedTo(i uint64) {
 	if i == 0 {
 		return
@@ -179,7 +185,7 @@ func (l *raftLog) matchTerm(i, term uint64) bool {
 
 func (l *raftLog) maybeCommit(maxIndex, term uint64) bool {
 	if maxIndex > l.committed && l.term(maxIndex) == term {
-		l.committed = maxIndex
+		l.commitTo(maxIndex)
 		return true
 	}
 	return false

+ 32 - 0
raft/log_test.go

@@ -386,6 +386,38 @@ func TestUnstableEnts(t *testing.T) {
 	}
 }
 
+func TestCommitTo(t *testing.T) {
+	previousEnts := []pb.Entry{{Term: 1, Index: 1}, {Term: 2, Index: 2}, {Term: 3, Index: 3}}
+	commit := uint64(2)
+	tests := []struct {
+		commit  uint64
+		wcommit uint64
+		wpanic  bool
+	}{
+		{3, 3, false},
+		{1, 2, false}, // never decrease
+		{4, 0, true},  // commit out of range -> panic
+	}
+	for i, tt := range tests {
+		func() {
+			defer func() {
+				if r := recover(); r != nil {
+					if tt.wpanic != true {
+						t.Errorf("%d: panic = %v, want %v", i, true, tt.wpanic)
+					}
+				}
+			}()
+			raftLog := newLog()
+			raftLog.append(0, previousEnts...)
+			raftLog.committed = commit
+			raftLog.commitTo(tt.commit)
+			if raftLog.committed != tt.wcommit {
+				t.Errorf("#%d: committed = %d, want %d", i, raftLog.committed, tt.wcommit)
+			}
+		}()
+	}
+}
+
 func TestStableTo(t *testing.T) {
 	tests := []struct {
 		stable    uint64

+ 9 - 1
raft/raft.go

@@ -387,6 +387,10 @@ func (r *raft) handleAppendEntries(m pb.Message) {
 	}
 }
 
+func (r *raft) handleHeartbeat(m pb.Message) {
+	r.raftLog.commitTo(m.Commit)
+}
+
 func (r *raft) handleSnapshot(m pb.Message) {
 	if r.restore(m.Snapshot) {
 		r.send(pb.Message{To: m.From, Type: pb.MsgAppResp, Index: r.raftLog.lastIndex()})
@@ -482,7 +486,11 @@ func stepFollower(r *raft, m pb.Message) {
 	case pb.MsgApp:
 		r.elapsed = 0
 		r.lead = m.From
-		r.handleAppendEntries(m)
+		if m.LogTerm == 0 && m.Index == 0 && len(m.Entries) == 0 {
+			r.handleHeartbeat(m)
+		} else {
+			r.handleAppendEntries(m)
+		}
 	case pb.MsgSnap:
 		r.elapsed = 0
 		r.handleSnapshot(m)

+ 0 - 1
raft/raft_paper_test.go

@@ -593,7 +593,6 @@ func TestFollowerCheckMsgApp(t *testing.T) {
 		index   uint64
 		wreject bool
 	}{
-		{ents[0].Term, ents[0].Index, false},
 		{ents[1].Term, ents[1].Index, false},
 		{ents[2].Term, ents[2].Index, false},
 		{ents[1].Term, ents[1].Index + 1, true},

+ 29 - 0
raft/raft_test.go

@@ -674,6 +674,35 @@ func TestHandleMsgApp(t *testing.T) {
 	}
 }
 
+// TestHandleHeartbeat ensures that the follower commits to the commit in the message.
+func TestHandleHeartbeat(t *testing.T) {
+	commit := uint64(2)
+	tests := []struct {
+		m       pb.Message
+		wCommit uint64
+	}{
+		{pb.Message{Type: pb.MsgApp, Term: 2, Commit: commit + 1}, commit + 1},
+		{pb.Message{Type: pb.MsgApp, Term: 2, Commit: commit - 1}, commit}, // do not decrease commit
+	}
+
+	for i, tt := range tests {
+		sm := &raft{
+			state:     StateFollower,
+			HardState: pb.HardState{Term: 2},
+			raftLog:   &raftLog{committed: 0, ents: []pb.Entry{{}, {Term: 1}, {Term: 2}, {Term: 3}}},
+		}
+		sm.raftLog.commitTo(commit)
+		sm.handleHeartbeat(tt.m)
+		if sm.raftLog.committed != tt.wCommit {
+			t.Errorf("#%d: committed = %d, want %d", i, sm.raftLog.committed, tt.wCommit)
+		}
+		m := sm.readMessages()
+		if len(m) != 0 {
+			t.Fatalf("#%d: msg = nil, want 0", i)
+		}
+	}
+}
+
 func TestRecvMsgVote(t *testing.T) {
 	tests := []struct {
 		state   StateType