Browse Source

raft: sm.compact and sm.restore

Xiang Li 11 years ago
parent
commit
2a11c1487c
5 changed files with 220 additions and 1 deletions
  1. 14 1
      raft/log.go
  2. 28 0
      raft/log_test.go
  3. 45 0
      raft/raft.go
  4. 113 0
      raft/raft_test.go
  5. 20 0
      raft/snapshot.go

+ 14 - 1
raft/log.go

@@ -69,6 +69,12 @@ func (l *log) term(i int) int {
 }
 
 func (l *log) entries(i int) []Entry {
+	// never send out the first entry
+	// first entry is only used for matching
+	// prevLogTerm
+	if i == l.offset {
+		panic("cannot return the first entry in log")
+	}
 	return l.slice(i, l.lastIndex()+1)
 }
 
@@ -116,7 +122,14 @@ func (l *log) compact(i int) int {
 }
 
 func (l *log) shouldCompact() bool {
-	return (l.committed - l.offset) > l.compactThreshold
+	return (l.applied - l.offset) > l.compactThreshold
+}
+
+func (l *log) restore(index, term int) {
+	l.ents = []Entry{{Term: term}}
+	l.committed = index
+	l.applied = index
+	l.offset = index
 }
 
 func (l *log) at(i int) *Entry {

+ 28 - 0
raft/log_test.go

@@ -85,6 +85,34 @@ func TestCompaction(t *testing.T) {
 	}
 }
 
+func TestLogRestore(t *testing.T) {
+	log := newLog()
+	for i := 0; i < 100; i++ {
+		log.append(i, Entry{Term: i + 1})
+	}
+
+	index := 1000
+	term := 1000
+	log.restore(index, term)
+
+	// only has the guard entry
+	if len(log.ents) != 1 {
+		t.Errorf("len = %d, want 0", len(log.ents))
+	}
+	if log.offset != index {
+		t.Errorf("offset = %d, want %d", log.offset, index)
+	}
+	if log.applied != index {
+		t.Errorf("applied = %d, want %d", log.applied, index)
+	}
+	if log.committed != index {
+		t.Errorf("comitted = %d, want %d", log.committed, index)
+	}
+	if log.term(index) != term {
+		t.Errorf("term = %d, want %d", log.term(index), term)
+	}
+}
+
 func TestIsOutOfBounds(t *testing.T) {
 	offset := 100
 	num := 100

+ 45 - 0
raft/raft.go

@@ -111,6 +111,8 @@ type stateMachine struct {
 
 	// pending reconfiguration
 	pendingConf bool
+
+	snapshoter Snapshoter
 }
 
 func newStateMachine(id int, peers []int) *stateMachine {
@@ -122,6 +124,10 @@ func newStateMachine(id int, peers []int) *stateMachine {
 	return sm
 }
 
+func (sm *stateMachine) setSnapshoter(snapshoter Snapshoter) {
+	sm.snapshoter = snapshoter
+}
+
 func (sm *stateMachine) poll(id int, v bool) (granted int) {
 	if _, ok := sm.votes[id]; !ok {
 		sm.votes[id] = v
@@ -379,3 +385,42 @@ func stepFollower(sm *stateMachine, m Message) bool {
 	}
 	return true
 }
+
+// maybeCompact tries to compact the log. It calls the snapshoter to take a snapshot and
+// then compact the log up-to the index at which the snapshot was taken.
+func (sm *stateMachine) maybeCompact() bool {
+	if sm.snapshoter == nil || !sm.log.shouldCompact() {
+		return false
+	}
+	sm.snapshoter.Snap(sm.log.applied, sm.log.term(sm.log.applied), sm.nodes())
+	sm.log.compact(sm.log.applied)
+	return true
+}
+
+// restore recovers the statemachine from a snapshot. It restores the log and the
+// configuration of statemachine. It calls the snapshoter to restore from the given
+// snapshot.
+func (sm *stateMachine) restore(s Snapshot) {
+	if sm.snapshoter == nil {
+		panic("try to restore from snapshot, but snapshoter is nil")
+	}
+
+	sm.log.restore(s.Index, s.Term)
+	sm.ins = make(map[int]*index)
+	for _, n := range s.Nodes {
+		sm.ins[n] = &index{next: sm.log.lastIndex() + 1}
+		if n == sm.id {
+			sm.ins[n].match = sm.log.lastIndex()
+		}
+	}
+	sm.pendingConf = false
+	sm.snapshoter.Restore(s)
+}
+
+func (sm *stateMachine) nodes() []int {
+	nodes := make([]int, 0, len(sm.ins))
+	for k := range sm.ins {
+		nodes = append(nodes, k)
+	}
+	return nodes
+}

+ 113 - 0
raft/raft_test.go

@@ -2,7 +2,10 @@ package raft
 
 import (
 	"bytes"
+	"fmt"
 	"math/rand"
+	"reflect"
+	"sort"
 	"testing"
 )
 
@@ -708,6 +711,97 @@ func TestRecvMsgBeat(t *testing.T) {
 	}
 }
 
+func TestMaybeCompact(t *testing.T) {
+	tests := []struct {
+		snapshoter Snapshoter
+		applied    int
+		wCompact   bool
+	}{
+		{nil, defaultCompactThreshold + 1, false},
+		{new(logSnapshoter), defaultCompactThreshold - 1, false},
+		{new(logSnapshoter), defaultCompactThreshold + 1, true},
+	}
+
+	for i, tt := range tests {
+		sm := newStateMachine(0, []int{0, 1, 2})
+		sm.setSnapshoter(tt.snapshoter)
+		for i := 0; i < defaultCompactThreshold*2; i++ {
+			sm.log.append(i, Entry{Term: i + 1})
+		}
+		sm.log.applied = tt.applied
+		sm.log.committed = tt.applied
+
+		if g := sm.maybeCompact(); g != tt.wCompact {
+			t.Errorf("#%d: compact = %v, want %v", i, g, tt.wCompact)
+		}
+
+		if tt.wCompact {
+			s := sm.snapshoter.GetSnap()
+			if s.Index != tt.applied {
+				t.Errorf("#%d: snap.Index = %v, want %v", i, s.Index, tt.applied)
+			}
+			if s.Term != tt.applied {
+				t.Errorf("#%d: snap.Term = %v, want %v", i, s.Index, tt.applied)
+			}
+
+			w := sm.nodes()
+			sort.Ints(w)
+			sort.Ints(s.Nodes)
+			if !reflect.DeepEqual(s.Nodes, w) {
+				t.Errorf("#%d: snap.Nodes = %+v, want %+v", i, s.Nodes, w)
+			}
+		}
+	}
+}
+
+func TestRestore(t *testing.T) {
+	s := Snapshot{
+		Index: defaultCompactThreshold + 1,
+		Term:  defaultCompactThreshold + 1,
+		Nodes: []int{0, 1, 2},
+	}
+
+	tests := []struct {
+		snapshoter Snapshoter
+		wallow     bool
+	}{
+		{nil, false},
+		{new(logSnapshoter), true},
+	}
+
+	for i, tt := range tests {
+		func() {
+			defer func() {
+				if r := recover(); r != nil {
+					if tt.wallow == true {
+						t.Errorf("%d: allow = %v, want %v", i, false, true)
+					}
+				}
+			}()
+
+			sm := newStateMachine(0, []int{0, 1})
+			sm.setSnapshoter(tt.snapshoter)
+			sm.restore(s)
+
+			if sm.log.lastIndex() != s.Index {
+				t.Errorf("#%d: log.lastIndex = %d, want %d", i, sm.log.lastIndex(), s.Index)
+			}
+			if sm.log.term(s.Index) != s.Term {
+				t.Errorf("#%d: log.lastTerm = %d, want %d", i, sm.log.term(s.Index), s.Term)
+			}
+			g := sm.nodes()
+			sort.Ints(g)
+			sort.Ints(s.Nodes)
+			if !reflect.DeepEqual(g, s.Nodes) {
+				t.Errorf("#%d: sm.Nodes = %+v, want %+v", i, g, s.Nodes)
+			}
+			if !reflect.DeepEqual(sm.snapshoter.GetSnap(), s) {
+				t.Errorf("%d: snapshoter.getSnap = %+v, want %+v", sm.snapshoter.GetSnap(), s)
+			}
+		}()
+	}
+}
+
 func ents(terms ...int) *stateMachine {
 	ents := []Entry{{}}
 	for _, term := range terms {
@@ -831,3 +925,22 @@ func (blackHole) Step(Message) bool { return true }
 func (blackHole) Msgs() []Message   { return nil }
 
 var nopStepper = &blackHole{}
+
+type logSnapshoter struct {
+	snapshot Snapshot
+}
+
+func (s *logSnapshoter) Snap(index, term int, nodes []int) {
+	s.snapshot = Snapshot{
+		Index: index,
+		Term:  term,
+		Nodes: nodes,
+		Data:  []byte(fmt.Sprintf("%d:%d", term, index)),
+	}
+}
+func (s *logSnapshoter) Restore(ss Snapshot) {
+	s.snapshot = ss
+}
+func (s *logSnapshoter) GetSnap() Snapshot {
+	return s.snapshot
+}

+ 20 - 0
raft/snapshot.go

@@ -0,0 +1,20 @@
+package raft
+
+type Snapshot struct {
+	Data []byte
+
+	// the configuration
+	Nodes []int
+	// the index at which the snapshot was taken.
+	Index int
+	// the log term of the index
+	Term int
+}
+
+// A snapshoter can make a snapshot of its current state atomically.
+// It can restore from a snapshot and get the latest snapshot it took.
+type Snapshoter interface {
+	Snap(index, term int, nodes []int)
+	Restore(snap Snapshot)
+	GetSnap() Snapshot
+}