Browse Source

raft: update lead to none when receives vaild msgVote

Xiang Li 11 years ago
parent
commit
3921295b21
2 changed files with 13 additions and 2 deletions
  1. 5 1
      raft/raft.go
  2. 8 1
      raft/raft_test.go

+ 5 - 1
raft/raft.go

@@ -311,7 +311,11 @@ func (sm *stateMachine) Step(m Message) (ok bool) {
 	case m.Term == 0:
 	case m.Term == 0:
 		// local message
 		// local message
 	case m.Term > sm.term.Get():
 	case m.Term > sm.term.Get():
-		sm.becomeFollower(m.Term, m.From)
+		lead := m.From
+		if m.Type == msgVote {
+			lead = none
+		}
+		sm.becomeFollower(m.Term, lead)
 	case m.Term < sm.term.Get():
 	case m.Term < sm.term.Get():
 		// ignore
 		// ignore
 		return true
 		return true

+ 8 - 1
raft/raft_test.go

@@ -685,7 +685,7 @@ func TestAllServerStepdown(t *testing.T) {
 		}
 		}
 
 
 		for j, msgType := range tmsgTypes {
 		for j, msgType := range tmsgTypes {
-			sm.Step(Message{Type: msgType, Term: tterm, LogTerm: tterm})
+			sm.Step(Message{From: 1, Type: msgType, Term: tterm, LogTerm: tterm})
 
 
 			if sm.state != tt.wstate {
 			if sm.state != tt.wstate {
 				t.Errorf("#%d.%d state = %v , want %v", i, j, sm.state, tt.wstate)
 				t.Errorf("#%d.%d state = %v , want %v", i, j, sm.state, tt.wstate)
@@ -696,6 +696,13 @@ func TestAllServerStepdown(t *testing.T) {
 			if int64(len(sm.log.ents)) != tt.windex {
 			if int64(len(sm.log.ents)) != tt.windex {
 				t.Errorf("#%d.%d index = %v , want %v", i, j, len(sm.log.ents), tt.windex)
 				t.Errorf("#%d.%d index = %v , want %v", i, j, len(sm.log.ents), tt.windex)
 			}
 			}
+			wlead := int64(1)
+			if msgType == msgVote {
+				wlead = none
+			}
+			if sm.lead.Get() != wlead {
+				t.Errorf("#%d, sm.lead = %d, want %d", i, sm.lead.Get(), none)
+			}
 		}
 		}
 	}
 	}
 }
 }