Browse Source

raft: remove heal from network

Blake Mizerany 11 years ago
parent
commit
fcc7a42d6c
1 changed files with 19 additions and 17 deletions
  1. 19 17
      raft_test.go

+ 19 - 17
raft_test.go

@@ -18,6 +18,7 @@ func TestLeaderElection(t *testing.T) {
 		{newNetwork(nil, nopStepper, nopStepper), stateCandidate},
 		{newNetwork(nil, nopStepper, nopStepper), stateCandidate},
 		{newNetwork(nil, nopStepper, nopStepper, nil), stateCandidate},
 		{newNetwork(nil, nopStepper, nopStepper, nil), stateCandidate},
 		{newNetwork(nil, nopStepper, nopStepper, nil, nil), stateLeader},
 		{newNetwork(nil, nopStepper, nopStepper, nil, nil), stateLeader},
+		/// {newNetwork(nil, newPartNode(), falseVote()), stateFollower},
 	}
 	}
 
 
 	for i, tt := range tests {
 	for i, tt := range tests {
@@ -33,15 +34,20 @@ func TestLeaderElection(t *testing.T) {
 }
 }
 
 
 func TestDualingCandidates(t *testing.T) {
 func TestDualingCandidates(t *testing.T) {
-	a := &stateMachine{
-		log:  []Entry{{}},
-		next: nopStepper, // field next is nil (partitioned)
-	}
-	c := &stateMachine{
-		log:  []Entry{{}},
-		next: nopStepper, // field next is nil (partitioned)
-	}
+	a := &stateMachine{log: defaultLog}
+	c := &stateMachine{log: defaultLog}
+
 	tt := newNetwork(a, nil, c)
 	tt := newNetwork(a, nil, c)
+
+	heal := false
+	next := stepperFunc(func(m Message) {
+		if heal {
+			tt.step(m)
+		}
+	})
+	a.next = next
+	c.next = next
+
 	tt.tee = stepperFunc(func(m Message) {
 	tt.tee = stepperFunc(func(m Message) {
 		t.Logf("m = %+v", m)
 		t.Logf("m = %+v", m)
 	})
 	})
@@ -49,7 +55,7 @@ func TestDualingCandidates(t *testing.T) {
 	tt.step(Message{To: 2, Type: msgHup})
 	tt.step(Message{To: 2, Type: msgHup})
 
 
 	t.Log("healing")
 	t.Log("healing")
-	tt.heal()
+	heal = true
 	tt.step(Message{To: 2, Type: msgHup})
 	tt.step(Message{To: 2, Type: msgHup})
 
 
 	tests := []struct {
 	tests := []struct {
@@ -255,6 +261,8 @@ func newNetwork(nodes ...stepper) *network {
 		case *stateMachine:
 		case *stateMachine:
 			v.k = len(nodes)
 			v.k = len(nodes)
 			v.addr = i
 			v.addr = i
+		default:
+			nt.ss[i] = v
 		}
 		}
 	}
 	}
 	return nt
 	return nt
@@ -267,14 +275,6 @@ func (nt network) step(m Message) {
 	nt.ss[m.To].step(m)
 	nt.ss[m.To].step(m)
 }
 }
 
 
-func (nt network) heal() {
-	for _, s := range nt.ss {
-		if sm, ok := s.(*stateMachine); ok {
-			sm.next = nt
-		}
-	}
-}
-
 // logs returns all logs in nt prepended with want. If a node is not a
 // logs returns all logs in nt prepended with want. If a node is not a
 // *stateMachine, its log will be nil.
 // *stateMachine, its log will be nil.
 func (nt network) logs() [][]Entry {
 func (nt network) logs() [][]Entry {
@@ -367,3 +367,5 @@ type stepperFunc func(Message)
 func (f stepperFunc) step(m Message) { f(m) }
 func (f stepperFunc) step(m Message) { f(m) }
 
 
 var nopStepper = stepperFunc(func(Message) {})
 var nopStepper = stepperFunc(func(Message) {})
+
+type nextStepperFunc func(Message, stepper)