Browse Source

raft: clean stateMachine

Xiang Li 11 years ago
parent
commit
c223eca938
3 changed files with 116 additions and 31 deletions
  1. 1 1
      raft/node.go
  2. 44 12
      raft/raft.go
  3. 71 18
      raft/raft_test.go

+ 1 - 1
raft/node.go

@@ -13,7 +13,7 @@ type Node struct {
 
 func New(k, addr int, next Interface) *Node {
 	n := &Node{
-		sm: newStateMachine(k, addr, next),
+		sm: newStateMachine(k, addr),
 	}
 	return n
 }

+ 44 - 12
raft/raft.go

@@ -108,15 +108,15 @@ type stateMachine struct {
 
 	votes map[int]bool
 
-	next Interface
+	msgs []Message
 
 	// the leader addr
 	lead int
 }
 
-func newStateMachine(k, addr int, next Interface) *stateMachine {
+func newStateMachine(k, addr int) *stateMachine {
 	log := make([]Entry, 1, 1024)
-	sm := &stateMachine{k: k, addr: addr, next: next, log: log}
+	sm := &stateMachine{k: k, addr: addr, log: log}
 	sm.reset()
 	return sm
 }
@@ -145,6 +145,14 @@ func (sm *stateMachine) append(after int, ents ...Entry) int {
 	return len(sm.log) - 1
 }
 
+func (sm *stateMachine) maybeAppend(index, logTerm int, ents ...Entry) bool {
+	if sm.isLogOk(index, logTerm) {
+		sm.append(index, ents...)
+		return true
+	}
+	return false
+}
+
 func (sm *stateMachine) isLogOk(i, term int) bool {
 	if i > sm.li() {
 		return false
@@ -152,11 +160,11 @@ func (sm *stateMachine) isLogOk(i, term int) bool {
 	return sm.log[i].Term == term
 }
 
-// send persists state to stable storage and then sends m over the network to m.To
+// send persists state to stable storage and then sends to its mailbox
 func (sm *stateMachine) send(m Message) {
 	m.From = sm.addr
 	m.Term = sm.term
-	sm.next.Step(m)
+	sm.msgs = append(sm.msgs, m)
 }
 
 // sendAppend sends RRPC, with entries to all peers that are not up-to-date according to sm.mis.
@@ -233,14 +241,39 @@ func (sm *stateMachine) becomeFollower(term, lead int) {
 	sm.state = stateFollower
 }
 
+func (sm *stateMachine) becomeCandidate() {
+	// TODO(xiangli) remove the panic when the raft implementation is stable
+	if sm.state == stateLeader {
+		panic("invalid transition [leader -> candidate]")
+	}
+	sm.reset()
+	sm.term++
+	sm.vote = sm.addr
+	sm.state = stateCandidate
+	sm.poll(sm.addr, true)
+}
+
+func (sm *stateMachine) becomeLeader() {
+	// TODO(xiangli) remove the panic when the raft implementation is stable
+	if sm.state == stateFollower {
+		panic("invalid transition [follower -> leader]")
+	}
+	sm.reset()
+	sm.lead = sm.addr
+	sm.state = stateLeader
+}
+
+func (sm *stateMachine) Msgs() []Message {
+	msgs := sm.msgs
+	sm.msgs = make([]Message, 0)
+
+	return msgs
+}
+
 func (sm *stateMachine) Step(m Message) {
 	switch m.Type {
 	case msgHup:
-		sm.term++
-		sm.reset()
-		sm.state = stateCandidate
-		sm.vote = sm.addr
-		sm.poll(sm.addr, true)
+		sm.becomeCandidate()
 		for i := 0; i < sm.k; i++ {
 			if i == sm.addr {
 				continue
@@ -301,8 +334,7 @@ func (sm *stateMachine) Step(m Message) {
 			gr := sm.poll(m.From, m.Index >= 0)
 			switch sm.q() {
 			case gr:
-				sm.state = stateLeader
-				sm.lead = sm.addr
+				sm.becomeLeader()
 				sm.sendAppend()
 			case len(sm.votes) - gr:
 				sm.becomeFollower(sm.term, none)

+ 71 - 18
raft/raft_test.go

@@ -23,9 +23,9 @@ func TestLeaderElection(t *testing.T) {
 		{
 			newNetwork(
 				nil,
-				&stateMachine{log: []Entry{{}, {Term: 1}}},
-				&stateMachine{log: []Entry{{}, {Term: 2}}},
-				&stateMachine{log: []Entry{{}, {Term: 1}, {Term: 3}}},
+				&nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil},
+				&nsm{stateMachine{log: []Entry{{}, {Term: 2}}}, nil},
+				&nsm{stateMachine{log: []Entry{{}, {Term: 1}, {Term: 3}}}, nil},
 				nil,
 			),
 			stateFollower,
@@ -34,10 +34,10 @@ func TestLeaderElection(t *testing.T) {
 		// logs converge
 		{
 			newNetwork(
-				&stateMachine{log: []Entry{{}, {Term: 1}}},
+				&nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil},
 				nil,
-				&stateMachine{log: []Entry{{}, {Term: 2}}},
-				&stateMachine{log: []Entry{{}, {Term: 1}}},
+				&nsm{stateMachine{log: []Entry{{}, {Term: 2}}}, nil},
+				&nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil},
 				nil,
 			),
 			stateLeader,
@@ -46,7 +46,7 @@ func TestLeaderElection(t *testing.T) {
 
 	for i, tt := range tests {
 		tt.Step(Message{To: 0, Type: msgHup})
-		sm := tt.network.ss[0].(*stateMachine)
+		sm := tt.network.ss[0].(*nsm)
 		if sm.state != tt.state {
 			t.Errorf("#%d: state = %s, want %s", i, sm.state, tt.state)
 		}
@@ -57,8 +57,8 @@ func TestLeaderElection(t *testing.T) {
 }
 
 func TestDualingCandidates(t *testing.T) {
-	a := &stateMachine{log: defaultLog}
-	c := &stateMachine{log: defaultLog}
+	a := &nsm{stateMachine{log: defaultLog}, nil}
+	c := &nsm{stateMachine{log: defaultLog}, nil}
 
 	tt := newNetwork(a, nil, c)
 
@@ -82,7 +82,7 @@ func TestDualingCandidates(t *testing.T) {
 	tt.Step(Message{To: 2, Type: msgHup})
 
 	tests := []struct {
-		sm    *stateMachine
+		sm    *nsm
 		state stateType
 		term  int
 	}{
@@ -106,7 +106,7 @@ func TestDualingCandidates(t *testing.T) {
 }
 
 func TestCandidateConcede(t *testing.T) {
-	a := &stateMachine{log: defaultLog}
+	a := &nsm{stateMachine{log: defaultLog}, nil}
 
 	tt := newNetwork(a, nil, nil)
 	tt.tee = stepperFunc(func(m Message) {
@@ -143,7 +143,7 @@ func TestOldMessages(t *testing.T) {
 	tt := newNetwork(nil, nil, nil)
 	// make 0 leader @ term 3
 	tt.Step(Message{To: 0, Type: msgHup})
-	tt.Step(Message{To: 0, Type: msgHup})
+	tt.Step(Message{To: 1, Type: msgHup})
 	tt.Step(Message{To: 0, Type: msgHup})
 	// pretend we're an old leader trying to make progress
 	tt.Step(Message{To: 0, Type: msgApp, Term: 1, Entries: []Entry{{Term: 1}}})
@@ -204,7 +204,7 @@ func TestProposal(t *testing.T) {
 				t.Errorf("#%d: diff:%s", i, diff)
 			}
 		}
-		sm := tt.network.ss[0].(*stateMachine)
+		sm := tt.network.ss[0].(*nsm)
 		if g := sm.term; g != 1 {
 			t.Errorf("#%d: term = %d, want %d", i, g, 1)
 		}
@@ -235,7 +235,7 @@ func TestProposalByProxy(t *testing.T) {
 				t.Errorf("#%d: bad entry: %s", i, diff)
 			}
 		}
-		sm := tt.ss[0].(*stateMachine)
+		sm := tt.ss[0].(*nsm)
 		if g := sm.term; g != 1 {
 			t.Errorf("#%d: term = %d, want %d", i, g, 1)
 		}
@@ -305,7 +305,7 @@ func TestVote(t *testing.T) {
 
 	for i, tt := range tests {
 		called := false
-		sm := &stateMachine{log: []Entry{{}, {Term: 2}, {Term: 2}}}
+		sm := &nsm{stateMachine{log: []Entry{{}, {Term: 2}, {Term: 2}}}, nil}
 		sm.next = stepperFunc(func(m Message) {
 			called = true
 			if m.Index != tt.w {
@@ -319,6 +319,46 @@ func TestVote(t *testing.T) {
 	}
 }
 
+func TestAllServerStepdown(t *testing.T) {
+	tests := []stateType{stateFollower, stateCandidate, stateLeader}
+
+	want := struct {
+		state stateType
+		term  int
+		index int
+	}{stateFollower, 3, 1}
+
+	tmsgTypes := [...]messageType{msgVote, msgApp}
+	tterm := 3
+
+	for i, tt := range tests {
+		sm := newStateMachine(3, 0)
+		switch tt {
+		case stateFollower:
+			sm.becomeFollower(1, 0)
+		case stateCandidate:
+			sm.becomeCandidate()
+		case stateLeader:
+			sm.becomeCandidate()
+			sm.becomeLeader()
+		}
+
+		for j, msgType := range tmsgTypes {
+			sm.Step(Message{Type: msgType, Term: tterm, LogTerm: tterm})
+
+			if sm.state != want.state {
+				t.Errorf("#%d.%d state = %v , want %v", i, j, sm.state, want.state)
+			}
+			if sm.term != want.term {
+				t.Errorf("#%d.%d term = %v , want %v", i, j, sm.term, want.term)
+			}
+			if len(sm.log) != want.index {
+				t.Errorf("#%d.%d index = %v , want %v", i, j, len(sm.log), want.index)
+			}
+		}
+	}
+}
+
 func TestLogDiff(t *testing.T) {
 	a := []Entry{{}, {Term: 1}, {Term: 2}}
 	b := []Entry{{}, {Term: 1}, {Term: 2}}
@@ -349,8 +389,8 @@ func newNetwork(nodes ...Interface) *network {
 	for i, n := range nodes {
 		switch v := n.(type) {
 		case nil:
-			nt.ss[i] = newStateMachine(len(nodes), i, nt)
-		case *stateMachine:
+			nt.ss[i] = &nsm{*newStateMachine(len(nodes), i), nt}
+		case *nsm:
 			v.k = len(nodes)
 			v.addr = i
 			if v.next == nil {
@@ -375,7 +415,7 @@ func (nt network) Step(m Message) {
 func (nt network) logs() [][]Entry {
 	ls := make([][]Entry, len(nt.ss))
 	for i, node := range nt.ss {
-		if sm, ok := node.(*stateMachine); ok {
+		if sm, ok := node.(*nsm); ok {
 			ls[i] = sm.log
 		}
 	}
@@ -462,3 +502,16 @@ type stepperFunc func(Message)
 func (f stepperFunc) Step(m Message) { f(m) }
 
 var nopStepper = stepperFunc(func(Message) {})
+
+type nsm struct {
+	stateMachine
+	next Interface
+}
+
+func (n *nsm) Step(m Message) {
+	(&n.stateMachine).Step(m)
+	ms := n.Msgs()
+	for _, m := range ms {
+		n.next.Step(m)
+	}
+}