Browse Source

raft: refactor recover

Xiang Li 11 years ago
parent
commit
b383cd5acf
4 changed files with 20 additions and 19 deletions
  1. 4 0
      raft/log.go
  2. 7 6
      raft/node.go
  3. 2 3
      raft/node_test.go
  4. 7 10
      raft/raft.go

+ 4 - 0
raft/log.go

@@ -46,6 +46,10 @@ func newLog() *raftLog {
 	}
 }
 
+func (l *raftLog) isEmpty() bool {
+	return l.offset == 0 && len(l.ents) == 1
+}
+
 func (l *raftLog) String() string {
 	return fmt.Sprintf("offset=%d committed=%d applied=%d len(ents)=%d", l.offset, l.committed, l.applied, len(l.ents))
 }

+ 7 - 6
raft/node.go

@@ -51,6 +51,13 @@ func New(id int64, heartbeat, election tick) *Node {
 	return n
 }
 
+func Recover(id int64, ents []Entry, state State, heartbeat, election tick) *Node {
+	n := New(id, heartbeat, election)
+	n.sm.loadEnts(ents)
+	n.sm.loadState(state)
+	return n
+}
+
 func (n *Node) Id() int64 { return n.sm.id }
 
 func (n *Node) ClusterId() int64 { return n.sm.clusterId }
@@ -219,9 +226,3 @@ func (n *Node) UnstableState() State {
 	n.sm.clearState()
 	return s
 }
-
-// Load loads saved info and recovers the node.
-// It should only be called for new node.
-func (n *Node) Load(ents []Entry, state State) {
-	n.sm.load(ents, state)
-}

+ 2 - 3
raft/node_test.go

@@ -188,12 +188,11 @@ func TestDenial(t *testing.T) {
 	}
 }
 
-func TestLoad(t *testing.T) {
+func TestRecover(t *testing.T) {
 	ents := []Entry{{Term: 1}, {Term: 2}, {Term: 3}}
 	state := State{Term: 500, Vote: 1, Commit: 3}
 
-	n := New(0, defaultHeartbeat, defaultElection)
-	n.Load(ents, state)
+	n := Recover(0, ents, state, defaultHeartbeat, defaultElection)
 	if g := n.Next(); !reflect.DeepEqual(g, ents) {
 		t.Errorf("ents = %+v, want %+v", g, ents)
 	}

+ 7 - 10
raft/raft.go

@@ -590,19 +590,16 @@ func (sm *stateMachine) setState(vote, term, commit int64) {
 	sm.unstableState.Commit = commit
 }
 
-func (sm *stateMachine) load(ents []Entry, state State) {
-	sm.loadEnts(ents)
-	sm.loadState(state)
-}
-
 func (sm *stateMachine) loadEnts(ents []Entry) {
-	sm.raftLog.append(sm.raftLog.lastIndex(), ents...)
+	if !sm.raftLog.isEmpty() {
+		panic("cannot load entries when log is not empty")
+	}
+	sm.raftLog.append(0, ents...)
+	sm.raftLog.unstable = sm.raftLog.lastIndex() + 1
 }
 
 func (sm *stateMachine) loadState(state State) {
-	sm.term.Set(state.Term)
-	sm.vote = state.Vote
-	sm.raftLog.unstable = state.Commit + 1
 	sm.raftLog.committed = state.Commit
-	sm.saveState()
+	sm.setTerm(state.Term)
+	sm.setVote(state.Vote)
 }