Browse Source

raft: add recover

Yicheng Qin 11 years ago
parent
commit
ba63cf666d
6 changed files with 66 additions and 115 deletions
  1. 15 12
      raft/log.go
  2. 1 1
      raft/log_test.go
  3. 19 0
      raft/node.go
  4. 5 28
      raft/raft.go
  5. 22 68
      raft/raft_test.go
  6. 4 6
      raft/snapshot.go

+ 15 - 12
raft/log.go

@@ -19,12 +19,13 @@ func (e *Entry) isConfig() bool {
 }
 
 type raftLog struct {
-	ents      []Entry
-	unstable  int64
-	committed int64
-	applied   int64
-	offset    int64
-	snapshot  Snapshot
+	ents             []Entry
+	unstable         int64
+	committed        int64
+	applied          int64
+	offset           int64
+	snapshot         Snapshot
+	unstableSnapshot Snapshot
 
 	// want a compact after the number of entries exceeds the threshold
 	// TODO(xiangli) size might be a better criteria
@@ -163,12 +164,14 @@ func (l *raftLog) shouldCompact() bool {
 	return (l.applied - l.offset) > l.compactThreshold
 }
 
-func (l *raftLog) restore(index, term int64) {
-	l.ents = []Entry{{Term: term}}
-	l.unstable = index + 1
-	l.committed = index
-	l.applied = index
-	l.offset = index
+func (l *raftLog) restore(s Snapshot) {
+	l.ents = []Entry{{Term: s.Term}}
+	l.unstable = s.Index + 1
+	l.committed = s.Index
+	l.applied = s.Index
+	l.offset = s.Index
+	l.snapshot = s
+	l.unstableSnapshot = s
 }
 
 func (l *raftLog) at(i int64) *Entry {

+ 1 - 1
raft/log_test.go

@@ -192,7 +192,7 @@ func TestLogRestore(t *testing.T) {
 
 	index := int64(1000)
 	term := int64(1000)
-	raftLog.restore(index, term)
+	raftLog.restore(Snapshot{Index: index, Term: term})
 
 	// only has the guard entry
 	if len(raftLog.ents) != 1 {

+ 19 - 0
raft/node.go

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"log"
 	"math/rand"
+	"sort"
 	"time"
 )
 
@@ -76,6 +77,15 @@ func (n *Node) Leader() int64 { return n.sm.lead.Get() }
 
 func (n *Node) IsRemoved() bool { return n.removed }
 
+func (n *Node) Nodes() []int64 {
+	nodes := make(int64Slice, 0, len(n.sm.ins))
+	for k := range n.sm.ins {
+		nodes = append(nodes, k)
+	}
+	sort.Sort(nodes)
+	return nodes
+}
+
 // Propose asynchronously proposes data be applied to the underlying state machine.
 func (n *Node) Propose(data []byte) { n.propose(Normal, data) }
 
@@ -232,6 +242,15 @@ func (n *Node) UnstableState() State {
 	return s
 }
 
+func (n *Node) UnstableSnapshot() Snapshot {
+	if n.sm.raftLog.unstableSnapshot.IsEmpty() {
+		return emptySnapshot
+	}
+	s := n.sm.raftLog.unstableSnapshot
+	n.sm.raftLog.unstableSnapshot = emptySnapshot
+	return s
+}
+
 func (n *Node) GetSnap() Snapshot {
 	return n.sm.raftLog.snapshot
 }

+ 5 - 28
raft/raft.go

@@ -157,8 +157,6 @@ type stateMachine struct {
 	// pending reconfiguration
 	pendingConf bool
 
-	snapshoter Snapshoter
-
 	unstableState State
 }
 
@@ -187,10 +185,6 @@ func (sm *stateMachine) String() string {
 	return s
 }
 
-func (sm *stateMachine) setSnapshoter(snapshoter Snapshoter) {
-	sm.snapshoter = snapshoter
-}
-
 func (sm *stateMachine) poll(id int64, v bool) (granted int) {
 	if _, ok := sm.votes[id]; !ok {
 		sm.votes[id] = v
@@ -220,7 +214,7 @@ func (sm *stateMachine) sendAppend(to int64) {
 	m.Index = in.next - 1
 	if sm.needSnapshot(m.Index) {
 		m.Type = msgSnap
-		m.Snapshot = sm.snapshoter.GetSnap()
+		m.Snapshot = sm.raftLog.snapshot
 	} else {
 		m.Type = msgApp
 		m.LogTerm = sm.raftLog.term(in.next - 1)
@@ -502,31 +496,15 @@ func stepFollower(sm *stateMachine, m Message) bool {
 	return true
 }
 
-// maybeCompact tries to compact the log. It calls the snapshoter to take a snapshot and
-// then compact the log up-to the index at which the snapshot was taken.
-func (sm *stateMachine) maybeCompact() bool {
-	if sm.snapshoter == nil || !sm.raftLog.shouldCompact() {
-		return false
-	}
-	sm.snapshoter.Snap(sm.raftLog.applied, sm.raftLog.term(sm.raftLog.applied), sm.nodes())
-	sm.raftLog.compact(sm.raftLog.applied)
-	return true
-}
-
 func (sm *stateMachine) compact(d []byte) {
 	sm.raftLog.snap(d, sm.raftLog.applied, sm.raftLog.term(sm.raftLog.applied), sm.nodes())
 	sm.raftLog.compact(sm.raftLog.applied)
 }
 
 // restore recovers the statemachine from a snapshot. It restores the log and the
-// configuration of statemachine. It calls the snapshoter to restore from the given
-// snapshot.
+// configuration of statemachine.
 func (sm *stateMachine) restore(s Snapshot) {
-	if sm.snapshoter == nil {
-		panic("try to restore from snapshot, but snapshoter is nil")
-	}
-
-	sm.raftLog.restore(s.Index, s.Term)
+	sm.raftLog.restore(s)
 	sm.index.Set(sm.raftLog.lastIndex())
 	sm.ins = make(map[int64]*index)
 	for _, n := range s.Nodes {
@@ -537,13 +515,12 @@ func (sm *stateMachine) restore(s Snapshot) {
 		}
 	}
 	sm.pendingConf = false
-	sm.snapshoter.Restore(s)
 }
 
 func (sm *stateMachine) needSnapshot(i int64) bool {
 	if i < sm.raftLog.offset {
-		if sm.snapshoter == nil {
-			panic("need snapshot but snapshoter is nil")
+		if sm.raftLog.snapshot.IsEmpty() {
+			panic("need non-empty snapshot")
 		}
 		return true
 	}

+ 22 - 68
raft/raft_test.go

@@ -2,7 +2,6 @@ package raft
 
 import (
 	"bytes"
-	"fmt"
 	"math/rand"
 	"reflect"
 	"sort"
@@ -781,45 +780,24 @@ func TestRestore(t *testing.T) {
 		Nodes: []int64{0, 1, 2},
 	}
 
-	tests := []struct {
-		snapshoter Snapshoter
-		wallow     bool
-	}{
-		{nil, false},
-		{new(logSnapshoter), true},
-	}
-
-	for i, tt := range tests {
-		func() {
-			defer func() {
-				if r := recover(); r != nil {
-					if tt.wallow == true {
-						t.Errorf("%d: allow = %v, want %v", i, false, true)
-					}
-				}
-			}()
-
-			sm := newStateMachine(0, []int64{0, 1})
-			sm.setSnapshoter(tt.snapshoter)
-			sm.restore(s)
+	sm := newStateMachine(0, []int64{0, 1})
+	sm.restore(s)
 
-			if sm.raftLog.lastIndex() != s.Index {
-				t.Errorf("#%d: log.lastIndex = %d, want %d", i, sm.raftLog.lastIndex(), s.Index)
-			}
-			if sm.raftLog.term(s.Index) != s.Term {
-				t.Errorf("#%d: log.lastTerm = %d, want %d", i, sm.raftLog.term(s.Index), s.Term)
-			}
-			sg := int64Slice(sm.nodes())
-			sw := int64Slice(s.Nodes)
-			sort.Sort(sg)
-			sort.Sort(sw)
-			if !reflect.DeepEqual(sg, sw) {
-				t.Errorf("#%d: sm.Nodes = %+v, want %+v", i, sg, sw)
-			}
-			if !reflect.DeepEqual(sm.snapshoter.GetSnap(), s) {
-				t.Errorf("%d: snapshoter.getSnap = %+v, want %+v", sm.snapshoter.GetSnap(), s)
-			}
-		}()
+	if sm.raftLog.lastIndex() != s.Index {
+		t.Errorf("log.lastIndex = %d, want %d", sm.raftLog.lastIndex(), s.Index)
+	}
+	if sm.raftLog.term(s.Index) != s.Term {
+		t.Errorf("log.lastTerm = %d, want %d", sm.raftLog.term(s.Index), s.Term)
+	}
+	sg := int64Slice(sm.nodes())
+	sw := int64Slice(s.Nodes)
+	sort.Sort(sg)
+	sort.Sort(sw)
+	if !reflect.DeepEqual(sg, sw) {
+		t.Errorf("sm.Nodes = %+v, want %+v", sg, sw)
+	}
+	if !reflect.DeepEqual(sm.raftLog.snapshot, s) {
+		t.Errorf("snapshot = %+v, want %+v", sm.raftLog.snapshot, s)
 	}
 }
 
@@ -830,7 +808,6 @@ func TestProvideSnap(t *testing.T) {
 		Nodes: []int64{0, 1},
 	}
 	sm := newStateMachine(0, []int64{0})
-	sm.setSnapshoter(new(logSnapshoter))
 	// restore the statemachin from a snapshot
 	// so it has a compacted log and a snapshot
 	sm.restore(s)
@@ -872,11 +849,10 @@ func TestRestoreFromSnapMsg(t *testing.T) {
 	m := Message{Type: msgSnap, From: 0, Term: 1, Snapshot: s}
 
 	sm := newStateMachine(1, []int64{0, 1})
-	sm.setSnapshoter(new(logSnapshoter))
 	sm.Step(m)
 
-	if !reflect.DeepEqual(sm.snapshoter.GetSnap(), s) {
-		t.Errorf("snapshot = %+v, want %+v", sm.snapshoter.GetSnap(), s)
+	if !reflect.DeepEqual(sm.raftLog.snapshot, s) {
+		t.Errorf("snapshot = %+v, want %+v", sm.raftLog.snapshot, s)
 	}
 }
 
@@ -890,16 +866,14 @@ func TestSlowNodeRestore(t *testing.T) {
 	}
 	lead := nt.peers[0].(*stateMachine)
 	lead.nextEnts()
-	if !lead.maybeCompact() {
-		t.Errorf("compacted = false, want true")
-	}
+	lead.compact(nil)
 
 	nt.recover()
 	nt.send(Message{From: 0, To: 0, Type: msgBeat})
 
 	follower := nt.peers[2].(*stateMachine)
-	if !reflect.DeepEqual(follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap()) {
-		t.Errorf("follower.snap = %+v, want %+v", follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap())
+	if !reflect.DeepEqual(follower.raftLog.snapshot, lead.raftLog.snapshot) {
+		t.Errorf("follower.snap = %+v, want %+v", follower.raftLog.snapshot, lead.raftLog.snapshot)
 	}
 
 	committed := follower.raftLog.lastIndex()
@@ -979,7 +953,6 @@ func newNetwork(peers ...Interface) *network {
 		switch v := p.(type) {
 		case nil:
 			sm := newStateMachine(nid, defaultPeerAddrs)
-			sm.setSnapshoter(new(logSnapshoter))
 			npeers[nid] = sm
 		case *stateMachine:
 			v.id = nid
@@ -1070,22 +1043,3 @@ func (blackHole) Step(Message) bool { return true }
 func (blackHole) Msgs() []Message   { return nil }
 
 var nopStepper = &blackHole{}
-
-type logSnapshoter struct {
-	snapshot Snapshot
-}
-
-func (s *logSnapshoter) Snap(index, term int64, nodes []int64) {
-	s.snapshot = Snapshot{
-		Index: index,
-		Term:  term,
-		Nodes: nodes,
-		Data:  []byte(fmt.Sprintf("%d:%d", term, index)),
-	}
-}
-func (s *logSnapshoter) Restore(ss Snapshot) {
-	s.snapshot = ss
-}
-func (s *logSnapshoter) GetSnap() Snapshot {
-	return s.snapshot
-}

+ 4 - 6
raft/snapshot.go

@@ -1,5 +1,7 @@
 package raft
 
+var emptySnapshot = Snapshot{}
+
 type Snapshot struct {
 	Data []byte
 
@@ -11,10 +13,6 @@ type Snapshot struct {
 	Term int64
 }
 
-// A snapshoter can make a snapshot of its current state atomically.
-// It can restore from a snapshot and get the latest snapshot it took.
-type Snapshoter interface {
-	Snap(index, term int64, nodes []int64)
-	Restore(snap Snapshot)
-	GetSnap() Snapshot
+func (s Snapshot) IsEmpty() bool {
+	return s.Term == 0
 }