Browse Source

raft: add lastIndex as rejectHint

Add the lastindex of the raft log as reject hint, so the leader can
bypass the greater index probing and decrease the next index directly
to last + 1.
Xiang Li 11 years ago
parent
commit
35b907ac58
5 changed files with 83 additions and 51 deletions
  1. 16 15
      raft/raft.go
  2. 10 9
      raft/raft_paper_test.go
  3. 26 17
      raft/raft_test.go
  4. 20 0
      raft/raftpb/raft.pb.go
  5. 11 10
      raft/raftpb/raft.proto

+ 16 - 15
raft/raft.go

@@ -70,13 +70,13 @@ func (pr *progress) update(n uint64) {
 func (pr *progress) optimisticUpdate(n uint64) { pr.next = n + 1 }
 
 // maybeDecrTo returns false if the given to index comes from an out of order message.
-// Otherwise it decreases the progress next index and returns true.
-func (pr *progress) maybeDecrTo(to uint64) bool {
+// Otherwise it decreases the progress next index to min(rejected, last) and returns true.
+func (pr *progress) maybeDecrTo(rejected, last uint64) bool {
 	pr.waitReset()
 	if pr.match != 0 {
-		// the rejection must be stale if the progress has matched and "to"
+		// the rejection must be stale if the progress has matched and "rejected"
 		// is smaller than "match".
-		if to <= pr.match {
+		if rejected <= pr.match {
 			return false
 		}
 		// directly decrease next to match + 1
@@ -84,12 +84,12 @@ func (pr *progress) maybeDecrTo(to uint64) bool {
 		return true
 	}
 
-	// the rejection must be stale if "to" does not match next - 1
-	if pr.next-1 != to {
+	// the rejection must be stale if "rejected" does not match next - 1
+	if pr.next-1 != rejected {
 		return false
 	}
 
-	if pr.next--; pr.next < 1 {
+	if pr.next = min(rejected, last+1); pr.next < 1 {
 		pr.next = 1
 	}
 	return true
@@ -245,8 +245,8 @@ func (r *raft) sendAppend(to uint64) {
 		if n := len(m.Entries); pr.match != 0 && n != 0 {
 			pr.optimisticUpdate(m.Entries[n-1].Index)
 		} else if pr.match == 0 {
-			// TODO (xiangli): better way to find out if the follwer is in good path or not
-			// a follower might be in bad path even if match != 0, since we optmistically
+			// TODO (xiangli): better way to find out if the follower is in good path or not
+			// a follower might be in bad path even if match != 0, since we optimistically
 			// increase the next.
 			pr.waitSet(r.heartbeatTimeout)
 		}
@@ -482,9 +482,10 @@ func stepLeader(r *raft, m pb.Message) {
 		r.bcastAppend()
 	case pb.MsgAppResp:
 		if m.Reject {
-			log.Printf("raft: %x received msgApp rejection from %x for index %d",
-				r.id, m.From, m.Index)
-			if r.prs[m.From].maybeDecrTo(m.Index) {
+			log.Printf("raft: %x received msgApp rejection(lastindex: %d) from %x for index %d",
+				r.id, m.RejectHint, m.From, m.Index)
+			if r.prs[m.From].maybeDecrTo(m.Index, m.RejectHint) {
+				log.Printf("raft: %x decreased progress of %x to [%s]", r.id, m.From, r.prs[m.From])
 				r.sendAppend(m.From)
 			}
 		} else {
@@ -572,7 +573,7 @@ func (r *raft) handleAppendEntries(m pb.Message) {
 	} else {
 		log.Printf("raft: %x [logterm: %d, index: %d] rejected msgApp [logterm: %d, index: %d] from %x",
 			r.id, r.raftLog.term(m.Index), m.Index, m.LogTerm, m.Index, m.From)
-		r.send(pb.Message{To: m.From, Type: pb.MsgAppResp, Index: m.Index, Reject: true})
+		r.send(pb.Message{To: m.From, Type: pb.MsgAppResp, Index: m.Index, Reject: true, RejectHint: r.raftLog.lastIndex()})
 	}
 }
 
@@ -593,8 +594,8 @@ func (r *raft) handleSnapshot(m pb.Message) {
 	}
 }
 
-// restore recovers the statemachine from a snapshot. It restores the log and the
-// configuration of statemachine.
+// restore recovers the state machine from a snapshot. It restores the log and the
+// configuration of state machine.
 func (r *raft) restore(s pb.Snapshot) bool {
 	if s.Metadata.Index <= r.raftLog.committed {
 		return false

+ 10 - 9
raft/raft_paper_test.go

@@ -605,15 +605,16 @@ func TestFollowerCommitEntry(t *testing.T) {
 func TestFollowerCheckMsgApp(t *testing.T) {
 	ents := []pb.Entry{{Term: 1, Index: 1}, {Term: 2, Index: 2}}
 	tests := []struct {
-		term    uint64
-		index   uint64
-		wreject bool
+		term        uint64
+		index       uint64
+		wreject     bool
+		wrejectHint uint64
 	}{
-		{ents[0].Term, ents[0].Index, false},
-		{ents[0].Term, ents[0].Index + 1, true},
-		{ents[0].Term + 1, ents[0].Index, true},
-		{ents[1].Term, ents[1].Index, false},
-		{3, 3, true},
+		{ents[0].Term, ents[0].Index, false, 0},
+		{ents[0].Term, ents[0].Index + 1, true, 2},
+		{ents[0].Term + 1, ents[0].Index, true, 2},
+		{ents[1].Term, ents[1].Index, false, 0},
+		{3, 3, true, 2},
 	}
 	for i, tt := range tests {
 		storage := NewMemoryStorage()
@@ -626,7 +627,7 @@ func TestFollowerCheckMsgApp(t *testing.T) {
 
 		msgs := r.readMessages()
 		wmsgs := []pb.Message{
-			{From: 1, To: 2, Type: pb.MsgAppResp, Term: 2, Index: tt.index, Reject: tt.wreject},
+			{From: 1, To: 2, Type: pb.MsgAppResp, Term: 2, Index: tt.index, Reject: tt.wreject, RejectHint: tt.wrejectHint},
 		}
 		if !reflect.DeepEqual(msgs, wmsgs) {
 			t.Errorf("#%d: msgs = %+v, want %+v", i, msgs, wmsgs)

+ 26 - 17
raft/raft_test.go

@@ -80,50 +80,59 @@ func TestProgressUpdate(t *testing.T) {
 
 func TestProgressMaybeDecr(t *testing.T) {
 	tests := []struct {
-		m  uint64
-		n  uint64
-		to uint64
+		m        uint64
+		n        uint64
+		rejected uint64
+		last     uint64
 
 		w  bool
 		wn uint64
 	}{
 		{
 			// match != 0 is always false
-			1, 0, 0, false, 0,
+			1, 0, 0, 0, false, 0,
 		},
 		{
 			// match != 0 and to is greater than match
 			// directly decrease to match+1
-			5, 10, 5, false, 10,
+			5, 10, 5, 5, false, 10,
 		},
 		{
 			// match != 0 and to is greater than match
 			// directly decrease to match+1
-			5, 10, 4, false, 10,
+			5, 10, 4, 4, false, 10,
 		},
 		{
 			// match != 0 and to is not greater than match
-			5, 10, 9, true, 6,
+			5, 10, 9, 9, true, 6,
 		},
 		{
-			// next-1 != to is always false
-			0, 0, 0, false, 0,
+			// next-1 != rejected is always false
+			0, 0, 0, 0, false, 0,
 		},
 		{
-			// next-1 != to is always false
-			0, 10, 5, false, 10,
+			// next-1 != rejected is always false
+			0, 10, 5, 5, false, 10,
 		},
 		{
 			// next>1 = decremented by 1
-			0, 10, 9, true, 9,
+			0, 10, 9, 9, true, 9,
 		},
 		{
 			// next>1 = decremented by 1
-			0, 2, 1, true, 1,
+			0, 2, 1, 1, true, 1,
 		},
 		{
 			// next<=1 = reset to 1
-			0, 1, 0, true, 1,
+			0, 1, 0, 0, true, 1,
+		},
+		{
+			// decrease to min(rejected, last+1)
+			0, 10, 9, 2, true, 3,
+		},
+		{
+			// rejected < 1, reset to 1
+			0, 10, 9, 0, true, 1,
 		},
 	}
 	for i, tt := range tests {
@@ -131,7 +140,7 @@ func TestProgressMaybeDecr(t *testing.T) {
 			match: tt.m,
 			next:  tt.n,
 		}
-		if g := p.maybeDecrTo(tt.to); g != tt.w {
+		if g := p.maybeDecrTo(tt.rejected, tt.last); g != tt.w {
 			t.Errorf("#%d: maybeDecrTo= %t, want %t", i, g, tt.w)
 		}
 		if gm := p.match; gm != tt.m {
@@ -173,7 +182,7 @@ func TestProgressWaitReset(t *testing.T) {
 	p := &progress{
 		wait: 1,
 	}
-	p.maybeDecrTo(1)
+	p.maybeDecrTo(1, 1)
 	if p.wait != 0 {
 		t.Errorf("wait= %d, want 0", p.wait)
 	}
@@ -1001,7 +1010,7 @@ func TestLeaderAppResp(t *testing.T) {
 		sm.becomeCandidate()
 		sm.becomeLeader()
 		sm.readMessages()
-		sm.Step(pb.Message{From: 2, Type: pb.MsgAppResp, Index: tt.index, Term: sm.Term, Reject: tt.reject})
+		sm.Step(pb.Message{From: 2, Type: pb.MsgAppResp, Index: tt.index, Term: sm.Term, Reject: tt.reject, RejectHint: tt.index})
 
 		p := sm.prs[2]
 		if p.match != tt.wmatch {

+ 20 - 0
raft/raftpb/raft.pb.go

@@ -200,6 +200,7 @@ type Message struct {
 	Commit           uint64      `protobuf:"varint,8,req,name=commit" json:"commit"`
 	Snapshot         Snapshot    `protobuf:"bytes,9,req,name=snapshot" json:"snapshot"`
 	Reject           bool        `protobuf:"varint,10,req,name=reject" json:"reject"`
+	RejectHint       uint64      `protobuf:"varint,11,req,name=rejectHint" json:"rejectHint"`
 	XXX_unrecognized []byte      `json:"-"`
 }
 
@@ -725,6 +726,21 @@ func (m *Message) Unmarshal(data []byte) error {
 				}
 			}
 			m.Reject = bool(v != 0)
+		case 11:
+			if wireType != 0 {
+				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+			}
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				m.RejectHint |= (uint64(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
 		default:
 			var sizeOfWire int
 			for {
@@ -1059,6 +1075,7 @@ func (m *Message) Size() (n int) {
 	l = m.Snapshot.Size()
 	n += 1 + l + sovRaft(uint64(l))
 	n += 2
+	n += 1 + sovRaft(uint64(m.RejectHint))
 	if m.XXX_unrecognized != nil {
 		n += len(m.XXX_unrecognized)
 	}
@@ -1278,6 +1295,9 @@ func (m *Message) MarshalTo(data []byte) (n int, err error) {
 		data[i] = 0
 	}
 	i++
+	data[i] = 0x58
+	i++
+	i = encodeVarintRaft(data, i, uint64(m.RejectHint))
 	if m.XXX_unrecognized != nil {
 		i += copy(data[i:], m.XXX_unrecognized)
 	}

+ 11 - 10
raft/raftpb/raft.proto

@@ -44,16 +44,17 @@ enum MessageType {
 }
 
 message Message {
-	required MessageType type     = 1  [(gogoproto.nullable) = false];
-	required uint64      to       = 2  [(gogoproto.nullable) = false];
-	required uint64      from     = 3  [(gogoproto.nullable) = false];
-	required uint64      term     = 4  [(gogoproto.nullable) = false];
-	required uint64      logTerm  = 5  [(gogoproto.nullable) = false];
-	required uint64      index    = 6  [(gogoproto.nullable) = false];
-	repeated Entry       entries  = 7  [(gogoproto.nullable) = false];
-	required uint64      commit   = 8  [(gogoproto.nullable) = false];
-	required Snapshot    snapshot = 9  [(gogoproto.nullable) = false];
-	required bool        reject   = 10 [(gogoproto.nullable) = false];
+	required MessageType type        = 1  [(gogoproto.nullable) = false];
+	required uint64      to          = 2  [(gogoproto.nullable) = false];
+	required uint64      from        = 3  [(gogoproto.nullable) = false];
+	required uint64      term        = 4  [(gogoproto.nullable) = false];
+	required uint64      logTerm     = 5  [(gogoproto.nullable) = false];
+	required uint64      index       = 6  [(gogoproto.nullable) = false];
+	repeated Entry       entries     = 7  [(gogoproto.nullable) = false];
+	required uint64      commit      = 8  [(gogoproto.nullable) = false];
+	required Snapshot    snapshot    = 9  [(gogoproto.nullable) = false];
+	required bool        reject      = 10 [(gogoproto.nullable) = false];
+	required uint64      rejectHint  = 11 [(gogoproto.nullable) = false];
 }
 
 message HardState {