Browse Source

raft: handle snapshot message

Xiang Li 11 years ago
parent
commit
5651272ec8
2 changed files with 120 additions and 5 deletions
  1. 33 5
      raft/raft.go
  2. 87 0
      raft/raft_test.go

+ 33 - 5
raft/raft.go

@@ -17,6 +17,7 @@ const (
 	msgAppResp
 	msgVote
 	msgVoteResp
+	msgSnap
 )
 
 var mtmap = [...]string{
@@ -27,6 +28,7 @@ var mtmap = [...]string{
 	msgAppResp:  "msgAppResp",
 	msgVote:     "msgVote",
 	msgVoteResp: "msgVoteResp",
+	msgSnap:     "msgSnap",
 }
 
 func (mt messageType) String() string {
@@ -69,6 +71,7 @@ type Message struct {
 	PrevTerm int
 	Entries  []Entry
 	Commit   int
+	Snapshot Snapshot
 }
 
 type index struct {
@@ -151,12 +154,17 @@ func (sm *stateMachine) send(m Message) {
 func (sm *stateMachine) sendAppend(to int) {
 	in := sm.ins[to]
 	m := Message{}
-	m.Type = msgApp
 	m.To = to
 	m.Index = in.next - 1
-	m.LogTerm = sm.log.term(in.next - 1)
-	m.Entries = sm.log.entries(in.next)
-	m.Commit = sm.log.committed
+	if sm.needSnapshot(m.Index) {
+		m.Type = msgSnap
+		m.Snapshot = sm.snapshoter.GetSnap()
+	} else {
+		m.Type = msgApp
+		m.LogTerm = sm.log.term(in.next - 1)
+		m.Entries = sm.log.entries(in.next)
+		m.Commit = sm.log.committed
+	}
 	sm.send(m)
 }
 
@@ -244,7 +252,7 @@ func (sm *stateMachine) becomeLeader() {
 	sm.lead = sm.id
 	sm.state = stateLeader
 
-	for _, e := range sm.log.ents[sm.log.committed:] {
+	for _, e := range sm.log.entries(sm.log.committed + 1) {
 		if e.isConfig() {
 			sm.pendingConf = true
 		}
@@ -298,6 +306,11 @@ func (sm *stateMachine) handleAppendEntries(m Message) {
 	}
 }
 
+func (sm *stateMachine) handleSnapshot(m Message) {
+	sm.restore(m.Snapshot)
+	sm.send(Message{To: m.From, Type: msgAppResp, Index: sm.log.lastIndex()})
+}
+
 func (sm *stateMachine) addNode(id int) {
 	sm.ins[id] = &index{next: sm.log.lastIndex() + 1}
 	sm.pendingConf = false
@@ -350,6 +363,9 @@ func stepCandidate(sm *stateMachine, m Message) bool {
 	case msgApp:
 		sm.becomeFollower(sm.term, m.From)
 		sm.handleAppendEntries(m)
+	case msgSnap:
+		sm.becomeFollower(m.Term, m.From)
+		sm.handleSnapshot(m)
 	case msgVote:
 		sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1})
 	case msgVoteResp:
@@ -375,6 +391,8 @@ func stepFollower(sm *stateMachine, m Message) bool {
 		sm.send(m)
 	case msgApp:
 		sm.handleAppendEntries(m)
+	case msgSnap:
+		sm.handleSnapshot(m)
 	case msgVote:
 		if (sm.vote == none || sm.vote == m.From) && sm.log.isUpToDate(m.Index, m.LogTerm) {
 			sm.vote = m.From
@@ -417,6 +435,16 @@ func (sm *stateMachine) restore(s Snapshot) {
 	sm.snapshoter.Restore(s)
 }
 
+func (sm *stateMachine) needSnapshot(i int) bool {
+	if i < sm.log.offset {
+		if sm.snapshoter == nil {
+			panic("need snapshot but snapshoter is nil")
+		}
+		return true
+	}
+	return false
+}
+
 func (sm *stateMachine) nodes() []int {
 	nodes := make([]int, 0, len(sm.ins))
 	for k := range sm.ins {

+ 87 - 0
raft/raft_test.go

@@ -802,6 +802,92 @@ func TestRestore(t *testing.T) {
 	}
 }
 
+func TestProvideSnap(t *testing.T) {
+	s := Snapshot{
+		Index: defaultCompactThreshold + 1,
+		Term:  defaultCompactThreshold + 1,
+		Nodes: []int{0, 1},
+	}
+	sm := newStateMachine(0, []int{0})
+	sm.setSnapshoter(new(logSnapshoter))
+	// restore the statemachin from a snapshot
+	// so it has a compacted log and a snapshot
+	sm.restore(s)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.Step(Message{Type: msgBeat})
+	msgs := sm.Msgs()
+	if len(msgs) != 1 {
+		t.Errorf("len(msgs) = %d, want 1", len(msgs))
+	}
+	m := msgs[0]
+	if m.Type != msgApp {
+		t.Errorf("m.Type = %v, want %v", m.Type, msgApp)
+	}
+
+	// force set the next of node 1, so that
+	// node 1 needs a snapshot
+	sm.ins[1].next = sm.log.offset
+
+	sm.Step(Message{Type: msgBeat})
+	msgs = sm.Msgs()
+	if len(msgs) != 1 {
+		t.Errorf("len(msgs) = %d, want 1", len(msgs))
+	}
+	m = msgs[0]
+	if m.Type != msgSnap {
+		t.Errorf("m.Type = %v, want %v", m.Type, msgSnap)
+	}
+}
+
+func TestRestoreFromSnapMsg(t *testing.T) {
+	s := Snapshot{
+		Index: defaultCompactThreshold + 1,
+		Term:  defaultCompactThreshold + 1,
+		Nodes: []int{0, 1},
+	}
+	m := Message{Type: msgSnap, From: 0, Term: 1, Snapshot: s}
+
+	sm := newStateMachine(1, []int{0, 1})
+	sm.setSnapshoter(new(logSnapshoter))
+	sm.Step(m)
+
+	if !reflect.DeepEqual(sm.snapshoter.GetSnap(), s) {
+		t.Errorf("snapshot = %+v, want %+v", sm.snapshoter.GetSnap(), s)
+	}
+}
+
+func TestSlowNodeRestore(t *testing.T) {
+	nt := newNetwork(nil, nil, nil)
+	nt.send(Message{To: 0, Type: msgHup})
+
+	nt.isolate(2)
+	for j := 0; j < defaultCompactThreshold+1; j++ {
+		nt.send(Message{To: 0, Type: msgProp, Entries: []Entry{{}}})
+	}
+	lead := nt.peers[0].(*stateMachine)
+	lead.nextEnts()
+	if !lead.maybeCompact() {
+		t.Errorf("compacted = false, want true")
+	}
+
+	nt.recover()
+	nt.send(Message{To: 0, Type: msgBeat})
+
+	follower := nt.peers[2].(*stateMachine)
+	if !reflect.DeepEqual(follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap()) {
+		t.Errorf("follower.snap = %+v, want %+v", follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap())
+	}
+
+	committed := follower.log.lastIndex()
+	nt.send(Message{To: 0, Type: msgProp, Entries: []Entry{{}}})
+	if follower.log.committed != committed+1 {
+		t.Errorf("follower.comitted = %d, want %d", follower.log.committed, committed+1)
+	}
+}
+
 func ents(terms ...int) *stateMachine {
 	ents := []Entry{{}}
 	for _, term := range terms {
@@ -836,6 +922,7 @@ func newNetwork(peers ...Interface) *network {
 		switch v := p.(type) {
 		case nil:
 			sm := newStateMachine(id, defaultPeerAddrs)
+			sm.setSnapshoter(new(logSnapshoter))
 			npeers[id] = sm
 		case *stateMachine:
 			v.id = id