Sfoglia il codice sorgente

Merge pull request #1197 from coreos/raft_t

Add raft msg denied
Xiang Li 11 anni fa
parent
commit
1eb09acd8b
4 ha cambiato i file con 88 aggiunte e 63 eliminazioni
  1. 7 7
      raft/raft.go
  2. 44 47
      raft/raft_test.go
  3. 27 0
      raft/raftpb/raft.pb.go
  4. 10 9
      raft/raftpb/raft.proto

+ 7 - 7
raft/raft.go

@@ -378,7 +378,7 @@ func (r *raft) handleAppendEntries(m pb.Message) {
 	if r.raftLog.maybeAppend(m.Index, m.LogTerm, m.Commit, m.Entries...) {
 		r.send(pb.Message{To: m.From, Type: msgAppResp, Index: r.raftLog.lastIndex()})
 	} else {
-		r.send(pb.Message{To: m.From, Type: msgAppResp, Index: -1})
+		r.send(pb.Message{To: m.From, Type: msgAppResp, Denied: true})
 	}
 }
 
@@ -420,7 +420,7 @@ func stepLeader(r *raft, m pb.Message) {
 		r.appendEntry(e)
 		r.bcastAppend()
 	case msgAppResp:
-		if m.Index < 0 {
+		if m.Denied {
 			r.prs[m.From].decr()
 			r.sendAppend(m.From)
 		} else {
@@ -430,7 +430,7 @@ func stepLeader(r *raft, m pb.Message) {
 			}
 		}
 	case msgVote:
-		r.send(pb.Message{To: m.From, Type: msgVoteResp, Index: -1})
+		r.send(pb.Message{To: m.From, Type: msgVoteResp, Denied: true})
 	}
 }
 
@@ -445,9 +445,9 @@ func stepCandidate(r *raft, m pb.Message) {
 		r.becomeFollower(m.Term, m.From)
 		r.handleSnapshot(m)
 	case msgVote:
-		r.send(pb.Message{To: m.From, Type: msgVoteResp, Index: -1})
+		r.send(pb.Message{To: m.From, Type: msgVoteResp, Denied: true})
 	case msgVoteResp:
-		gr := r.poll(m.From, m.Index >= 0)
+		gr := r.poll(m.From, !m.Denied)
 		switch r.q() {
 		case gr:
 			r.becomeLeader()
@@ -477,9 +477,9 @@ func stepFollower(r *raft, m pb.Message) {
 		if (r.Vote == None || r.Vote == m.From) && r.raftLog.isUpToDate(m.Index, m.LogTerm) {
 			r.elapsed = 0
 			r.Vote = m.From
-			r.send(pb.Message{To: m.From, Type: msgVoteResp, Index: r.raftLog.lastIndex()})
+			r.send(pb.Message{To: m.From, Type: msgVoteResp})
 		} else {
-			r.send(pb.Message{To: m.From, Type: msgVoteResp, Index: -1})
+			r.send(pb.Message{To: m.From, Type: msgVoteResp, Denied: true})
 		}
 	}
 }

+ 44 - 47
raft/raft_test.go

@@ -484,22 +484,22 @@ func TestHandleMsgApp(t *testing.T) {
 		m       pb.Message
 		wIndex  int64
 		wCommit int64
-		wAccept bool
+		wDenied bool
 	}{
 		// Ensure 1
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 3, Index: 2, Commit: 3}, 2, 0, false}, // previous log mismatch
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 3, Index: 3, Commit: 3}, 2, 0, false}, // previous log non-exist
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 3, Index: 2, Commit: 3}, 2, 0, true}, // previous log mismatch
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 3, Index: 3, Commit: 3}, 2, 0, true}, // previous log non-exist
 
 		// Ensure 2
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 1, Index: 1, Commit: 1}, 2, 1, true},
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 0, Index: 0, Commit: 1, Entries: []pb.Entry{{Term: 2}}}, 1, 1, true},
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 2, Index: 2, Commit: 3, Entries: []pb.Entry{{Term: 2}, {Term: 2}}}, 4, 3, true},
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 2, Index: 2, Commit: 4, Entries: []pb.Entry{{Term: 2}}}, 3, 3, true},
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 1, Index: 1, Commit: 4, Entries: []pb.Entry{{Term: 2}}}, 2, 2, true},
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 1, Index: 1, Commit: 1}, 2, 1, false},
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 0, Index: 0, Commit: 1, Entries: []pb.Entry{{Term: 2}}}, 1, 1, false},
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 2, Index: 2, Commit: 3, Entries: []pb.Entry{{Term: 2}, {Term: 2}}}, 4, 3, false},
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 2, Index: 2, Commit: 4, Entries: []pb.Entry{{Term: 2}}}, 3, 3, false},
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 1, Index: 1, Commit: 4, Entries: []pb.Entry{{Term: 2}}}, 2, 2, false},
 
 		// Ensure 3
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 2, Index: 2, Commit: 2}, 2, 2, true},
-		{pb.Message{Type: msgApp, Term: 2, LogTerm: 2, Index: 2, Commit: 4}, 2, 2, true}, // commit upto min(commit, last)
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 2, Index: 2, Commit: 2}, 2, 2, false},
+		{pb.Message{Type: msgApp, Term: 2, LogTerm: 2, Index: 2, Commit: 4}, 2, 2, false}, // commit upto min(commit, last)
 	}
 
 	for i, tt := range tests {
@@ -518,14 +518,10 @@ func TestHandleMsgApp(t *testing.T) {
 		}
 		m := sm.ReadMessages()
 		if len(m) != 1 {
-			t.Errorf("#%d: msg = nil, want 1", i)
+			t.Fatalf("#%d: msg = nil, want 1", i)
 		}
-		gaccept := true
-		if m[0].Index == -1 {
-			gaccept = false
-		}
-		if gaccept != tt.wAccept {
-			t.Errorf("#%d: accept = %v, want %v", i, gaccept, tt.wAccept)
+		if m[0].Denied != tt.wDenied {
+			t.Errorf("#%d: denied = %v, want %v", i, m[0].Denied, tt.wDenied)
 		}
 	}
 }
@@ -535,33 +531,33 @@ func TestRecvMsgVote(t *testing.T) {
 		state   StateType
 		i, term int64
 		voteFor int64
-		w       int64
+		wdenied bool
 	}{
-		{StateFollower, 0, 0, None, -1},
-		{StateFollower, 0, 1, None, -1},
-		{StateFollower, 0, 2, None, -1},
-		{StateFollower, 0, 3, None, 2},
+		{StateFollower, 0, 0, None, true},
+		{StateFollower, 0, 1, None, true},
+		{StateFollower, 0, 2, None, true},
+		{StateFollower, 0, 3, None, false},
 
-		{StateFollower, 1, 0, None, -1},
-		{StateFollower, 1, 1, None, -1},
-		{StateFollower, 1, 2, None, -1},
-		{StateFollower, 1, 3, None, 2},
+		{StateFollower, 1, 0, None, true},
+		{StateFollower, 1, 1, None, true},
+		{StateFollower, 1, 2, None, true},
+		{StateFollower, 1, 3, None, false},
 
-		{StateFollower, 2, 0, None, -1},
-		{StateFollower, 2, 1, None, -1},
-		{StateFollower, 2, 2, None, 2},
-		{StateFollower, 2, 3, None, 2},
+		{StateFollower, 2, 0, None, true},
+		{StateFollower, 2, 1, None, true},
+		{StateFollower, 2, 2, None, false},
+		{StateFollower, 2, 3, None, false},
 
-		{StateFollower, 3, 0, None, -1},
-		{StateFollower, 3, 1, None, -1},
-		{StateFollower, 3, 2, None, 2},
-		{StateFollower, 3, 3, None, 2},
+		{StateFollower, 3, 0, None, true},
+		{StateFollower, 3, 1, None, true},
+		{StateFollower, 3, 2, None, false},
+		{StateFollower, 3, 3, None, false},
 
-		{StateFollower, 3, 2, 2, 2},
-		{StateFollower, 3, 2, 1, -1},
+		{StateFollower, 3, 2, 2, false},
+		{StateFollower, 3, 2, 1, true},
 
-		{StateLeader, 3, 3, 1, -1},
-		{StateCandidate, 3, 3, 1, -1},
+		{StateLeader, 3, 3, 1, true},
+		{StateCandidate, 3, 3, 1, true},
 	}
 
 	for i, tt := range tests {
@@ -582,11 +578,11 @@ func TestRecvMsgVote(t *testing.T) {
 
 		msgs := sm.ReadMessages()
 		if g := len(msgs); g != 1 {
-			t.Errorf("#%d: len(msgs) = %d, want 1", i, g)
+			t.Fatalf("#%d: len(msgs) = %d, want 1", i, g)
 			continue
 		}
-		if g := msgs[0].Index; g != tt.w {
-			t.Errorf("#%d, m.Index = %d, want %d", i, g, tt.w)
+		if g := msgs[0].Denied; g != tt.wdenied {
+			t.Errorf("#%d, m.Denied = %d, want %d", i, g, tt.wdenied)
 		}
 	}
 }
@@ -698,12 +694,13 @@ func TestAllServerStepdown(t *testing.T) {
 func TestLeaderAppResp(t *testing.T) {
 	tests := []struct {
 		index      int64
+		denied     bool
 		wmsgNum    int
 		windex     int64
 		wcommitted int64
 	}{
-		{-1, 1, 1, 0}, // bad resp; leader does not commit; reply with log entries
-		{2, 2, 2, 2},  // good resp; leader commits; broadcast with commit index
+		{-1, true, 1, 1, 0}, // bad resp; leader does not commit; reply with log entries
+		{2, false, 2, 2, 2}, // good resp; leader commits; broadcast with commit index
 	}
 
 	for i, tt := range tests {
@@ -714,7 +711,7 @@ func TestLeaderAppResp(t *testing.T) {
 		sm.becomeCandidate()
 		sm.becomeLeader()
 		sm.ReadMessages()
-		sm.Step(pb.Message{From: 2, Type: msgAppResp, Index: tt.index, Term: sm.Term})
+		sm.Step(pb.Message{From: 2, Type: msgAppResp, Index: tt.index, Term: sm.Term, Denied: tt.denied})
 		msgs := sm.ReadMessages()
 
 		if len(msgs) != tt.wmsgNum {
@@ -883,7 +880,7 @@ func TestProvideSnap(t *testing.T) {
 	sm.Step(pb.Message{From: 1, To: 1, Type: msgBeat})
 	msgs := sm.ReadMessages()
 	if len(msgs) != 1 {
-		t.Errorf("len(msgs) = %d, want 1", len(msgs))
+		t.Fatalf("len(msgs) = %d, want 1", len(msgs))
 	}
 	m := msgs[0]
 	if m.Type != msgApp {
@@ -894,10 +891,10 @@ func TestProvideSnap(t *testing.T) {
 	// node 1 needs a snapshot
 	sm.prs[2].next = sm.raftLog.offset
 
-	sm.Step(pb.Message{From: 2, To: 1, Type: msgAppResp, Index: -1})
+	sm.Step(pb.Message{From: 2, To: 1, Type: msgAppResp, Index: -1, Denied: true})
 	msgs = sm.ReadMessages()
 	if len(msgs) != 1 {
-		t.Errorf("len(msgs) = %d, want 1", len(msgs))
+		t.Fatalf("len(msgs) = %d, want 1", len(msgs))
 	}
 	m = msgs[0]
 	if m.Type != msgSnap {

+ 27 - 0
raft/raftpb/raft.pb.go

@@ -141,6 +141,7 @@ type Message struct {
 	Entries          []Entry  `protobuf:"bytes,7,rep,name=entries" json:"entries"`
 	Commit           int64    `protobuf:"varint,8,req,name=commit" json:"commit"`
 	Snapshot         Snapshot `protobuf:"bytes,9,req,name=snapshot" json:"snapshot"`
+	Denied           bool     `protobuf:"varint,10,req,name=denied" json:"denied"`
 	XXX_unrecognized []byte   `json:"-"`
 }
 
@@ -623,6 +624,23 @@ func (m *Message) Unmarshal(data []byte) error {
 				return err
 			}
 			index = postIndex
+		case 10:
+			if wireType != 0 {
+				return code_google_com_p_gogoprotobuf_proto.ErrWrongType
+			}
+			var v int
+			for shift := uint(0); ; shift += 7 {
+				if index >= l {
+					return io.ErrUnexpectedEOF
+				}
+				b := data[index]
+				index++
+				v |= (int(b) & 0x7F) << shift
+				if b < 0x80 {
+					break
+				}
+			}
+			m.Denied = bool(v != 0)
 		default:
 			var sizeOfWire int
 			for {
@@ -899,6 +917,7 @@ func (m *Message) Size() (n int) {
 	n += 1 + sovRaft(uint64(m.Commit))
 	l = m.Snapshot.Size()
 	n += 1 + l + sovRaft(uint64(l))
+	n += 2
 	if m.XXX_unrecognized != nil {
 		n += len(m.XXX_unrecognized)
 	}
@@ -1097,6 +1116,14 @@ func (m *Message) MarshalTo(data []byte) (n int, err error) {
 		return 0, err
 	}
 	i += n1
+	data[i] = 0x50
+	i++
+	if m.Denied {
+		data[i] = 1
+	} else {
+		data[i] = 0
+	}
+	i++
 	if m.XXX_unrecognized != nil {
 		i += copy(data[i:], m.XXX_unrecognized)
 	}

+ 10 - 9
raft/raftpb/raft.proto

@@ -32,15 +32,16 @@ message Snapshot {
 }
 
 message Message {
-	required int64 type        = 1 [(gogoproto.nullable) = false];
-	required int64 to          = 2 [(gogoproto.nullable) = false];
-	required int64 from        = 3 [(gogoproto.nullable) = false];
-	required int64 term        = 4 [(gogoproto.nullable) = false];
-	required int64 logTerm     = 5 [(gogoproto.nullable) = false];
-	required int64 index       = 6 [(gogoproto.nullable) = false];
-	repeated Entry entries     = 7 [(gogoproto.nullable) = false];
-	required int64 commit      = 8 [(gogoproto.nullable) = false];
-	required Snapshot snapshot = 9 [(gogoproto.nullable) = false];
+	required int64 type        = 1  [(gogoproto.nullable) = false];
+	required int64 to          = 2  [(gogoproto.nullable) = false];
+	required int64 from        = 3  [(gogoproto.nullable) = false];
+	required int64 term        = 4  [(gogoproto.nullable) = false];
+	required int64 logTerm     = 5  [(gogoproto.nullable) = false];
+	required int64 index       = 6  [(gogoproto.nullable) = false];
+	repeated Entry entries     = 7  [(gogoproto.nullable) = false];
+	required int64 commit      = 8  [(gogoproto.nullable) = false];
+	required Snapshot snapshot = 9  [(gogoproto.nullable) = false];
+	required bool  denied      = 10 [(gogoproto.nullable) = false];
 }
 
 message HardState {