Browse Source

Merge pull request #79 from etcd-team/use_msg

raft: make tick send out messages.
Xiang Li 11 years ago
parent
commit
70580de197
3 changed files with 35 additions and 30 deletions
  1. 4 8
      etcdserver/server_test.go
  2. 12 18
      raft/raft.go
  3. 19 4
      raft/raft_test.go

+ 4 - 8
etcdserver/server_test.go

@@ -36,7 +36,7 @@ func testServer(t *testing.T, ns int64) {
 	}
 
 	for i := int64(0); i < ns; i++ {
-		n := raft.Start(i, peers, 1, 10)
+		n := raft.Start(i, peers, 10, 1)
 		tk := time.NewTicker(10 * time.Millisecond)
 		defer tk.Stop()
 		srv := &Server{
@@ -47,16 +47,12 @@ func testServer(t *testing.T, ns int64) {
 			Ticker: tk.C,
 		}
 		Start(srv)
-
+		// TODO(xiangli): randomize election timeout
+		// then remove this sleep.
+		time.Sleep(1 * time.Millisecond)
 		ss[i] = srv
 	}
 
-	for i := int64(0); i < ns; i++ {
-		if err := ss[i].Node.Campaign(ctx); err != nil {
-			t.Fatal(err)
-		}
-	}
-
 	for i := 1; i <= 10; i++ {
 		r := pb.Request{
 			Method: "PUT",

+ 12 - 18
raft/raft.go

@@ -56,12 +56,6 @@ var stmap = [...]string{
 	stateLeader:    "stateLeader",
 }
 
-var stepmap = [...]stepFunc{
-	stateFollower:  stepFollower,
-	stateCandidate: stepCandidate,
-	stateLeader:    stepLeader,
-}
-
 func (st stateType) String() string {
 	return stmap[int64(st)]
 }
@@ -126,6 +120,7 @@ type raft struct {
 	heartbeatTimeout int
 	electionTimeout  int
 	tick             func()
+	step             stepFunc
 }
 
 func newRaft(id int64, peers []int64, election, heartbeat int) *raft {
@@ -249,6 +244,7 @@ func (r *raft) reset(term int64) {
 	r.Term = term
 	r.lead = none
 	r.Vote = none
+	r.elapsed = 0
 	r.votes = make(map[int64]bool)
 	for i := range r.prs {
 		r.prs[i] = &progress{next: r.raftLog.lastIndex() + 1}
@@ -272,9 +268,10 @@ func (r *raft) appendEntry(e pb.Entry) {
 
 func (r *raft) tickElection() {
 	r.elapsed++
+	// TODO (xiangli): elctionTimeout should be randomized.
 	if r.elapsed > r.electionTimeout {
 		r.elapsed = 0
-		r.campaign()
+		r.Step(pb.Message{From: r.id, Type: msgHup})
 	}
 }
 
@@ -282,41 +279,39 @@ func (r *raft) tickHeartbeat() {
 	r.elapsed++
 	if r.elapsed > r.heartbeatTimeout {
 		r.elapsed = 0
-		r.bcastHeartbeat()
+		r.Step(pb.Message{From: r.id, Type: msgBeat})
 	}
 }
 
-func (r *raft) setTick(f func()) {
-	r.elapsed = 0
-	r.tick = f
-}
-
 func (r *raft) becomeFollower(term int64, lead int64) {
-	r.setTick(r.tickElection)
+	r.step = stepFollower
 	r.reset(term)
+	r.tick = r.tickElection
 	r.lead = lead
 	r.state = stateFollower
 	r.configuring = false
 }
 
 func (r *raft) becomeCandidate() {
-	r.setTick(r.tickElection)
 	// TODO(xiangli) remove the panic when the raft implementation is stable
 	if r.state == stateLeader {
 		panic("invalid transition [leader -> candidate]")
 	}
+	r.step = stepCandidate
 	r.reset(r.Term + 1)
+	r.tick = r.tickElection
 	r.Vote = r.id
 	r.state = stateCandidate
 }
 
 func (r *raft) becomeLeader() {
-	r.setTick(r.tickHeartbeat)
 	// TODO(xiangli) remove the panic when the raft implementation is stable
 	if r.state == stateFollower {
 		panic("invalid transition [follower -> leader]")
 	}
+	r.step = stepLeader
 	r.reset(r.Term)
+	r.tick = r.tickElection
 	r.lead = r.id
 	r.state = stateLeader
 
@@ -370,8 +365,7 @@ func (r *raft) Step(m pb.Message) error {
 	case m.Term < r.Term:
 		// ignore
 	}
-
-	stepmap[r.state](r, m)
+	r.step(r, m)
 	return nil
 }
 

+ 19 - 4
raft/raft_test.go

@@ -549,11 +549,18 @@ func TestRecvMsgVote(t *testing.T) {
 	}
 
 	for i, tt := range tests {
-		sm := &raft{
-			state:   tt.state,
-			State:   pb.State{Vote: tt.voteFor},
-			raftLog: &raftLog{ents: []pb.Entry{{}, {Term: 2}, {Term: 2}}},
+		sm := newRaft(0, []int64{0}, 0, 0)
+		sm.state = tt.state
+		switch tt.state {
+		case stateFollower:
+			sm.step = stepFollower
+		case stateCandidate:
+			sm.step = stepCandidate
+		case stateLeader:
+			sm.step = stepLeader
 		}
+		sm.State = pb.State{Vote: tt.voteFor}
+		sm.raftLog = &raftLog{ents: []pb.Entry{{}, {Term: 2}, {Term: 2}}}
 
 		sm.Step(pb.Message{Type: msgVote, From: 1, Index: tt.i, LogTerm: tt.term})
 
@@ -778,6 +785,14 @@ func TestRecvMsgBeat(t *testing.T) {
 		sm.raftLog = &raftLog{ents: []pb.Entry{{}, {Term: 0}, {Term: 1}}}
 		sm.Term = 1
 		sm.state = tt.state
+		switch tt.state {
+		case stateFollower:
+			sm.step = stepFollower
+		case stateCandidate:
+			sm.step = stepCandidate
+		case stateLeader:
+			sm.step = stepLeader
+		}
 		sm.Step(pb.Message{From: 0, To: 0, Type: msgBeat})
 
 		msgs := sm.ReadMessages()