Browse Source

raft: test dualing proposers

Blake Mizerany 11 years ago
parent
commit
895d80d0e1
2 changed files with 64 additions and 19 deletions
  1. 2 1
      raft.go
  2. 62 18
      raft_test.go

+ 2 - 1
raft.go

@@ -34,7 +34,7 @@ func (mt messageType) String() string {
 var errNoLeader = errors.New("no leader")
 
 const (
-	stateFollower = iota
+	stateFollower stateType = iota
 	stateCandidate
 	stateLeader
 )
@@ -244,6 +244,7 @@ func (sm *stateMachine) step(m Message) {
 		sm.term++
 		sm.reset()
 		sm.state = stateCandidate
+		sm.vote = sm.addr
 		sm.poll(sm.addr, true)
 		for i := 0; i < sm.k; i++ {
 			if i == sm.addr {

+ 62 - 18
raft_test.go

@@ -10,7 +10,7 @@ var defaultLog = []Entry{{}}
 
 func TestLeaderElection(t *testing.T) {
 	tests := []struct {
-		network
+		*network
 		state stateType
 	}{
 		{newNetwork(nil, nil, nil), stateLeader},
@@ -22,7 +22,7 @@ func TestLeaderElection(t *testing.T) {
 
 	for i, tt := range tests {
 		tt.step(Message{To: 0, Type: msgHup})
-		sm := tt.network[0].(*stateMachine)
+		sm := tt.network.ss[0].(*stateMachine)
 		if sm.state != tt.state {
 			t.Errorf("#%d: state = %s, want %s", i, sm.state, tt.state)
 		}
@@ -32,12 +32,44 @@ func TestLeaderElection(t *testing.T) {
 	}
 }
 
+func TestDualingCandidates(t *testing.T) {
+	a := &stateMachine{
+		log:  []Entry{{}},
+		next: nopStepper, // field next is nil (partitioned)
+	}
+	c := &stateMachine{
+		log:  []Entry{{}},
+		next: nopStepper, // field next is nil (partitioned)
+	}
+	tt := newNetwork(a, nil, c)
+	tt.tee = stepperFunc(func(m Message) {
+		t.Logf("m = %+v", m)
+	})
+	tt.step(Message{To: 0, Type: msgHup})
+	tt.step(Message{To: 2, Type: msgHup})
+
+	t.Log("healing")
+	tt.heal()
+	tt.step(Message{To: 2, Type: msgHup})
+	if c.state != stateLeader {
+		t.Errorf("state = %s, want %s", c.state, stateLeader)
+	}
+	if g := c.term; g != 2 {
+		t.Errorf("term = %d, want %d", g, 2)
+	}
+	if g := diffLogs(tt.logs(defaultLog)); g != nil {
+		for _, diff := range g {
+			t.Errorf("bag log:\n%s", diff)
+		}
+	}
+}
+
 func TestProposal(t *testing.T) {
 	data := []byte("somedata")
 	successLog := []Entry{{}, {Term: 1, Data: data}}
 
 	tests := []struct {
-		network
+		*network
 		log       []Entry
 		willpanic bool
 	}{
@@ -73,7 +105,7 @@ func TestProposal(t *testing.T) {
 				t.Errorf("#%d: bag log:\n%s", i, diff)
 			}
 		}
-		sm := tt.network[0].(*stateMachine)
+		sm := tt.network.ss[0].(*stateMachine)
 		if g := sm.term; g != 1 {
 			t.Errorf("#%d: term = %d, want %d", i, g, 1)
 		}
@@ -85,7 +117,7 @@ func TestProposalByProxy(t *testing.T) {
 	successLog := []Entry{{}, {Term: 1, Data: data}}
 
 	tests := []struct {
-		network
+		*network
 		log []Entry
 	}{
 		{newNetwork(nil, nil, nil), successLog},
@@ -93,59 +125,71 @@ func TestProposalByProxy(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		step := stepperFunc(func(m Message) {
+		tt.tee = stepperFunc(func(m Message) {
 			t.Logf("#%d: m = %+v", i, m)
-			tt.step(m)
 		})
 
 		// promote 0 the leader
-		step(Message{To: 0, Type: msgHup})
+		tt.step(Message{To: 0, Type: msgHup})
 
 		// propose via follower
-		step(Message{To: 1, Type: msgProp, Data: []byte("somedata")})
+		tt.step(Message{To: 1, Type: msgProp, Data: []byte("somedata")})
 
 		if g := diffLogs(tt.logs(tt.log)); g != nil {
 			for _, diff := range g {
 				t.Errorf("#%d: bag log:\n%s", i, diff)
 			}
 		}
-		sm := tt.network[0].(*stateMachine)
+		sm := tt.network.ss[0].(*stateMachine)
 		if g := sm.term; g != 1 {
 			t.Errorf("#%d: term = %d, want %d", i, g, 1)
 		}
 	}
 }
 
-type network []stepper
+type network struct {
+	tee stepper
+	ss  []stepper
+}
 
 // newNetwork initializes a network from nodes. A nil node will be replaced
 // with a new *stateMachine. A *stateMachine will get its k, addr, and next
 // fields set.
-func newNetwork(nodes ...stepper) network {
-	nt := network(nodes)
+func newNetwork(nodes ...stepper) *network {
+	nt := &network{ss: nodes}
 	for i, n := range nodes {
 		switch v := n.(type) {
 		case nil:
-			nt[i] = newStateMachine(len(nodes), i, &nt)
+			nt.ss[i] = newStateMachine(len(nodes), i, nt)
 		case *stateMachine:
 			v.k = len(nodes)
 			v.addr = i
-			v.next = &nt
 		}
 	}
 	return nt
 }
 
 func (nt network) step(m Message) {
-	nt[m.To].step(m)
+	if nt.tee != nil {
+		nt.tee.step(m)
+	}
+	nt.ss[m.To].step(m)
+}
+
+func (nt network) heal() {
+	for _, s := range nt.ss {
+		if sm, ok := s.(*stateMachine); ok {
+			sm.next = nt
+		}
+	}
 }
 
 // logs returns all logs in nt prepended with want. If a node is not a
 // *stateMachine, its log will be nil.
 func (nt network) logs(want []Entry) [][]Entry {
-	ls := make([][]Entry, len(nt)+1)
+	ls := make([][]Entry, len(nt.ss)+1)
 	ls[0] = want
-	for i, node := range nt {
+	for i, node := range nt.ss {
 		if sm, ok := node.(*stateMachine); ok {
 			ls[i] = sm.log
 		}