Browse Source

raft: move more methods onto the progress tracker

Continues what was initiated in the last commit.
Tobias Schottdorf 6 years ago
parent
commit
ea82b2b758
7 changed files with 131 additions and 123 deletions
  1. 5 5
      raft/node.go
  2. 63 6
      raft/progress.go
  3. 29 81
      raft/raft.go
  4. 4 4
      raft/raft_flow_control_test.go
  5. 16 16
      raft/raft_test.go
  6. 3 3
      raft/rawnode.go
  7. 11 8
      raft/status.go

+ 5 - 5
raft/node.go

@@ -353,15 +353,15 @@ func (n *node) run(r *raft) {
 			}
 			}
 		case m := <-n.recvc:
 		case m := <-n.recvc:
 			// filter out response message from unknown From.
 			// filter out response message from unknown From.
-			if pr := r.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) {
+			if pr := r.prs.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) {
 				r.Step(m)
 				r.Step(m)
 			}
 			}
 		case cc := <-n.confc:
 		case cc := <-n.confc:
 			if cc.NodeID == None {
 			if cc.NodeID == None {
 				select {
 				select {
 				case n.confstatec <- pb.ConfState{
 				case n.confstatec <- pb.ConfState{
-					Nodes:    r.nodes(),
-					Learners: r.learnerNodes()}:
+					Nodes:    r.prs.voterNodes(),
+					Learners: r.prs.learnerNodes()}:
 				case <-n.done:
 				case <-n.done:
 				}
 				}
 				break
 				break
@@ -384,8 +384,8 @@ func (n *node) run(r *raft) {
 			}
 			}
 			select {
 			select {
 			case n.confstatec <- pb.ConfState{
 			case n.confstatec <- pb.ConfState{
-				Nodes:    r.nodes(),
-				Learners: r.learnerNodes()}:
+				Nodes:    r.prs.voterNodes(),
+				Learners: r.prs.learnerNodes()}:
 			case <-n.done:
 			case <-n.done:
 			}
 			}
 		case <-n.tickc:
 		case <-n.tickc:

+ 63 - 6
raft/progress.go

@@ -291,15 +291,17 @@ func (in *inflights) reset() {
 // the nodes and learners in it. In particular, it tracks the match index for
 // the nodes and learners in it. In particular, it tracks the match index for
 // each peer which in turn allows reasoning about the committed index.
 // each peer which in turn allows reasoning about the committed index.
 type prs struct {
 type prs struct {
-	nodes    map[uint64]*Progress
-	learners map[uint64]*Progress
-	matchBuf uint64Slice
+	nodes       map[uint64]*Progress
+	learners    map[uint64]*Progress
+	maxInflight int
+	matchBuf    uint64Slice
 }
 }
 
 
-func makePRS() prs {
+func makePRS(maxInflight int) prs {
 	return prs{
 	return prs{
-		nodes:    map[uint64]*Progress{},
-		learners: map[uint64]*Progress{},
+		nodes:       map[uint64]*Progress{},
+		learners:    map[uint64]*Progress{},
+		maxInflight: maxInflight,
 	}
 	}
 }
 }
 
 
@@ -307,6 +309,8 @@ func (p *prs) quorum() int {
 	return len(p.nodes)/2 + 1
 	return len(p.nodes)/2 + 1
 }
 }
 
 
+// committed returns the largest log index known to be committed based on what
+// the voting members of the group have acknowledged.
 func (p *prs) committed() uint64 {
 func (p *prs) committed() uint64 {
 	// Preserving matchBuf across calls is an optimization
 	// Preserving matchBuf across calls is an optimization
 	// used to avoid allocating a new slice on each call.
 	// used to avoid allocating a new slice on each call.
@@ -327,3 +331,56 @@ func (p *prs) removeAny(id uint64) {
 	delete(p.nodes, id)
 	delete(p.nodes, id)
 	delete(p.learners, id)
 	delete(p.learners, id)
 }
 }
+
+func (p *prs) getProgress(id uint64) *Progress {
+	if pr, ok := p.nodes[id]; ok {
+		return pr
+	}
+
+	return p.learners[id]
+}
+
+// initProgress initializes a new progress for the given node, replacing any that
+// may exist. It is invalid to replace a voter by a learner and attempts to do so
+// will result in a panic.
+func (p *prs) initProgress(id, match, next uint64, isLearner bool) {
+	if !isLearner {
+		delete(p.learners, id)
+		p.nodes[id] = &Progress{Next: next, Match: match, ins: newInflights(p.maxInflight)}
+		return
+	}
+
+	if _, ok := p.nodes[id]; ok {
+		panic(fmt.Sprintf("changing from voter to learner for %x", id))
+	}
+	p.learners[id] = &Progress{Next: next, Match: match, ins: newInflights(p.maxInflight), IsLearner: true}
+}
+
+func (p *prs) voterNodes() []uint64 {
+	nodes := make([]uint64, 0, len(p.nodes))
+	for id := range p.nodes {
+		nodes = append(nodes, id)
+	}
+	sort.Sort(uint64Slice(nodes))
+	return nodes
+}
+
+func (p *prs) learnerNodes() []uint64 {
+	nodes := make([]uint64, 0, len(p.learners))
+	for id := range p.learners {
+		nodes = append(nodes, id)
+	}
+	sort.Sort(uint64Slice(nodes))
+	return nodes
+}
+
+// visit invokes the supplied closure for all tracked progresses.
+func (p *prs) visit(f func(id uint64, pr *Progress)) {
+	for id, pr := range p.nodes {
+		f(id, pr)
+	}
+
+	for id, pr := range p.learners {
+		f(id, pr)
+	}
+}

+ 29 - 81
raft/raft.go

@@ -20,7 +20,6 @@ import (
 	"fmt"
 	"fmt"
 	"math"
 	"math"
 	"math/rand"
 	"math/rand"
-	"sort"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -261,7 +260,6 @@ type raft struct {
 
 
 	maxMsgSize         uint64
 	maxMsgSize         uint64
 	maxUncommittedSize uint64
 	maxUncommittedSize uint64
-	maxInflight        int
 	prs                prs
 	prs                prs
 
 
 	state StateType
 	state StateType
@@ -346,9 +344,8 @@ func newRaft(c *Config) *raft {
 		isLearner:                 false,
 		isLearner:                 false,
 		raftLog:                   raftlog,
 		raftLog:                   raftlog,
 		maxMsgSize:                c.MaxSizePerMsg,
 		maxMsgSize:                c.MaxSizePerMsg,
-		maxInflight:               c.MaxInflightMsgs,
 		maxUncommittedSize:        c.MaxUncommittedEntriesSize,
 		maxUncommittedSize:        c.MaxUncommittedEntriesSize,
-		prs:                       makePRS(),
+		prs:                       makePRS(c.MaxInflightMsgs),
 		electionTimeout:           c.ElectionTick,
 		electionTimeout:           c.ElectionTick,
 		heartbeatTimeout:          c.HeartbeatTick,
 		heartbeatTimeout:          c.HeartbeatTick,
 		logger:                    c.Logger,
 		logger:                    c.Logger,
@@ -358,13 +355,13 @@ func newRaft(c *Config) *raft {
 		disableProposalForwarding: c.DisableProposalForwarding,
 		disableProposalForwarding: c.DisableProposalForwarding,
 	}
 	}
 	for _, p := range peers {
 	for _, p := range peers {
-		r.prs.nodes[p] = &Progress{Next: 1, ins: newInflights(r.maxInflight)}
+		// Add node to active config.
+		r.prs.initProgress(p, 0 /* match */, 1 /* next */, false /* isLearner */)
 	}
 	}
 	for _, p := range learners {
 	for _, p := range learners {
-		if _, ok := r.prs.nodes[p]; ok {
-			panic(fmt.Sprintf("node %x is in both learner and peer list", p))
-		}
-		r.prs.learners[p] = &Progress{Next: 1, ins: newInflights(r.maxInflight), IsLearner: true}
+		// Add learner to active config.
+		r.prs.initProgress(p, 0 /* match */, 1 /* next */, true /* isLearner */)
+
 		if r.id == p {
 		if r.id == p {
 			r.isLearner = true
 			r.isLearner = true
 		}
 		}
@@ -379,7 +376,7 @@ func newRaft(c *Config) *raft {
 	r.becomeFollower(r.Term, None)
 	r.becomeFollower(r.Term, None)
 
 
 	var nodesStrs []string
 	var nodesStrs []string
-	for _, n := range r.nodes() {
+	for _, n := range r.prs.voterNodes() {
 		nodesStrs = append(nodesStrs, fmt.Sprintf("%x", n))
 		nodesStrs = append(nodesStrs, fmt.Sprintf("%x", n))
 	}
 	}
 
 
@@ -400,24 +397,6 @@ func (r *raft) hardState() pb.HardState {
 	}
 	}
 }
 }
 
 
-func (r *raft) nodes() []uint64 {
-	nodes := make([]uint64, 0, len(r.prs.nodes))
-	for id := range r.prs.nodes {
-		nodes = append(nodes, id)
-	}
-	sort.Sort(uint64Slice(nodes))
-	return nodes
-}
-
-func (r *raft) learnerNodes() []uint64 {
-	nodes := make([]uint64, 0, len(r.prs.learners))
-	for id := range r.prs.learners {
-		nodes = append(nodes, id)
-	}
-	sort.Sort(uint64Slice(nodes))
-	return nodes
-}
-
 // send persists state to stable storage and then sends to its mailbox.
 // send persists state to stable storage and then sends to its mailbox.
 func (r *raft) send(m pb.Message) {
 func (r *raft) send(m pb.Message) {
 	m.From = r.id
 	m.From = r.id
@@ -452,14 +431,6 @@ func (r *raft) send(m pb.Message) {
 	r.msgs = append(r.msgs, m)
 	r.msgs = append(r.msgs, m)
 }
 }
 
 
-func (r *raft) getProgress(id uint64) *Progress {
-	if pr, ok := r.prs.nodes[id]; ok {
-		return pr
-	}
-
-	return r.prs.learners[id]
-}
-
 // sendAppend sends an append RPC with new entries (if any) and the
 // sendAppend sends an append RPC with new entries (if any) and the
 // current commit index to the given peer.
 // current commit index to the given peer.
 func (r *raft) sendAppend(to uint64) {
 func (r *raft) sendAppend(to uint64) {
@@ -472,7 +443,7 @@ func (r *raft) sendAppend(to uint64) {
 // ("empty" messages are useful to convey updated Commit indexes, but
 // ("empty" messages are useful to convey updated Commit indexes, but
 // are undesirable when we're sending multiple messages in a batch).
 // are undesirable when we're sending multiple messages in a batch).
 func (r *raft) maybeSendAppend(to uint64, sendIfEmpty bool) bool {
 func (r *raft) maybeSendAppend(to uint64, sendIfEmpty bool) bool {
-	pr := r.getProgress(to)
+	pr := r.prs.getProgress(to)
 	if pr.IsPaused() {
 	if pr.IsPaused() {
 		return false
 		return false
 	}
 	}
@@ -541,7 +512,7 @@ func (r *raft) sendHeartbeat(to uint64, ctx []byte) {
 	// or it might not have all the committed entries.
 	// or it might not have all the committed entries.
 	// The leader MUST NOT forward the follower's commit to
 	// The leader MUST NOT forward the follower's commit to
 	// an unmatched index.
 	// an unmatched index.
-	commit := min(r.getProgress(to).Match, r.raftLog.committed)
+	commit := min(r.prs.getProgress(to).Match, r.raftLog.committed)
 	m := pb.Message{
 	m := pb.Message{
 		To:      to,
 		To:      to,
 		Type:    pb.MsgHeartbeat,
 		Type:    pb.MsgHeartbeat,
@@ -552,20 +523,10 @@ func (r *raft) sendHeartbeat(to uint64, ctx []byte) {
 	r.send(m)
 	r.send(m)
 }
 }
 
 
-func (r *raft) forEachProgress(f func(id uint64, pr *Progress)) {
-	for id, pr := range r.prs.nodes {
-		f(id, pr)
-	}
-
-	for id, pr := range r.prs.learners {
-		f(id, pr)
-	}
-}
-
 // bcastAppend sends RPC, with entries to all peers that are not up-to-date
 // bcastAppend sends RPC, with entries to all peers that are not up-to-date
 // according to the progress recorded in r.prs.
 // according to the progress recorded in r.prs.
 func (r *raft) bcastAppend() {
 func (r *raft) bcastAppend() {
-	r.forEachProgress(func(id uint64, _ *Progress) {
+	r.prs.visit(func(id uint64, _ *Progress) {
 		if id == r.id {
 		if id == r.id {
 			return
 			return
 		}
 		}
@@ -585,7 +546,7 @@ func (r *raft) bcastHeartbeat() {
 }
 }
 
 
 func (r *raft) bcastHeartbeatWithCtx(ctx []byte) {
 func (r *raft) bcastHeartbeatWithCtx(ctx []byte) {
-	r.forEachProgress(func(id uint64, _ *Progress) {
+	r.prs.visit(func(id uint64, _ *Progress) {
 		if id == r.id {
 		if id == r.id {
 			return
 			return
 		}
 		}
@@ -615,8 +576,8 @@ func (r *raft) reset(term uint64) {
 	r.abortLeaderTransfer()
 	r.abortLeaderTransfer()
 
 
 	r.votes = make(map[uint64]bool)
 	r.votes = make(map[uint64]bool)
-	r.forEachProgress(func(id uint64, pr *Progress) {
-		*pr = Progress{Next: r.raftLog.lastIndex() + 1, ins: newInflights(r.maxInflight), IsLearner: pr.IsLearner}
+	r.prs.visit(func(id uint64, pr *Progress) {
+		*pr = Progress{Next: r.raftLog.lastIndex() + 1, ins: newInflights(r.prs.maxInflight), IsLearner: pr.IsLearner}
 		if id == r.id {
 		if id == r.id {
 			pr.Match = r.raftLog.lastIndex()
 			pr.Match = r.raftLog.lastIndex()
 		}
 		}
@@ -644,7 +605,7 @@ func (r *raft) appendEntry(es ...pb.Entry) (accepted bool) {
 	}
 	}
 	// use latest "last" index after truncate/append
 	// use latest "last" index after truncate/append
 	li = r.raftLog.append(es...)
 	li = r.raftLog.append(es...)
-	r.getProgress(r.id).maybeUpdate(li)
+	r.prs.getProgress(r.id).maybeUpdate(li)
 	// Regardless of maybeCommit's return, our caller will call bcastAppend.
 	// Regardless of maybeCommit's return, our caller will call bcastAppend.
 	r.maybeCommit()
 	r.maybeCommit()
 	return true
 	return true
@@ -738,7 +699,7 @@ func (r *raft) becomeLeader() {
 	// (perhaps after having received a snapshot as a result). The leader is
 	// (perhaps after having received a snapshot as a result). The leader is
 	// trivially in this state. Note that r.reset() has initialized this
 	// trivially in this state. Note that r.reset() has initialized this
 	// progress with the last index already.
 	// progress with the last index already.
-	r.prs.nodes[r.id].becomeReplicate()
+	r.prs.getProgress(r.id).becomeReplicate()
 
 
 	// Conservatively set the pendingConfIndex to the last index in the
 	// Conservatively set the pendingConfIndex to the last index in the
 	// log. There may or may not be a pending config change, but it's
 	// log. There may or may not be a pending config change, but it's
@@ -1040,7 +1001,7 @@ func stepLeader(r *raft, m pb.Message) error {
 	}
 	}
 
 
 	// All other message types require a progress for m.From (pr).
 	// All other message types require a progress for m.From (pr).
-	pr := r.getProgress(m.From)
+	pr := r.prs.getProgress(m.From)
 	if pr == nil {
 	if pr == nil {
 		r.logger.Debugf("%x no progress available for %x", r.id, m.From)
 		r.logger.Debugf("%x no progress available for %x", r.id, m.From)
 		return nil
 		return nil
@@ -1367,16 +1328,16 @@ func (r *raft) restoreNode(nodes []uint64, isLearner bool) {
 			match = next - 1
 			match = next - 1
 			r.isLearner = isLearner
 			r.isLearner = isLearner
 		}
 		}
-		r.setProgress(n, match, next, isLearner)
-		r.logger.Infof("%x restored progress of %x [%s]", r.id, n, r.getProgress(n))
+		r.prs.initProgress(n, match, next, isLearner)
+		r.logger.Infof("%x restored progress of %x [%s]", r.id, n, r.prs.getProgress(n))
 	}
 	}
 }
 }
 
 
 // promotable indicates whether state machine can be promoted to leader,
 // promotable indicates whether state machine can be promoted to leader,
 // which is true when its own id is in progress list.
 // which is true when its own id is in progress list.
 func (r *raft) promotable() bool {
 func (r *raft) promotable() bool {
-	_, ok := r.prs.nodes[r.id]
-	return ok
+	pr := r.prs.getProgress(r.id)
+	return pr != nil && !pr.IsLearner
 }
 }
 
 
 func (r *raft) addNode(id uint64) {
 func (r *raft) addNode(id uint64) {
@@ -1388,12 +1349,12 @@ func (r *raft) addLearner(id uint64) {
 }
 }
 
 
 func (r *raft) addNodeOrLearnerNode(id uint64, isLearner bool) {
 func (r *raft) addNodeOrLearnerNode(id uint64, isLearner bool) {
-	pr := r.getProgress(id)
+	pr := r.prs.getProgress(id)
 	if pr == nil {
 	if pr == nil {
-		r.setProgress(id, 0, r.raftLog.lastIndex()+1, isLearner)
+		r.prs.initProgress(id, 0, r.raftLog.lastIndex()+1, isLearner)
 	} else {
 	} else {
 		if isLearner && !pr.IsLearner {
 		if isLearner && !pr.IsLearner {
-			// can only change Learner to Voter
+			// Can only change Learner to Voter.
 			r.logger.Infof("%x ignored addLearner: do not support changing %x from raft peer to learner.", r.id, id)
 			r.logger.Infof("%x ignored addLearner: do not support changing %x from raft peer to learner.", r.id, id)
 			return
 			return
 		}
 		}
@@ -1404,10 +1365,11 @@ func (r *raft) addNodeOrLearnerNode(id uint64, isLearner bool) {
 			return
 			return
 		}
 		}
 
 
-		// change Learner to Voter, use origin Learner progress
-		delete(r.prs.learners, id)
+		// Change Learner to Voter, use origin Learner progress.
+		r.prs.removeAny(id)
+		r.prs.initProgress(id, 0 /* match */, 1 /* next */, false /* isLearner */)
 		pr.IsLearner = false
 		pr.IsLearner = false
-		r.prs.nodes[id] = pr
+		*r.prs.getProgress(id) = *pr
 	}
 	}
 
 
 	if r.id == id {
 	if r.id == id {
@@ -1417,8 +1379,7 @@ func (r *raft) addNodeOrLearnerNode(id uint64, isLearner bool) {
 	// When a node is first added, we should mark it as recently active.
 	// When a node is first added, we should mark it as recently active.
 	// Otherwise, CheckQuorum may cause us to step down if it is invoked
 	// Otherwise, CheckQuorum may cause us to step down if it is invoked
 	// before the added node has a chance to communicate with us.
 	// before the added node has a chance to communicate with us.
-	pr = r.getProgress(id)
-	pr.RecentActive = true
+	r.prs.getProgress(id).RecentActive = true
 }
 }
 
 
 func (r *raft) removeNode(id uint64) {
 func (r *raft) removeNode(id uint64) {
@@ -1440,19 +1401,6 @@ func (r *raft) removeNode(id uint64) {
 	}
 	}
 }
 }
 
 
-func (r *raft) setProgress(id, match, next uint64, isLearner bool) {
-	if !isLearner {
-		delete(r.prs.learners, id)
-		r.prs.nodes[id] = &Progress{Next: next, Match: match, ins: newInflights(r.maxInflight)}
-		return
-	}
-
-	if _, ok := r.prs.nodes[id]; ok {
-		panic(fmt.Sprintf("%x unexpected changing from voter to learner for %x", r.id, id))
-	}
-	r.prs.learners[id] = &Progress{Next: next, Match: match, ins: newInflights(r.maxInflight), IsLearner: true}
-}
-
 func (r *raft) loadState(state pb.HardState) {
 func (r *raft) loadState(state pb.HardState) {
 	if state.Commit < r.raftLog.committed || state.Commit > r.raftLog.lastIndex() {
 	if state.Commit < r.raftLog.committed || state.Commit > r.raftLog.lastIndex() {
 		r.logger.Panicf("%x state.commit %d is out of range [%d, %d]", r.id, state.Commit, r.raftLog.committed, r.raftLog.lastIndex())
 		r.logger.Panicf("%x state.commit %d is out of range [%d, %d]", r.id, state.Commit, r.raftLog.committed, r.raftLog.lastIndex())
@@ -1480,7 +1428,7 @@ func (r *raft) resetRandomizedElectionTimeout() {
 func (r *raft) checkQuorumActive() bool {
 func (r *raft) checkQuorumActive() bool {
 	var act int
 	var act int
 
 
-	r.forEachProgress(func(id uint64, pr *Progress) {
+	r.prs.visit(func(id uint64, pr *Progress) {
 		if id == r.id { // self is always active
 		if id == r.id { // self is always active
 			act++
 			act++
 			return
 			return

+ 4 - 4
raft/raft_flow_control_test.go

@@ -33,7 +33,7 @@ func TestMsgAppFlowControlFull(t *testing.T) {
 	// force the progress to be in replicate state
 	// force the progress to be in replicate state
 	pr2.becomeReplicate()
 	pr2.becomeReplicate()
 	// fill in the inflights window
 	// fill in the inflights window
-	for i := 0; i < r.maxInflight; i++ {
+	for i := 0; i < r.prs.maxInflight; i++ {
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
 		ms := r.readMessages()
 		ms := r.readMessages()
 		if len(ms) != 1 {
 		if len(ms) != 1 {
@@ -69,14 +69,14 @@ func TestMsgAppFlowControlMoveForward(t *testing.T) {
 	// force the progress to be in replicate state
 	// force the progress to be in replicate state
 	pr2.becomeReplicate()
 	pr2.becomeReplicate()
 	// fill in the inflights window
 	// fill in the inflights window
-	for i := 0; i < r.maxInflight; i++ {
+	for i := 0; i < r.prs.maxInflight; i++ {
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
 		r.readMessages()
 		r.readMessages()
 	}
 	}
 
 
 	// 1 is noop, 2 is the first proposal we just sent.
 	// 1 is noop, 2 is the first proposal we just sent.
 	// so we start with 2.
 	// so we start with 2.
-	for tt := 2; tt < r.maxInflight; tt++ {
+	for tt := 2; tt < r.prs.maxInflight; tt++ {
 		// move forward the window
 		// move forward the window
 		r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: uint64(tt)})
 		r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: uint64(tt)})
 		r.readMessages()
 		r.readMessages()
@@ -114,7 +114,7 @@ func TestMsgAppFlowControlRecvHeartbeat(t *testing.T) {
 	// force the progress to be in replicate state
 	// force the progress to be in replicate state
 	pr2.becomeReplicate()
 	pr2.becomeReplicate()
 	// fill in the inflights window
 	// fill in the inflights window
-	for i := 0; i < r.maxInflight; i++ {
+	for i := 0; i < r.prs.maxInflight; i++ {
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
 		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
 		r.readMessages()
 		r.readMessages()
 	}
 	}

+ 16 - 16
raft/raft_test.go

@@ -889,7 +889,7 @@ func TestLearnerLogReplication(t *testing.T) {
 		t.Errorf("peer 2 wants committed to %d, but still %d", n1.raftLog.committed, n2.raftLog.committed)
 		t.Errorf("peer 2 wants committed to %d, but still %d", n1.raftLog.committed, n2.raftLog.committed)
 	}
 	}
 
 
-	match := n1.getProgress(2).Match
+	match := n1.prs.getProgress(2).Match
 	if match != n2.raftLog.committed {
 	if match != n2.raftLog.committed {
 		t.Errorf("progress 2 of leader 1 wants match %d, but got %d", n2.raftLog.committed, match)
 		t.Errorf("progress 2 of leader 1 wants match %d, but got %d", n2.raftLog.committed, match)
 	}
 	}
@@ -1352,7 +1352,7 @@ func TestCommit(t *testing.T) {
 
 
 		sm := newTestRaft(1, []uint64{1}, 10, 2, storage)
 		sm := newTestRaft(1, []uint64{1}, 10, 2, storage)
 		for j := 0; j < len(tt.matches); j++ {
 		for j := 0; j < len(tt.matches); j++ {
-			sm.setProgress(uint64(j)+1, tt.matches[j], tt.matches[j]+1, false)
+			sm.prs.initProgress(uint64(j)+1, tt.matches[j], tt.matches[j]+1, false)
 		}
 		}
 		sm.maybeCommit()
 		sm.maybeCommit()
 		if g := sm.raftLog.committed; g != tt.w {
 		if g := sm.raftLog.committed; g != tt.w {
@@ -2931,7 +2931,7 @@ func TestRestore(t *testing.T) {
 	if mustTerm(sm.raftLog.term(s.Metadata.Index)) != s.Metadata.Term {
 	if mustTerm(sm.raftLog.term(s.Metadata.Index)) != s.Metadata.Term {
 		t.Errorf("log.lastTerm = %d, want %d", mustTerm(sm.raftLog.term(s.Metadata.Index)), s.Metadata.Term)
 		t.Errorf("log.lastTerm = %d, want %d", mustTerm(sm.raftLog.term(s.Metadata.Index)), s.Metadata.Term)
 	}
 	}
-	sg := sm.nodes()
+	sg := sm.prs.voterNodes()
 	if !reflect.DeepEqual(sg, s.Metadata.ConfState.Nodes) {
 	if !reflect.DeepEqual(sg, s.Metadata.ConfState.Nodes) {
 		t.Errorf("sm.Nodes = %+v, want %+v", sg, s.Metadata.ConfState.Nodes)
 		t.Errorf("sm.Nodes = %+v, want %+v", sg, s.Metadata.ConfState.Nodes)
 	}
 	}
@@ -2963,11 +2963,11 @@ func TestRestoreWithLearner(t *testing.T) {
 	if mustTerm(sm.raftLog.term(s.Metadata.Index)) != s.Metadata.Term {
 	if mustTerm(sm.raftLog.term(s.Metadata.Index)) != s.Metadata.Term {
 		t.Errorf("log.lastTerm = %d, want %d", mustTerm(sm.raftLog.term(s.Metadata.Index)), s.Metadata.Term)
 		t.Errorf("log.lastTerm = %d, want %d", mustTerm(sm.raftLog.term(s.Metadata.Index)), s.Metadata.Term)
 	}
 	}
-	sg := sm.nodes()
+	sg := sm.prs.voterNodes()
 	if len(sg) != len(s.Metadata.ConfState.Nodes) {
 	if len(sg) != len(s.Metadata.ConfState.Nodes) {
 		t.Errorf("sm.Nodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Nodes)
 		t.Errorf("sm.Nodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Nodes)
 	}
 	}
-	lns := sm.learnerNodes()
+	lns := sm.prs.learnerNodes()
 	if len(lns) != len(s.Metadata.ConfState.Learners) {
 	if len(lns) != len(s.Metadata.ConfState.Learners) {
 		t.Errorf("sm.LearnerNodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Learners)
 		t.Errorf("sm.LearnerNodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Learners)
 	}
 	}
@@ -3192,7 +3192,7 @@ func TestSlowNodeRestore(t *testing.T) {
 	}
 	}
 	lead := nt.peers[1].(*raft)
 	lead := nt.peers[1].(*raft)
 	nextEnts(lead, nt.storage[1])
 	nextEnts(lead, nt.storage[1])
-	nt.storage[1].CreateSnapshot(lead.raftLog.applied, &pb.ConfState{Nodes: lead.nodes()}, nil)
+	nt.storage[1].CreateSnapshot(lead.raftLog.applied, &pb.ConfState{Nodes: lead.prs.voterNodes()}, nil)
 	nt.storage[1].Compact(lead.raftLog.applied)
 	nt.storage[1].Compact(lead.raftLog.applied)
 
 
 	nt.recover()
 	nt.recover()
@@ -3287,7 +3287,7 @@ func TestNewLeaderPendingConfig(t *testing.T) {
 func TestAddNode(t *testing.T) {
 func TestAddNode(t *testing.T) {
 	r := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	r := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	r.addNode(2)
 	r.addNode(2)
-	nodes := r.nodes()
+	nodes := r.prs.voterNodes()
 	wnodes := []uint64{1, 2}
 	wnodes := []uint64{1, 2}
 	if !reflect.DeepEqual(nodes, wnodes) {
 	if !reflect.DeepEqual(nodes, wnodes) {
 		t.Errorf("nodes = %v, want %v", nodes, wnodes)
 		t.Errorf("nodes = %v, want %v", nodes, wnodes)
@@ -3298,7 +3298,7 @@ func TestAddNode(t *testing.T) {
 func TestAddLearner(t *testing.T) {
 func TestAddLearner(t *testing.T) {
 	r := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	r := newTestRaft(1, []uint64{1}, 10, 1, NewMemoryStorage())
 	r.addLearner(2)
 	r.addLearner(2)
-	nodes := r.learnerNodes()
+	nodes := r.prs.learnerNodes()
 	wnodes := []uint64{2}
 	wnodes := []uint64{2}
 	if !reflect.DeepEqual(nodes, wnodes) {
 	if !reflect.DeepEqual(nodes, wnodes) {
 		t.Errorf("nodes = %v, want %v", nodes, wnodes)
 		t.Errorf("nodes = %v, want %v", nodes, wnodes)
@@ -3348,14 +3348,14 @@ func TestRemoveNode(t *testing.T) {
 	r := newTestRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 	r := newTestRaft(1, []uint64{1, 2}, 10, 1, NewMemoryStorage())
 	r.removeNode(2)
 	r.removeNode(2)
 	w := []uint64{1}
 	w := []uint64{1}
-	if g := r.nodes(); !reflect.DeepEqual(g, w) {
+	if g := r.prs.voterNodes(); !reflect.DeepEqual(g, w) {
 		t.Errorf("nodes = %v, want %v", g, w)
 		t.Errorf("nodes = %v, want %v", g, w)
 	}
 	}
 
 
 	// remove all nodes from cluster
 	// remove all nodes from cluster
 	r.removeNode(1)
 	r.removeNode(1)
 	w = []uint64{}
 	w = []uint64{}
-	if g := r.nodes(); !reflect.DeepEqual(g, w) {
+	if g := r.prs.voterNodes(); !reflect.DeepEqual(g, w) {
 		t.Errorf("nodes = %v, want %v", g, w)
 		t.Errorf("nodes = %v, want %v", g, w)
 	}
 	}
 }
 }
@@ -3366,18 +3366,18 @@ func TestRemoveLearner(t *testing.T) {
 	r := newTestLearnerRaft(1, []uint64{1}, []uint64{2}, 10, 1, NewMemoryStorage())
 	r := newTestLearnerRaft(1, []uint64{1}, []uint64{2}, 10, 1, NewMemoryStorage())
 	r.removeNode(2)
 	r.removeNode(2)
 	w := []uint64{1}
 	w := []uint64{1}
-	if g := r.nodes(); !reflect.DeepEqual(g, w) {
+	if g := r.prs.voterNodes(); !reflect.DeepEqual(g, w) {
 		t.Errorf("nodes = %v, want %v", g, w)
 		t.Errorf("nodes = %v, want %v", g, w)
 	}
 	}
 
 
 	w = []uint64{}
 	w = []uint64{}
-	if g := r.learnerNodes(); !reflect.DeepEqual(g, w) {
+	if g := r.prs.learnerNodes(); !reflect.DeepEqual(g, w) {
 		t.Errorf("nodes = %v, want %v", g, w)
 		t.Errorf("nodes = %v, want %v", g, w)
 	}
 	}
 
 
 	// remove all nodes from cluster
 	// remove all nodes from cluster
 	r.removeNode(1)
 	r.removeNode(1)
-	if g := r.nodes(); !reflect.DeepEqual(g, w) {
+	if g := r.prs.voterNodes(); !reflect.DeepEqual(g, w) {
 		t.Errorf("nodes = %v, want %v", g, w)
 		t.Errorf("nodes = %v, want %v", g, w)
 	}
 	}
 }
 }
@@ -3416,8 +3416,8 @@ func TestRaftNodes(t *testing.T) {
 	}
 	}
 	for i, tt := range tests {
 	for i, tt := range tests {
 		r := newTestRaft(1, tt.ids, 10, 1, NewMemoryStorage())
 		r := newTestRaft(1, tt.ids, 10, 1, NewMemoryStorage())
-		if !reflect.DeepEqual(r.nodes(), tt.wids) {
-			t.Errorf("#%d: nodes = %+v, want %+v", i, r.nodes(), tt.wids)
+		if !reflect.DeepEqual(r.prs.voterNodes(), tt.wids) {
+			t.Errorf("#%d: nodes = %+v, want %+v", i, r.prs.voterNodes(), tt.wids)
 		}
 		}
 	}
 	}
 }
 }
@@ -3637,7 +3637,7 @@ func TestLeaderTransferAfterSnapshot(t *testing.T) {
 	nt.send(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{}}})
 	nt.send(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{}}})
 	lead := nt.peers[1].(*raft)
 	lead := nt.peers[1].(*raft)
 	nextEnts(lead, nt.storage[1])
 	nextEnts(lead, nt.storage[1])
-	nt.storage[1].CreateSnapshot(lead.raftLog.applied, &pb.ConfState{Nodes: lead.nodes()}, nil)
+	nt.storage[1].CreateSnapshot(lead.raftLog.applied, &pb.ConfState{Nodes: lead.prs.voterNodes()}, nil)
 	nt.storage[1].Compact(lead.raftLog.applied)
 	nt.storage[1].Compact(lead.raftLog.applied)
 
 
 	nt.recover()
 	nt.recover()

+ 3 - 3
raft/rawnode.go

@@ -166,7 +166,7 @@ func (rn *RawNode) ProposeConfChange(cc pb.ConfChange) error {
 // ApplyConfChange applies a config change to the local node.
 // ApplyConfChange applies a config change to the local node.
 func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState {
 func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState {
 	if cc.NodeID == None {
 	if cc.NodeID == None {
-		return &pb.ConfState{Nodes: rn.raft.nodes(), Learners: rn.raft.learnerNodes()}
+		return &pb.ConfState{Nodes: rn.raft.prs.voterNodes(), Learners: rn.raft.prs.learnerNodes()}
 	}
 	}
 	switch cc.Type {
 	switch cc.Type {
 	case pb.ConfChangeAddNode:
 	case pb.ConfChangeAddNode:
@@ -179,7 +179,7 @@ func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState {
 	default:
 	default:
 		panic("unexpected conf type")
 		panic("unexpected conf type")
 	}
 	}
-	return &pb.ConfState{Nodes: rn.raft.nodes(), Learners: rn.raft.learnerNodes()}
+	return &pb.ConfState{Nodes: rn.raft.prs.voterNodes(), Learners: rn.raft.prs.learnerNodes()}
 }
 }
 
 
 // Step advances the state machine using the given message.
 // Step advances the state machine using the given message.
@@ -188,7 +188,7 @@ func (rn *RawNode) Step(m pb.Message) error {
 	if IsLocalMsg(m.Type) {
 	if IsLocalMsg(m.Type) {
 		return ErrStepLocalMsg
 		return ErrStepLocalMsg
 	}
 	}
-	if pr := rn.raft.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) {
+	if pr := rn.raft.prs.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) {
 		return rn.raft.Step(m)
 		return rn.raft.Step(m)
 	}
 	}
 	return ErrStepPeerNotFound
 	return ErrStepPeerNotFound

+ 11 - 8
raft/status.go

@@ -33,15 +33,18 @@ type Status struct {
 }
 }
 
 
 func getProgressCopy(r *raft) map[uint64]Progress {
 func getProgressCopy(r *raft) map[uint64]Progress {
-	prs := make(map[uint64]Progress)
-	for id, p := range r.prs.nodes {
-		prs[id] = *p
-	}
+	m := make(map[uint64]Progress)
+	r.prs.visit(func(id uint64, pr *Progress) {
+		var p Progress
+		p, pr = *pr, nil /* avoid accidental reuse below */
 
 
-	for id, p := range r.prs.learners {
-		prs[id] = *p
-	}
-	return prs
+		// The inflight buffer is tricky to copy and besides, it isn't exposed
+		// to the client, so pretend it's nil.
+		p.ins = nil
+
+		m[id] = p
+	})
+	return m
 }
 }
 
 
 func getStatusWithoutProgress(r *raft) Status {
 func getStatusWithoutProgress(r *raft) Status {