Browse Source

Merge pull request #2021 from xiang90/raft_reject_hint

raft: add lastIndex as rejectHint
Xiang Li 11 years ago
parent
commit
921ce4c25b
5 changed files with 83 additions and 51 deletions
  1. 16 15
      raft/raft.go
  2. 10 9
      raft/raft_paper_test.go
  3. 26 17
      raft/raft_test.go
  4. 20 0
      raft/raftpb/raft.pb.go
  5. 11 10
      raft/raftpb/raft.proto

+ 16 - 15
raft/raft.go

@@ -70,13 +70,13 @@ func (pr *progress) update(n uint64) {
 func (pr *progress) optimisticUpdate(n uint64) { pr.next = n + 1 }
 
 // maybeDecrTo returns false if the given to index comes from an out of order message.
-// Otherwise it decreases the progress next index and returns true.
-func (pr *progress) maybeDecrTo(to uint64) bool {
+// Otherwise it decreases the progress next index to min(rejected, last) and returns true.
+func (pr *progress) maybeDecrTo(rejected, last uint64) bool {
 	pr.waitReset()
 	if pr.match != 0 {
-		// the rejection must be stale if the progress has matched and "to"
+		// the rejection must be stale if the progress has matched and "rejected"
 		// is smaller than "match".
-		if to <= pr.match {
+		if rejected <= pr.match {
 			return false
 		}
 		// directly decrease next to match + 1
@@ -84,12 +84,12 @@ func (pr *progress) maybeDecrTo(to uint64) bool {
 		return true
 	}
 
-	// the rejection must be stale if "to" does not match next - 1
-	if pr.next-1 != to {
+	// the rejection must be stale if "rejected" does not match next - 1
+	if pr.next-1 != rejected {
 		return false
 	}
 
-	if pr.next--; pr.next < 1 {
+	if pr.next = min(rejected, last+1); pr.next < 1 {
 		pr.next = 1
 	}
 	return true
@@ -245,8 +245,8 @@ func (r *raft) sendAppend(to uint64) {
 		if n := len(m.Entries); pr.match != 0 && n != 0 {
 			pr.optimisticUpdate(m.Entries[n-1].Index)
 		} else if pr.match == 0 {
-			// TODO (xiangli): better way to find out if the follwer is in good path or not
-			// a follower might be in bad path even if match != 0, since we optmistically
+			// TODO (xiangli): better way to find out if the follower is in good path or not
+			// a follower might be in bad path even if match != 0, since we optimistically
 			// increase the next.
 			pr.waitSet(r.heartbeatTimeout)
 		}
@@ -482,9 +482,10 @@ func stepLeader(r *raft, m pb.Message) {
 		r.bcastAppend()
 	case pb.MsgAppResp:
 		if m.Reject {
-			log.Printf("raft: %x received msgApp rejection from %x for index %d",
-				r.id, m.From, m.Index)
-			if r.prs[m.From].maybeDecrTo(m.Index) {
+			log.Printf("raft: %x received msgApp rejection(lastindex: %d) from %x for index %d",
+				r.id, m.RejectHint, m.From, m.Index)
+			if r.prs[m.From].maybeDecrTo(m.Index, m.RejectHint) {
+				log.Printf("raft: %x decreased progress of %x to [%s]", r.id, m.From, r.prs[m.From])
 				r.sendAppend(m.From)
 			}
 		} else {
@@ -572,7 +573,7 @@ func (r *raft) handleAppendEntries(m pb.Message) {
 	} else {
 		log.Printf("raft: %x [logterm: %d, index: %d] rejected msgApp [logterm: %d, index: %d] from %x",
 			r.id, r.raftLog.term(m.Index), m.Index, m.LogTerm, m.Index, m.From)
-		r.send(pb.Message{To: m.From, Type: pb.MsgAppResp, Index: m.Index, Reject: true})
+		r.send(pb.Message{To: m.From, Type: pb.MsgAppResp, Index: m.Index, Reject: true, RejectHint: r.raftLog.lastIndex()})
 	}
 }
 
@@ -593,8 +594,8 @@ func (r *raft) handleSnapshot(m pb.Message) {
 	}
 }
 
-// restore recovers the statemachine from a snapshot. It restores the log and the
-// configuration of statemachine.
+// restore recovers the state machine from a snapshot. It restores the log and the
+// configuration of state machine.
 func (r *raft) restore(s pb.Snapshot) bool {
 	if s.Metadata.Index <= r.raftLog.committed {
 		return false

+ 10 - 9
raft/raft_paper_test.go

@@ -605,15 +605,16 @@ func TestFollowerCommitEntry(t *testing.T) {
 func TestFollowerCheckMsgApp(t *testing.T) {
 	ents := []pb.Entry{{Term: 1, Index: 1}, {Term: 2, Index: 2}}
 	tests := []struct {
-		term    uint64
-		index   uint64
-		wreject bool
+		term        uint64
+		index       uint64
+		wreject     bool
+		wrejectHint uint64
 	}{
-		{ents[0].Term, ents[0].Index, false},
-		{ents[0].Term, ents[0].Index + 1, true},
-		{ents[0].Term + 1, ents[0].Index, true},
-		{ents[1].Term, ents[1].Index, false},
-		{3, 3, true},
+		{ents[0].Term, ents[0].Index, false, 0},
+		{ents[0].Term, ents[0].Index + 1, true, 2},
+		{ents[0].Term + 1, ents[0].Index, true, 2},
+		{ents[1].Term, ents[1].Index, false, 0},
+		{3, 3, true, 2},
 	}
 	for i, tt := range tests {
 		storage := NewMemoryStorage()
@@ -626,7 +627,7 @@ func TestFollowerCheckMsgApp(t *testing.T) {
 
 		msgs := r.readMessages()
 		wmsgs := []pb.Message{
-			{From: 1, To: 2, Type: pb.MsgAppResp, Term: 2, Index: tt.index, Reject: tt.wreject},
+			{From: 1, To: 2, Type: pb.MsgAppResp, Term: 2, Index: tt.index, Reject: tt.wreject, RejectHint: tt.wrejectHint},
 		}
 		if !reflect.DeepEqual(msgs, wmsgs) {
 			t.Errorf("#%d: msgs = %+v, want %+v", i, msgs, wmsgs)

+ 26 - 17
raft/raft_test.go

@@ -80,50 +80,59 @@ func TestProgressUpdate(t *testing.T) {
 
 func TestProgressMaybeDecr(t *testing.T) {
 	tests := []struct {
-		m  uint64
-		n  uint64
-		to uint64
+		m        uint64
+		n        uint64
+		rejected uint64
+		last     uint64
 
 		w  bool
 		wn uint64
 	}{
 		{
 			// match != 0 is always false
-			1, 0, 0, false, 0,
+			1, 0, 0, 0, false, 0,
 		},
 		{
 			// match != 0 and to is greater than match
 			// directly decrease to match+1
-			5, 10, 5, false, 10,
+			5, 10, 5, 5, false, 10,
 		},
 		{
 			// match != 0 and to is greater than match
 			// directly decrease to match+1
-			5, 10, 4, false, 10,
+			5, 10, 4, 4, false, 10,
 		},
 		{
 			// match != 0 and to is not greater than match
-			5, 10, 9, true, 6,
+			5, 10, 9, 9, true, 6,
 		},
 		{
-			// next-1 != to is always false
-			0, 0, 0, false, 0,
+			// next-1 != rejected is always false
+			0, 0, 0, 0, false, 0,
 		},
 		{
-			// next-1 != to is always false
-			0, 10, 5, false, 10,
+			// next-1 != rejected is always false
+			0, 10, 5, 5, false, 10,
 		},
 		{
 			// next>1 = decremented by 1
-			0, 10, 9, true, 9,
+			0, 10, 9, 9, true, 9,
 		},
 		{
 			// next>1 = decremented by 1
-			0, 2, 1, true, 1,
+			0, 2, 1, 1, true, 1,
 		},
 		{
 			// next<=1 = reset to 1
-			0, 1, 0, true, 1,
+			0, 1, 0, 0, true, 1,
+		},
+		{
+			// decrease to min(rejected, last+1)
+			0, 10, 9, 2, true, 3,
+		},
+		{
+			// rejected < 1, reset to 1
+			0, 10, 9, 0, true, 1,
 		},
 	}
 	for i, tt := range tests {
@@ -131,7 +140,7 @@ func TestProgressMaybeDecr(t *testing.T) {
 			match: tt.m,
 			next:  tt.n,
 		}
-		if g := p.maybeDecrTo(tt.to); g != tt.w {
+		if g := p.maybeDecrTo(tt.rejected, tt.last); g != tt.w {
 			t.Errorf("#%d: maybeDecrTo= %t, want %t", i, g, tt.w)
 		}
 		if gm := p.match; gm != tt.m {
@@ -173,7 +182,7 @@ func TestProgressWaitReset(t *testing.T) {
 	p := &progress{
 		wait: 1,
 	}
-	p.maybeDecrTo(1)
+	p.maybeDecrTo(1, 1)
 	if p.wait != 0 {
 		t.Errorf("wait= %d, want 0", p.wait)
 	}
@@ -1001,7 +1010,7 @@ func TestLeaderAppResp(t *testing.T) {
 		sm.becomeCandidate()
 		sm.becomeLeader()
 		sm.readMessages()
-		sm.Step(pb.Message{From: 2, Type: pb.MsgAppResp, Index: tt.index, Term: sm.Term, Reject: tt.reject})
+		sm.Step(pb.Message{From: 2, Type: pb.MsgAppResp, Index: tt.index, Term: sm.Term, Reject: tt.reject, RejectHint: tt.index})
 
 		p := sm.prs[2]
 		if p.match != tt.wmatch {

+ 20 - 0
raft/raftpb/raft.pb.go

@@ -200,6 +200,7 @@ type Message struct {
 	Commit           uint64      `protobuf:"varint,8,req,name=commit" json:"commit"`
 	Snapshot         Snapshot    `protobuf:"bytes,9,req,name=snapshot" json:"snapshot"`
 	Reject           bool        `protobuf:"varint,10,req,name=reject" json:"reject"`
+	RejectHint       uint64      `protobuf:"varint,11,req,name=rejectHint" json:"rejectHint"`
 	XXX_unrecognized []byte      `json:"-"`
 }
 
@@ -725,6 +726,21 @@ func (m *Message) Unmarshal(data []byte) error {
 				}
 			}
 			m.Reject = bool(v != 0)
+		case 11:
+			if wireType != 0 {
+				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+			}
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				m.RejectHint |= (uint64(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
 		default:
 			var sizeOfWire int
 			for {
@@ -1059,6 +1075,7 @@ func (m *Message) Size() (n int) {
 	l = m.Snapshot.Size()
 	n += 1 + l + sovRaft(uint64(l))
 	n += 2
+	n += 1 + sovRaft(uint64(m.RejectHint))
 	if m.XXX_unrecognized != nil {
 		n += len(m.XXX_unrecognized)
 	}
@@ -1278,6 +1295,9 @@ func (m *Message) MarshalTo(data []byte) (n int, err error) {
 		data[i] = 0
 	}
 	i++
+	data[i] = 0x58
+	i++
+	i = encodeVarintRaft(data, i, uint64(m.RejectHint))
 	if m.XXX_unrecognized != nil {
 		i += copy(data[i:], m.XXX_unrecognized)
 	}

+ 11 - 10
raft/raftpb/raft.proto

@@ -44,16 +44,17 @@ enum MessageType {
 }
 
 message Message {
-	required MessageType type     = 1  [(gogoproto.nullable) = false];
-	required uint64      to       = 2  [(gogoproto.nullable) = false];
-	required uint64      from     = 3  [(gogoproto.nullable) = false];
-	required uint64      term     = 4  [(gogoproto.nullable) = false];
-	required uint64      logTerm  = 5  [(gogoproto.nullable) = false];
-	required uint64      index    = 6  [(gogoproto.nullable) = false];
-	repeated Entry       entries  = 7  [(gogoproto.nullable) = false];
-	required uint64      commit   = 8  [(gogoproto.nullable) = false];
-	required Snapshot    snapshot = 9  [(gogoproto.nullable) = false];
-	required bool        reject   = 10 [(gogoproto.nullable) = false];
+	required MessageType type        = 1  [(gogoproto.nullable) = false];
+	required uint64      to          = 2  [(gogoproto.nullable) = false];
+	required uint64      from        = 3  [(gogoproto.nullable) = false];
+	required uint64      term        = 4  [(gogoproto.nullable) = false];
+	required uint64      logTerm     = 5  [(gogoproto.nullable) = false];
+	required uint64      index       = 6  [(gogoproto.nullable) = false];
+	repeated Entry       entries     = 7  [(gogoproto.nullable) = false];
+	required uint64      commit      = 8  [(gogoproto.nullable) = false];
+	required Snapshot    snapshot    = 9  [(gogoproto.nullable) = false];
+	required bool        reject      = 10 [(gogoproto.nullable) = false];
+	required uint64      rejectHint  = 11 [(gogoproto.nullable) = false];
 }
 
 message HardState {