Browse Source

raft: add unstableState

Xiang Li 11 years ago
parent
commit
311db876b0
3 changed files with 100 additions and 8 deletions
  1. 9 0
      raft/node.go
  2. 56 8
      raft/raft.go
  3. 35 0
      raft/raft_test.go

+ 9 - 0
raft/node.go

@@ -209,3 +209,12 @@ func (n *Node) UpdateConf(t int64, c *Config) {
 func (n *Node) UnstableEnts() []Entry {
 func (n *Node) UnstableEnts() []Entry {
 	return n.sm.raftLog.unstableEnts()
 	return n.sm.raftLog.unstableEnts()
 }
 }
+
+func (n *Node) UnstableState() State {
+	if n.sm.unstableState == emptyState {
+		return emptyState
+	}
+	s := n.sm.unstableState
+	n.sm.clearState()
+	return s
+}

+ 56 - 8
raft/raft.go

@@ -66,6 +66,14 @@ func (st stateType) String() string {
 	return stmap[int64(st)]
 	return stmap[int64(st)]
 }
 }
 
 
+type State struct {
+	Term   int64
+	Vote   int64
+	Commit int64
+}
+
+var emptyState = State{}
+
 type Message struct {
 type Message struct {
 	Type      messageType
 	Type      messageType
 	ClusterId int64
 	ClusterId int64
@@ -151,6 +159,8 @@ type stateMachine struct {
 	pendingConf bool
 	pendingConf bool
 
 
 	snapshoter Snapshoter
 	snapshoter Snapshoter
+
+	unstableState State
 }
 }
 
 
 func newStateMachine(id int64, peers []int64) *stateMachine {
 func newStateMachine(id int64, peers []int64) *stateMachine {
@@ -273,9 +283,9 @@ func (sm *stateMachine) nextEnts() (ents []Entry) {
 }
 }
 
 
 func (sm *stateMachine) reset(term int64) {
 func (sm *stateMachine) reset(term int64) {
-	sm.term.Set(term)
+	sm.setTerm(term)
 	sm.lead.Set(none)
 	sm.lead.Set(none)
-	sm.vote = none
+	sm.setVote(none)
 	sm.votes = make(map[int64]bool)
 	sm.votes = make(map[int64]bool)
 	for i := range sm.ins {
 	for i := range sm.ins {
 		sm.ins[i] = &index{next: sm.raftLog.lastIndex() + 1}
 		sm.ins[i] = &index{next: sm.raftLog.lastIndex() + 1}
@@ -316,7 +326,7 @@ func (sm *stateMachine) becomeCandidate() {
 		panic("invalid transition [leader -> candidate]")
 		panic("invalid transition [leader -> candidate]")
 	}
 	}
 	sm.reset(sm.term.Get() + 1)
 	sm.reset(sm.term.Get() + 1)
-	sm.vote = sm.id
+	sm.setVote(sm.id)
 	sm.state = stateCandidate
 	sm.state = stateCandidate
 }
 }
 
 
@@ -399,12 +409,12 @@ func (sm *stateMachine) handleSnapshot(m Message) {
 }
 }
 
 
 func (sm *stateMachine) addNode(id int64) {
 func (sm *stateMachine) addNode(id int64) {
-	sm.ins[id] = &index{next: sm.raftLog.lastIndex() + 1}
+	sm.addIns(id, 0, sm.raftLog.lastIndex()+1)
 	sm.pendingConf = false
 	sm.pendingConf = false
 }
 }
 
 
 func (sm *stateMachine) removeNode(id int64) {
 func (sm *stateMachine) removeNode(id int64) {
-	delete(sm.ins, id)
+	sm.deleteIns(id)
 	sm.pendingConf = false
 	sm.pendingConf = false
 }
 }
 
 
@@ -483,7 +493,7 @@ func stepFollower(sm *stateMachine, m Message) bool {
 		sm.handleSnapshot(m)
 		sm.handleSnapshot(m)
 	case msgVote:
 	case msgVote:
 		if (sm.vote == none || sm.vote == m.From) && sm.raftLog.isUpToDate(m.Index, m.LogTerm) {
 		if (sm.vote == none || sm.vote == m.From) && sm.raftLog.isUpToDate(m.Index, m.LogTerm) {
-			sm.vote = m.From
+			sm.setVote(m.From)
 			sm.send(Message{To: m.From, Type: msgVoteResp, Index: sm.raftLog.lastIndex()})
 			sm.send(Message{To: m.From, Type: msgVoteResp, Index: sm.raftLog.lastIndex()})
 		} else {
 		} else {
 			sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1})
 			sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1})
@@ -515,9 +525,10 @@ func (sm *stateMachine) restore(s Snapshot) {
 	sm.index.Set(sm.raftLog.lastIndex())
 	sm.index.Set(sm.raftLog.lastIndex())
 	sm.ins = make(map[int64]*index)
 	sm.ins = make(map[int64]*index)
 	for _, n := range s.Nodes {
 	for _, n := range s.Nodes {
-		sm.ins[n] = &index{next: sm.raftLog.lastIndex() + 1}
 		if n == sm.id {
 		if n == sm.id {
-			sm.ins[n].match = sm.raftLog.lastIndex()
+			sm.addIns(n, sm.raftLog.lastIndex(), sm.raftLog.lastIndex()+1)
+		} else {
+			sm.addIns(n, 0, sm.raftLog.lastIndex()+1)
 		}
 		}
 	}
 	}
 	sm.pendingConf = false
 	sm.pendingConf = false
@@ -541,3 +552,40 @@ func (sm *stateMachine) nodes() []int64 {
 	}
 	}
 	return nodes
 	return nodes
 }
 }
+
+func (sm *stateMachine) setTerm(term int64) {
+	sm.term.Set(term)
+	sm.saveState()
+}
+
+func (sm *stateMachine) setVote(vote int64) {
+	sm.vote = vote
+	sm.saveState()
+}
+
+func (sm *stateMachine) addIns(id, match, next int64) {
+	sm.ins[id] = &index{next: next, match: match}
+	sm.saveState()
+}
+
+func (sm *stateMachine) deleteIns(id int64) {
+	delete(sm.ins, id)
+	sm.saveState()
+}
+
+// saveState saves the state to sm.unstableState
+// When there is a term change, vote change or configuration change, raft
+// must call saveState.
+func (sm *stateMachine) saveState() {
+	sm.setState(sm.vote, sm.term.Get(), sm.raftLog.committed)
+}
+
+func (sm *stateMachine) clearState() {
+	sm.setState(0, 0, 0)
+}
+
+func (sm *stateMachine) setState(vote, term, commit int64) {
+	sm.unstableState.Vote = vote
+	sm.unstableState.Term = term
+	sm.unstableState.Commit = commit
+}

+ 35 - 0
raft/raft_test.go

@@ -954,6 +954,41 @@ func TestSlowNodeRestore(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestUnstableState(t *testing.T) {
+	sm := newStateMachine(0, []int64{0})
+	w := State{}
+
+	sm.setVote(1)
+	w.Vote = 1
+	if !reflect.DeepEqual(sm.unstableState, w) {
+		t.Errorf("unstableState = %v, want %v", sm.unstableState, w)
+	}
+	sm.clearState()
+
+	sm.setTerm(1)
+	w.Term = 1
+	if !reflect.DeepEqual(sm.unstableState, w) {
+		t.Errorf("unstableState = %v, want %v", sm.unstableState, w)
+	}
+	sm.clearState()
+
+	sm.raftLog.committed = 1
+	sm.addIns(1, 0, 0)
+	w.Commit = 1
+	if !reflect.DeepEqual(sm.unstableState, w) {
+		t.Errorf("unstableState = %v, want %v", sm.unstableState, w)
+	}
+	sm.clearState()
+
+	sm.raftLog.committed = 2
+	sm.deleteIns(1)
+	w.Commit = 2
+	if !reflect.DeepEqual(sm.unstableState, w) {
+		t.Errorf("unstableState = %v, want %v", sm.unstableState, w)
+	}
+	sm.clearState()
+}
+
 func ents(terms ...int64) *stateMachine {
 func ents(terms ...int64) *stateMachine {
 	ents := []Entry{{}}
 	ents := []Entry{{}}
 	for _, term := range terms {
 	for _, term := range terms {