Browse Source

raft: check voteFor

Xiang Li 11 years ago
parent
commit
93b08502e4
2 changed files with 37 additions and 23 deletions
  1. 12 3
      raft/raft.go
  2. 25 20
      raft/raft_test.go

+ 12 - 3
raft/raft.go

@@ -288,6 +288,8 @@ func (sm *stateMachine) Step(m Message) {
 		case msgApp:
 		case msgApp:
 			sm.becomeFollower(sm.term, m.From)
 			sm.becomeFollower(sm.term, m.From)
 			handleAppendEntries()
 			handleAppendEntries()
+		case msgVote:
+			sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1})
 		case msgVoteResp:
 		case msgVoteResp:
 			gr := sm.poll(m.From, m.Index >= 0)
 			gr := sm.poll(m.From, m.Index >= 0)
 			switch sm.q() {
 			switch sm.q() {
@@ -303,11 +305,18 @@ func (sm *stateMachine) Step(m Message) {
 		case msgApp:
 		case msgApp:
 			handleAppendEntries()
 			handleAppendEntries()
 		case msgVote:
 		case msgVote:
-			if sm.log.isUpToDate(m.Index, m.LogTerm) {
+			switch sm.vote {
+			case m.From:
 				sm.send(Message{To: m.From, Type: msgVoteResp, Index: sm.log.lastIndex()})
 				sm.send(Message{To: m.From, Type: msgVoteResp, Index: sm.log.lastIndex()})
-			} else {
-				sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1})
+				return
+			case none:
+				if sm.log.isUpToDate(m.Index, m.LogTerm) {
+					sm.vote = m.From
+					sm.send(Message{To: m.From, Type: msgVoteResp, Index: sm.log.lastIndex()})
+					return
+				}
 			}
 			}
+			sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1})
 		}
 		}
 	}
 	}
 }
 }

+ 25 - 20
raft/raft_test.go

@@ -346,32 +346,37 @@ func TestCommit(t *testing.T) {
 func TestVote(t *testing.T) {
 func TestVote(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		i, term int
 		i, term int
+		voteFor int
 		w       int
 		w       int
 	}{
 	}{
-		{0, 0, -1},
-		{0, 1, -1},
-		{0, 2, -1},
-		{0, 3, 2},
-
-		{1, 0, -1},
-		{1, 1, -1},
-		{1, 2, -1},
-		{1, 3, 2},
-
-		{2, 0, -1},
-		{2, 1, -1},
-		{2, 2, 2},
-		{2, 3, 2},
-
-		{3, 0, -1},
-		{3, 1, -1},
-		{3, 2, 2},
-		{3, 3, 2},
+		{0, 0, none, -1},
+		{0, 1, none, -1},
+		{0, 2, none, -1},
+		{0, 3, none, 2},
+
+		{1, 0, none, -1},
+		{1, 1, none, -1},
+		{1, 2, none, -1},
+		{1, 3, none, 2},
+
+		{2, 0, none, -1},
+		{2, 1, none, -1},
+		{2, 2, none, 2},
+		{2, 3, none, 2},
+
+		{3, 0, none, -1},
+		{3, 1, none, -1},
+		{3, 2, none, 2},
+		{3, 3, none, 2},
+
+		{3, 2, 0, 2},
+		{3, 2, 1, -1},
 	}
 	}
 
 
 	for i, tt := range tests {
 	for i, tt := range tests {
 		called := false
 		called := false
-		sm := &nsm{stateMachine{log: &log{ents: []Entry{{}, {Term: 2}, {Term: 2}}}}, nil}
+		sm := &nsm{stateMachine{vote: tt.voteFor, log: &log{ents: []Entry{{}, {Term: 2}, {Term: 2}}}}, nil}
+
 		sm.next = stepperFunc(func(m Message) {
 		sm.next = stepperFunc(func(m Message) {
 			called = true
 			called = true
 			if m.Index != tt.w {
 			if m.Index != tt.w {