Browse Source

raft: establish an interface around vote counting

This cleans up the mechanical refactor in the last commit and will
help with etcd-io/etcd#7625 as well.
Tobias Schottdorf 6 years ago
parent
commit
a6f222e62d
2 changed files with 60 additions and 27 deletions
  1. 40 8
      raft/progress.go
  2. 20 19
      raft/raft.go

+ 40 - 8
raft/progress.go

@@ -301,11 +301,13 @@ type prs struct {
 }
 
 func makePRS(maxInflight int) prs {
-	return prs{
+	p := prs{
+		maxInflight: maxInflight,
 		nodes:       map[uint64]*Progress{},
 		learners:    map[uint64]*Progress{},
-		maxInflight: maxInflight,
+		votes:       map[uint64]bool{},
 	}
+	return p
 }
 
 func (p *prs) quorum() int {
@@ -379,12 +381,6 @@ func (p *prs) visit(f func(id uint64, pr *Progress)) {
 	}
 }
 
-func (p *prs) reset() {
-	p.nodes = map[uint64]*Progress{}
-	p.learners = map[uint64]*Progress{}
-	p.matchBuf = nil
-}
-
 func (p *prs) voterNodes() []uint64 {
 	nodes := make([]uint64, 0, len(p.nodes))
 	for id := range p.nodes {
@@ -402,3 +398,39 @@ func (p *prs) learnerNodes() []uint64 {
 	sort.Sort(uint64Slice(nodes))
 	return nodes
 }
+
+// resetVotes prepares for a new round of vote counting via recordVote.
+func (p *prs) resetVotes() {
+	p.votes = map[uint64]bool{}
+}
+
+// recordVote records that the node with the given id voted for this Raft
+// instance if v == true (and declined it otherwise).
+func (p *prs) recordVote(id uint64, v bool) {
+	_, ok := p.votes[id]
+	if !ok {
+		p.votes[id] = v
+	}
+}
+
+// tallyVotes returns the number of granted and rejected votes, and whether the
+// election outcome is known.
+func (p *prs) tallyVotes() (granted int, rejected int, result electionResult) {
+	for _, v := range p.votes {
+		if v {
+			granted++
+		} else {
+			rejected++
+		}
+	}
+
+	q := p.quorum()
+
+	result = electionIndeterminate
+	if granted >= q {
+		result = electionWon
+	} else if rejected >= q {
+		result = electionLost
+	}
+	return granted, rejected, result
+}

+ 20 - 19
raft/raft.go

@@ -573,7 +573,7 @@ func (r *raft) reset(term uint64) {
 
 	r.abortLeaderTransfer()
 
-	r.prs.votes = make(map[uint64]bool)
+	r.prs.resetVotes()
 	r.prs.visit(func(id uint64, pr *Progress) {
 		*pr = Progress{
 			Match:     0,
@@ -681,7 +681,7 @@ func (r *raft) becomePreCandidate() {
 	// but doesn't change anything else. In particular it does not increase
 	// r.Term or change r.Vote.
 	r.step = stepCandidate
-	r.prs.votes = make(map[uint64]bool)
+	r.prs.resetVotes()
 	r.tick = r.tickElection
 	r.lead = None
 	r.state = StatePreCandidate
@@ -737,7 +737,7 @@ func (r *raft) campaign(t CampaignType) {
 		voteMsg = pb.MsgVote
 		term = r.Term
 	}
-	if r.prs.quorum() == r.poll(r.id, voteRespMsgType(voteMsg), true) {
+	if _, _, res := r.poll(r.id, voteRespMsgType(voteMsg), true); res == electionWon {
 		// We won the election after voting for ourselves (which must mean that
 		// this is a single-node cluster). Advance to the next state.
 		if t == campaignPreElection {
@@ -762,21 +762,22 @@ func (r *raft) campaign(t CampaignType) {
 	}
 }
 
-func (r *raft) poll(id uint64, t pb.MessageType, v bool) (granted int) {
+type electionResult byte
+
+const (
+	electionIndeterminate electionResult = iota
+	electionLost
+	electionWon
+)
+
+func (r *raft) poll(id uint64, t pb.MessageType, v bool) (granted int, rejected int, result electionResult) {
 	if v {
 		r.logger.Infof("%x received %s from %x at term %d", r.id, t, id, r.Term)
 	} else {
 		r.logger.Infof("%x received %s rejection from %x at term %d", r.id, t, id, r.Term)
 	}
-	if _, ok := r.prs.votes[id]; !ok {
-		r.prs.votes[id] = v
-	}
-	for _, vv := range r.prs.votes {
-		if vv {
-			granted++
-		}
-	}
-	return granted
+	r.prs.recordVote(id, v)
+	return r.prs.tallyVotes()
 }
 
 func (r *raft) Step(m pb.Message) error {
@@ -1178,17 +1179,17 @@ func stepCandidate(r *raft, m pb.Message) error {
 		r.becomeFollower(m.Term, m.From) // always m.Term == r.Term
 		r.handleSnapshot(m)
 	case myVoteRespType:
-		gr := r.poll(m.From, m.Type, !m.Reject)
-		r.logger.Infof("%x [quorum:%d] has received %d %s votes and %d vote rejections", r.id, r.prs.quorum(), gr, m.Type, len(r.prs.votes)-gr)
-		switch r.prs.quorum() {
-		case gr:
+		gr, rj, res := r.poll(m.From, m.Type, !m.Reject)
+		r.logger.Infof("%x has received %d %s votes and %d vote rejections", r.id, gr, m.Type, rj)
+		switch res {
+		case electionWon:
 			if r.state == StatePreCandidate {
 				r.campaign(campaignElection)
 			} else {
 				r.becomeLeader()
 				r.bcastAppend()
 			}
-		case len(r.prs.votes) - gr:
+		case electionLost:
 			// pb.MsgPreVoteResp contains future term of pre-candidate
 			// m.Term > r.Term; reuse r.Term
 			r.becomeFollower(r.Term, None)
@@ -1317,7 +1318,7 @@ func (r *raft) restore(s pb.Snapshot) bool {
 		r.id, r.raftLog.committed, r.raftLog.lastIndex(), r.raftLog.lastTerm(), s.Metadata.Index, s.Metadata.Term)
 
 	r.raftLog.restore(s)
-	r.prs.reset()
+	r.prs = makePRS(r.prs.maxInflight)
 	r.restoreNode(s.Metadata.ConfState.Nodes, false)
 	r.restoreNode(s.Metadata.ConfState.Learners, true)
 	return true