Просмотр исходного кода

raft: do not resend snapshot if not necessary

raft relies on the link layer to report the status of the sent snapshot.
If the snapshot is still sending, the replication to that remote peer will
be paused. If the snapshot finish sending, the replication will begin
optimistically after electionTimeout. If the snapshot fails, raft will
try to resend it.
Xiang Li 11 лет назад
Родитель
Сommit
9b4d52ee73
8 измененных файлов с 225 добавлено и 15 удалено
  1. 2 0
      etcdserver/server_test.go
  2. 20 0
      raft/node.go
  3. 1 1
      raft/node_test.go
  4. 56 2
      raft/raft.go
  5. 128 0
      raft/raft_snap_test.go
  6. 3 0
      raft/raftpb/raft.pb.go
  7. 12 11
      raft/raftpb/raft.proto
  8. 3 1
      raft/util.go

+ 2 - 0
etcdserver/server_test.go

@@ -1305,6 +1305,8 @@ func (n *nodeRecorder) Stop() {
 
 
 func (n *nodeRecorder) ReportUnreachable(id uint64) {}
 func (n *nodeRecorder) ReportUnreachable(id uint64) {}
 
 
+func (n *nodeRecorder) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
+
 func (n *nodeRecorder) Compact(index uint64, nodes []uint64, d []byte) {
 func (n *nodeRecorder) Compact(index uint64, nodes []uint64, d []byte) {
 	n.Record(testutil.Action{Name: "Compact"})
 	n.Record(testutil.Action{Name: "Compact"})
 }
 }

+ 20 - 0
raft/node.go

@@ -22,6 +22,13 @@ import (
 	pb "github.com/coreos/etcd/raft/raftpb"
 	pb "github.com/coreos/etcd/raft/raftpb"
 )
 )
 
 
+type SnapshotStatus int
+
+const (
+	SnapshotFinish  SnapshotStatus = 1
+	SnapshotFailure SnapshotStatus = 2
+)
+
 var (
 var (
 	emptyState = pb.HardState{}
 	emptyState = pb.HardState{}
 
 
@@ -68,6 +75,8 @@ type Ready struct {
 
 
 	// Messages specifies outbound messages to be sent AFTER Entries are
 	// Messages specifies outbound messages to be sent AFTER Entries are
 	// committed to stable storage.
 	// committed to stable storage.
+	// If it contains a MsgSnap message, the application MUST report back to raft
+	// when the snapshot has been received or has failed by calling ReportSnapshot.
 	Messages []pb.Message
 	Messages []pb.Message
 }
 }
 
 
@@ -121,6 +130,8 @@ type Node interface {
 	Status() Status
 	Status() Status
 	// Report reports the given node is not reachable for the last send.
 	// Report reports the given node is not reachable for the last send.
 	ReportUnreachable(id uint64)
 	ReportUnreachable(id uint64)
+	// ReportSnapshot reports the stutus of the sent snapshot.
+	ReportSnapshot(id uint64, status SnapshotStatus)
 	// Stop performs any necessary termination of the Node
 	// Stop performs any necessary termination of the Node
 	Stop()
 	Stop()
 }
 }
@@ -427,6 +438,15 @@ func (n *node) ReportUnreachable(id uint64) {
 	}
 	}
 }
 }
 
 
+func (n *node) ReportSnapshot(id uint64, status SnapshotStatus) {
+	rej := status == SnapshotFailure
+
+	select {
+	case n.recvc <- pb.Message{Type: pb.MsgSnapStatus, From: id, Reject: rej}:
+	case <-n.done:
+	}
+}
+
 func newReady(r *raft, prevSoftSt *SoftState, prevHardSt pb.HardState) Ready {
 func newReady(r *raft, prevSoftSt *SoftState, prevHardSt pb.HardState) Ready {
 	rd := Ready{
 	rd := Ready{
 		Entries:          r.raftLog.unstableEntries(),
 		Entries:          r.raftLog.unstableEntries(),

+ 1 - 1
raft/node_test.go

@@ -42,7 +42,7 @@ func TestNodeStep(t *testing.T) {
 				t.Errorf("%d: cannot receive %s on propc chan", msgt, msgn)
 				t.Errorf("%d: cannot receive %s on propc chan", msgt, msgn)
 			}
 			}
 		} else {
 		} else {
-			if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup {
+			if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup || msgt == raftpb.MsgUnreachable || msgt == raftpb.MsgSnapStatus {
 				select {
 				select {
 				case <-n.recvc:
 				case <-n.recvc:
 					t.Errorf("%d: step should ignore %s", msgt, msgn)
 					t.Errorf("%d: step should ignore %s", msgt, msgn)

+ 56 - 2
raft/raft.go

@@ -60,6 +60,15 @@ type Progress struct {
 	// Unreachable will be unset if raft starts to receive message (msgAppResp,
 	// Unreachable will be unset if raft starts to receive message (msgAppResp,
 	// msgHeartbeatResp) from the remote peer of the Progress.
 	// msgHeartbeatResp) from the remote peer of the Progress.
 	Unreachable bool
 	Unreachable bool
+	// If there is a pending snapshot, the pendingSnapshot will be set to the
+	// index of the snapshot. If pendingSnapshot is set, the replication process of
+	// this Progress will be paused. raft will not resend snapshot until the pending one
+	// is reported to be failed.
+	//
+	// PendingSnapshot is set when raft sends out a snapshot to this Progress.
+	// PendingSnapshot is unset when the snapshot is reported to be successfully,
+	// or raft updates an equal or higher Match for this Progress.
+	PendingSnapshot uint64
 }
 }
 
 
 func (pr *Progress) update(n uint64) {
 func (pr *Progress) update(n uint64) {
@@ -114,6 +123,33 @@ func (pr *Progress) reachable()       { pr.Unreachable = false }
 func (pr *Progress) unreachable()     { pr.Unreachable = true }
 func (pr *Progress) unreachable()     { pr.Unreachable = true }
 func (pr *Progress) shouldWait() bool { return (pr.Unreachable || pr.Match == 0) && pr.Wait > 0 }
 func (pr *Progress) shouldWait() bool { return (pr.Unreachable || pr.Match == 0) && pr.Wait > 0 }
 
 
+func (pr *Progress) hasPendingSnapshot() bool    { return pr.PendingSnapshot != 0 }
+func (pr *Progress) setPendingSnapshot(i uint64) { pr.PendingSnapshot = i }
+
+// finishSnapshot unsets the pending snapshot and optimistically increase Next to
+// the index of pendingSnapshot + 1. The next replication message is expected
+// to be msgApp.
+func (pr *Progress) snapshotFinish() {
+	pr.Next = pr.PendingSnapshot + 1
+	pr.PendingSnapshot = 0
+}
+
+// snapshotFail unsets the pending snapshot. The next replication message is expected
+// to be another msgSnap.
+func (pr *Progress) snapshotFail() {
+	pr.PendingSnapshot = 0
+}
+
+// maybeSnapshotAbort unsets pendingSnapshot if Match is equal or higher than
+// the pendingSnapshot
+func (pr *Progress) maybeSnapshotAbort() bool {
+	if pr.hasPendingSnapshot() && pr.Match >= pr.PendingSnapshot {
+		pr.PendingSnapshot = 0
+		return true
+	}
+	return false
+}
+
 func (pr *Progress) String() string {
 func (pr *Progress) String() string {
 	return fmt.Sprintf("next = %d, match = %d, wait = %v", pr.Next, pr.Match, pr.Wait)
 	return fmt.Sprintf("next = %d, match = %d, wait = %v", pr.Next, pr.Match, pr.Wait)
 }
 }
@@ -227,7 +263,7 @@ func (r *raft) send(m pb.Message) {
 // sendAppend sends RRPC, with entries to the given peer.
 // sendAppend sends RRPC, with entries to the given peer.
 func (r *raft) sendAppend(to uint64) {
 func (r *raft) sendAppend(to uint64) {
 	pr := r.prs[to]
 	pr := r.prs[to]
-	if pr.shouldWait() {
+	if pr.shouldWait() || pr.hasPendingSnapshot() {
 		return
 		return
 	}
 	}
 	m := pb.Message{}
 	m := pb.Message{}
@@ -251,7 +287,8 @@ func (r *raft) sendAppend(to uint64) {
 		sindex, sterm := snapshot.Metadata.Index, snapshot.Metadata.Term
 		sindex, sterm := snapshot.Metadata.Index, snapshot.Metadata.Term
 		log.Printf("raft: %x [firstindex: %d, commit: %d] sent snapshot[index: %d, term: %d] to %x [%s]",
 		log.Printf("raft: %x [firstindex: %d, commit: %d] sent snapshot[index: %d, term: %d] to %x [%s]",
 			r.id, r.raftLog.firstIndex(), r.Commit, sindex, sterm, to, pr)
 			r.id, r.raftLog.firstIndex(), r.Commit, sindex, sterm, to, pr)
-		pr.waitSet(r.electionTimeout)
+		pr.setPendingSnapshot(sindex)
+		log.Printf("raft: %x paused sending replication messages to %x [%s]", r.id, to, pr)
 	} else {
 	} else {
 		m.Type = pb.MsgApp
 		m.Type = pb.MsgApp
 		m.Index = pr.Next - 1
 		m.Index = pr.Next - 1
@@ -509,6 +546,9 @@ func stepLeader(r *raft, m pb.Message) {
 		} else {
 		} else {
 			oldWait := pr.shouldWait()
 			oldWait := pr.shouldWait()
 			pr.update(m.Index)
 			pr.update(m.Index)
+			if r.prs[m.From].maybeSnapshotAbort() {
+				log.Printf("raft: %x snapshot aborted, resumed sending replication messages to %x [%s]", r.id, m.From, pr)
+			}
 			if r.maybeCommit() {
 			if r.maybeCommit() {
 				r.bcastAppend()
 				r.bcastAppend()
 			} else if oldWait {
 			} else if oldWait {
@@ -526,6 +566,20 @@ func stepLeader(r *raft, m pb.Message) {
 		log.Printf("raft: %x [logterm: %d, index: %d, vote: %x] rejected vote from %x [logterm: %d, index: %d] at term %d",
 		log.Printf("raft: %x [logterm: %d, index: %d, vote: %x] rejected vote from %x [logterm: %d, index: %d] at term %d",
 			r.id, r.raftLog.lastTerm(), r.raftLog.lastIndex(), r.Vote, m.From, m.LogTerm, m.Index, r.Term)
 			r.id, r.raftLog.lastTerm(), r.raftLog.lastIndex(), r.Vote, m.From, m.LogTerm, m.Index, r.Term)
 		r.send(pb.Message{To: m.From, Type: pb.MsgVoteResp, Reject: true})
 		r.send(pb.Message{To: m.From, Type: pb.MsgVoteResp, Reject: true})
+	case pb.MsgSnapStatus:
+		if !pr.hasPendingSnapshot() {
+			return
+		}
+		if m.Reject {
+			pr.snapshotFail()
+			log.Printf("raft: %x snapshot failed, resumed sending replication messages to %x [%s]", r.id, m.From, pr)
+		} else {
+			pr.snapshotFinish()
+			log.Printf("raft: %x snapshot succeeded resumed sending replication messages to %x [%s]", r.id, m.From, pr)
+			// wait for the msgAppResp from the remote node before sending
+			// out the next msgApp
+			pr.waitSet(r.electionTimeout)
+		}
 	case pb.MsgUnreachable:
 	case pb.MsgUnreachable:
 		r.prs[m.From].unreachable()
 		r.prs[m.From].unreachable()
 	}
 	}

+ 128 - 0
raft/raft_snap_test.go

@@ -0,0 +1,128 @@
+// Copyright 2015 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raft
+
+import (
+	"testing"
+
+	pb "github.com/coreos/etcd/raft/raftpb"
+)
+
+var (
+	testingSnap = pb.Snapshot{
+		Metadata: pb.SnapshotMetadata{
+			Index:     11, // magic number
+			Term:      11, // magic number
+			ConfState: pb.ConfState{Nodes: []uint64{1, 2}},
+		},
+	}
+)
+
+func TestSendingSnapshotSetPendingSnapshot(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	// force set the next of node 1, so that
+	// node 1 needs a snapshot
+	sm.prs[2].Next = sm.raftLog.firstIndex()
+
+	sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: sm.prs[2].Next - 1, Reject: true})
+	if sm.prs[2].PendingSnapshot != 11 {
+		t.Fatalf("PendingSnapshot = %d, want 11", sm.prs[2].PendingSnapshot)
+	}
+}
+
+func TestPendingSnapshotPauseReplication(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.prs[2].setPendingSnapshot(11)
+
+	sm.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+	msgs := sm.readMessages()
+	if len(msgs) != 0 {
+		t.Fatalf("len(msgs) = %d, want 0", len(msgs))
+	}
+}
+
+func TestSnapshotFailure(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.prs[2].Next = 1
+	sm.prs[2].setPendingSnapshot(11)
+
+	sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgSnapStatus, Reject: true})
+	if sm.prs[2].PendingSnapshot != 0 {
+		t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot)
+	}
+	if sm.prs[2].Next != 1 {
+		t.Fatalf("Next = %d, want 1", sm.prs[2].Next)
+	}
+}
+
+func TestSnapshotSucceed(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.prs[2].Next = 1
+	sm.prs[2].setPendingSnapshot(11)
+
+	sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgSnapStatus, Reject: false})
+	if sm.prs[2].PendingSnapshot != 0 {
+		t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot)
+	}
+	if sm.prs[2].Next != 12 {
+		t.Fatalf("Next = %d, want 12", sm.prs[2].Next)
+	}
+}
+
+func TestSnapshotAbort(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.prs[2].Next = 1
+	sm.prs[2].setPendingSnapshot(11)
+
+	// A successful msgAppResp that has a higher/equal index than the
+	// pending snapshot should abort the pending snapshot.
+	sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: 11})
+	if sm.prs[2].PendingSnapshot != 0 {
+		t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot)
+	}
+	if sm.prs[2].Next != 12 {
+		t.Fatalf("Next = %d, want 12", sm.prs[2].Next)
+	}
+}

+ 3 - 0
raft/raftpb/raft.pb.go

@@ -80,6 +80,7 @@ const (
 	MsgHeartbeat     MessageType = 8
 	MsgHeartbeat     MessageType = 8
 	MsgHeartbeatResp MessageType = 9
 	MsgHeartbeatResp MessageType = 9
 	MsgUnreachable   MessageType = 10
 	MsgUnreachable   MessageType = 10
+	MsgSnapStatus    MessageType = 11
 )
 )
 
 
 var MessageType_name = map[int32]string{
 var MessageType_name = map[int32]string{
@@ -94,6 +95,7 @@ var MessageType_name = map[int32]string{
 	8:  "MsgHeartbeat",
 	8:  "MsgHeartbeat",
 	9:  "MsgHeartbeatResp",
 	9:  "MsgHeartbeatResp",
 	10: "MsgUnreachable",
 	10: "MsgUnreachable",
+	11: "MsgSnapStatus",
 }
 }
 var MessageType_value = map[string]int32{
 var MessageType_value = map[string]int32{
 	"MsgHup":           0,
 	"MsgHup":           0,
@@ -107,6 +109,7 @@ var MessageType_value = map[string]int32{
 	"MsgHeartbeat":     8,
 	"MsgHeartbeat":     8,
 	"MsgHeartbeatResp": 9,
 	"MsgHeartbeatResp": 9,
 	"MsgUnreachable":   10,
 	"MsgUnreachable":   10,
+	"MsgSnapStatus":    11,
 }
 }
 
 
 func (x MessageType) Enum() *MessageType {
 func (x MessageType) Enum() *MessageType {

+ 12 - 11
raft/raftpb/raft.proto

@@ -32,17 +32,18 @@ message Snapshot {
 }
 }
 
 
 enum MessageType {
 enum MessageType {
-	MsgHup           = 0;
-	MsgBeat          = 1;
-	MsgProp          = 2;
-	MsgApp           = 3;
-	MsgAppResp       = 4;
-	MsgVote          = 5;
-	MsgVoteResp      = 6;
-	MsgSnap          = 7;
-	MsgHeartbeat     = 8;
-	MsgHeartbeatResp = 9;
-	MsgUnreachable   = 10;
+	MsgHup             = 0;
+	MsgBeat            = 1;
+	MsgProp            = 2;
+	MsgApp             = 3;
+	MsgAppResp         = 4;
+	MsgVote            = 5;
+	MsgVoteResp        = 6;
+	MsgSnap            = 7;
+	MsgHeartbeat       = 8;
+	MsgHeartbeatResp   = 9;
+	MsgUnreachable     = 10;
+	MsgSnapStatus      = 11;
 }
 }
 
 
 message Message {
 message Message {

+ 3 - 1
raft/util.go

@@ -46,7 +46,9 @@ func max(a, b uint64) uint64 {
 	return b
 	return b
 }
 }
 
 
-func IsLocalMsg(m pb.Message) bool { return m.Type == pb.MsgHup || m.Type == pb.MsgBeat }
+func IsLocalMsg(m pb.Message) bool {
+	return m.Type == pb.MsgHup || m.Type == pb.MsgBeat || m.Type == pb.MsgUnreachable || m.Type == pb.MsgSnapStatus
+}
 
 
 func IsResponseMsg(m pb.Message) bool {
 func IsResponseMsg(m pb.Message) bool {
 	return m.Type == pb.MsgAppResp || m.Type == pb.MsgVoteResp || m.Type == pb.MsgHeartbeatResp || m.Type == pb.MsgUnreachable
 	return m.Type == pb.MsgAppResp || m.Type == pb.MsgVoteResp || m.Type == pb.MsgHeartbeatResp || m.Type == pb.MsgUnreachable