Browse Source

raft: do not resend snapshot if not necessary

raft relies on the link layer to report the status of the sent snapshot.
If the snapshot is still sending, the replication to that remote peer will
be paused. If the snapshot finish sending, the replication will begin
optimistically after electionTimeout. If the snapshot fails, raft will
try to resend it.
Xiang Li 11 years ago
parent
commit
9b4d52ee73
8 changed files with 225 additions and 15 deletions
  1. 2 0
      etcdserver/server_test.go
  2. 20 0
      raft/node.go
  3. 1 1
      raft/node_test.go
  4. 56 2
      raft/raft.go
  5. 128 0
      raft/raft_snap_test.go
  6. 3 0
      raft/raftpb/raft.pb.go
  7. 12 11
      raft/raftpb/raft.proto
  8. 3 1
      raft/util.go

+ 2 - 0
etcdserver/server_test.go

@@ -1305,6 +1305,8 @@ func (n *nodeRecorder) Stop() {
 
 func (n *nodeRecorder) ReportUnreachable(id uint64) {}
 
+func (n *nodeRecorder) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
+
 func (n *nodeRecorder) Compact(index uint64, nodes []uint64, d []byte) {
 	n.Record(testutil.Action{Name: "Compact"})
 }

+ 20 - 0
raft/node.go

@@ -22,6 +22,13 @@ import (
 	pb "github.com/coreos/etcd/raft/raftpb"
 )
 
+type SnapshotStatus int
+
+const (
+	SnapshotFinish  SnapshotStatus = 1
+	SnapshotFailure SnapshotStatus = 2
+)
+
 var (
 	emptyState = pb.HardState{}
 
@@ -68,6 +75,8 @@ type Ready struct {
 
 	// Messages specifies outbound messages to be sent AFTER Entries are
 	// committed to stable storage.
+	// If it contains a MsgSnap message, the application MUST report back to raft
+	// when the snapshot has been received or has failed by calling ReportSnapshot.
 	Messages []pb.Message
 }
 
@@ -121,6 +130,8 @@ type Node interface {
 	Status() Status
 	// Report reports the given node is not reachable for the last send.
 	ReportUnreachable(id uint64)
+	// ReportSnapshot reports the stutus of the sent snapshot.
+	ReportSnapshot(id uint64, status SnapshotStatus)
 	// Stop performs any necessary termination of the Node
 	Stop()
 }
@@ -427,6 +438,15 @@ func (n *node) ReportUnreachable(id uint64) {
 	}
 }
 
+func (n *node) ReportSnapshot(id uint64, status SnapshotStatus) {
+	rej := status == SnapshotFailure
+
+	select {
+	case n.recvc <- pb.Message{Type: pb.MsgSnapStatus, From: id, Reject: rej}:
+	case <-n.done:
+	}
+}
+
 func newReady(r *raft, prevSoftSt *SoftState, prevHardSt pb.HardState) Ready {
 	rd := Ready{
 		Entries:          r.raftLog.unstableEntries(),

+ 1 - 1
raft/node_test.go

@@ -42,7 +42,7 @@ func TestNodeStep(t *testing.T) {
 				t.Errorf("%d: cannot receive %s on propc chan", msgt, msgn)
 			}
 		} else {
-			if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup {
+			if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup || msgt == raftpb.MsgUnreachable || msgt == raftpb.MsgSnapStatus {
 				select {
 				case <-n.recvc:
 					t.Errorf("%d: step should ignore %s", msgt, msgn)

+ 56 - 2
raft/raft.go

@@ -60,6 +60,15 @@ type Progress struct {
 	// Unreachable will be unset if raft starts to receive message (msgAppResp,
 	// msgHeartbeatResp) from the remote peer of the Progress.
 	Unreachable bool
+	// If there is a pending snapshot, the pendingSnapshot will be set to the
+	// index of the snapshot. If pendingSnapshot is set, the replication process of
+	// this Progress will be paused. raft will not resend snapshot until the pending one
+	// is reported to be failed.
+	//
+	// PendingSnapshot is set when raft sends out a snapshot to this Progress.
+	// PendingSnapshot is unset when the snapshot is reported to be successfully,
+	// or raft updates an equal or higher Match for this Progress.
+	PendingSnapshot uint64
 }
 
 func (pr *Progress) update(n uint64) {
@@ -114,6 +123,33 @@ func (pr *Progress) reachable()       { pr.Unreachable = false }
 func (pr *Progress) unreachable()     { pr.Unreachable = true }
 func (pr *Progress) shouldWait() bool { return (pr.Unreachable || pr.Match == 0) && pr.Wait > 0 }
 
+func (pr *Progress) hasPendingSnapshot() bool    { return pr.PendingSnapshot != 0 }
+func (pr *Progress) setPendingSnapshot(i uint64) { pr.PendingSnapshot = i }
+
+// finishSnapshot unsets the pending snapshot and optimistically increase Next to
+// the index of pendingSnapshot + 1. The next replication message is expected
+// to be msgApp.
+func (pr *Progress) snapshotFinish() {
+	pr.Next = pr.PendingSnapshot + 1
+	pr.PendingSnapshot = 0
+}
+
+// snapshotFail unsets the pending snapshot. The next replication message is expected
+// to be another msgSnap.
+func (pr *Progress) snapshotFail() {
+	pr.PendingSnapshot = 0
+}
+
+// maybeSnapshotAbort unsets pendingSnapshot if Match is equal or higher than
+// the pendingSnapshot
+func (pr *Progress) maybeSnapshotAbort() bool {
+	if pr.hasPendingSnapshot() && pr.Match >= pr.PendingSnapshot {
+		pr.PendingSnapshot = 0
+		return true
+	}
+	return false
+}
+
 func (pr *Progress) String() string {
 	return fmt.Sprintf("next = %d, match = %d, wait = %v", pr.Next, pr.Match, pr.Wait)
 }
@@ -227,7 +263,7 @@ func (r *raft) send(m pb.Message) {
 // sendAppend sends RRPC, with entries to the given peer.
 func (r *raft) sendAppend(to uint64) {
 	pr := r.prs[to]
-	if pr.shouldWait() {
+	if pr.shouldWait() || pr.hasPendingSnapshot() {
 		return
 	}
 	m := pb.Message{}
@@ -251,7 +287,8 @@ func (r *raft) sendAppend(to uint64) {
 		sindex, sterm := snapshot.Metadata.Index, snapshot.Metadata.Term
 		log.Printf("raft: %x [firstindex: %d, commit: %d] sent snapshot[index: %d, term: %d] to %x [%s]",
 			r.id, r.raftLog.firstIndex(), r.Commit, sindex, sterm, to, pr)
-		pr.waitSet(r.electionTimeout)
+		pr.setPendingSnapshot(sindex)
+		log.Printf("raft: %x paused sending replication messages to %x [%s]", r.id, to, pr)
 	} else {
 		m.Type = pb.MsgApp
 		m.Index = pr.Next - 1
@@ -509,6 +546,9 @@ func stepLeader(r *raft, m pb.Message) {
 		} else {
 			oldWait := pr.shouldWait()
 			pr.update(m.Index)
+			if r.prs[m.From].maybeSnapshotAbort() {
+				log.Printf("raft: %x snapshot aborted, resumed sending replication messages to %x [%s]", r.id, m.From, pr)
+			}
 			if r.maybeCommit() {
 				r.bcastAppend()
 			} else if oldWait {
@@ -526,6 +566,20 @@ func stepLeader(r *raft, m pb.Message) {
 		log.Printf("raft: %x [logterm: %d, index: %d, vote: %x] rejected vote from %x [logterm: %d, index: %d] at term %d",
 			r.id, r.raftLog.lastTerm(), r.raftLog.lastIndex(), r.Vote, m.From, m.LogTerm, m.Index, r.Term)
 		r.send(pb.Message{To: m.From, Type: pb.MsgVoteResp, Reject: true})
+	case pb.MsgSnapStatus:
+		if !pr.hasPendingSnapshot() {
+			return
+		}
+		if m.Reject {
+			pr.snapshotFail()
+			log.Printf("raft: %x snapshot failed, resumed sending replication messages to %x [%s]", r.id, m.From, pr)
+		} else {
+			pr.snapshotFinish()
+			log.Printf("raft: %x snapshot succeeded resumed sending replication messages to %x [%s]", r.id, m.From, pr)
+			// wait for the msgAppResp from the remote node before sending
+			// out the next msgApp
+			pr.waitSet(r.electionTimeout)
+		}
 	case pb.MsgUnreachable:
 		r.prs[m.From].unreachable()
 	}

+ 128 - 0
raft/raft_snap_test.go

@@ -0,0 +1,128 @@
+// Copyright 2015 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raft
+
+import (
+	"testing"
+
+	pb "github.com/coreos/etcd/raft/raftpb"
+)
+
+var (
+	testingSnap = pb.Snapshot{
+		Metadata: pb.SnapshotMetadata{
+			Index:     11, // magic number
+			Term:      11, // magic number
+			ConfState: pb.ConfState{Nodes: []uint64{1, 2}},
+		},
+	}
+)
+
+func TestSendingSnapshotSetPendingSnapshot(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	// force set the next of node 1, so that
+	// node 1 needs a snapshot
+	sm.prs[2].Next = sm.raftLog.firstIndex()
+
+	sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: sm.prs[2].Next - 1, Reject: true})
+	if sm.prs[2].PendingSnapshot != 11 {
+		t.Fatalf("PendingSnapshot = %d, want 11", sm.prs[2].PendingSnapshot)
+	}
+}
+
+func TestPendingSnapshotPauseReplication(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.prs[2].setPendingSnapshot(11)
+
+	sm.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+	msgs := sm.readMessages()
+	if len(msgs) != 0 {
+		t.Fatalf("len(msgs) = %d, want 0", len(msgs))
+	}
+}
+
+func TestSnapshotFailure(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.prs[2].Next = 1
+	sm.prs[2].setPendingSnapshot(11)
+
+	sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgSnapStatus, Reject: true})
+	if sm.prs[2].PendingSnapshot != 0 {
+		t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot)
+	}
+	if sm.prs[2].Next != 1 {
+		t.Fatalf("Next = %d, want 1", sm.prs[2].Next)
+	}
+}
+
+func TestSnapshotSucceed(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.prs[2].Next = 1
+	sm.prs[2].setPendingSnapshot(11)
+
+	sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgSnapStatus, Reject: false})
+	if sm.prs[2].PendingSnapshot != 0 {
+		t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot)
+	}
+	if sm.prs[2].Next != 12 {
+		t.Fatalf("Next = %d, want 12", sm.prs[2].Next)
+	}
+}
+
+func TestSnapshotAbort(t *testing.T) {
+	storage := NewMemoryStorage()
+	sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0)
+	sm.restore(testingSnap)
+
+	sm.becomeCandidate()
+	sm.becomeLeader()
+
+	sm.prs[2].Next = 1
+	sm.prs[2].setPendingSnapshot(11)
+
+	// A successful msgAppResp that has a higher/equal index than the
+	// pending snapshot should abort the pending snapshot.
+	sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: 11})
+	if sm.prs[2].PendingSnapshot != 0 {
+		t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot)
+	}
+	if sm.prs[2].Next != 12 {
+		t.Fatalf("Next = %d, want 12", sm.prs[2].Next)
+	}
+}

+ 3 - 0
raft/raftpb/raft.pb.go

@@ -80,6 +80,7 @@ const (
 	MsgHeartbeat     MessageType = 8
 	MsgHeartbeatResp MessageType = 9
 	MsgUnreachable   MessageType = 10
+	MsgSnapStatus    MessageType = 11
 )
 
 var MessageType_name = map[int32]string{
@@ -94,6 +95,7 @@ var MessageType_name = map[int32]string{
 	8:  "MsgHeartbeat",
 	9:  "MsgHeartbeatResp",
 	10: "MsgUnreachable",
+	11: "MsgSnapStatus",
 }
 var MessageType_value = map[string]int32{
 	"MsgHup":           0,
@@ -107,6 +109,7 @@ var MessageType_value = map[string]int32{
 	"MsgHeartbeat":     8,
 	"MsgHeartbeatResp": 9,
 	"MsgUnreachable":   10,
+	"MsgSnapStatus":    11,
 }
 
 func (x MessageType) Enum() *MessageType {

+ 12 - 11
raft/raftpb/raft.proto

@@ -32,17 +32,18 @@ message Snapshot {
 }
 
 enum MessageType {
-	MsgHup           = 0;
-	MsgBeat          = 1;
-	MsgProp          = 2;
-	MsgApp           = 3;
-	MsgAppResp       = 4;
-	MsgVote          = 5;
-	MsgVoteResp      = 6;
-	MsgSnap          = 7;
-	MsgHeartbeat     = 8;
-	MsgHeartbeatResp = 9;
-	MsgUnreachable   = 10;
+	MsgHup             = 0;
+	MsgBeat            = 1;
+	MsgProp            = 2;
+	MsgApp             = 3;
+	MsgAppResp         = 4;
+	MsgVote            = 5;
+	MsgVoteResp        = 6;
+	MsgSnap            = 7;
+	MsgHeartbeat       = 8;
+	MsgHeartbeatResp   = 9;
+	MsgUnreachable     = 10;
+	MsgSnapStatus      = 11;
 }
 
 message Message {

+ 3 - 1
raft/util.go

@@ -46,7 +46,9 @@ func max(a, b uint64) uint64 {
 	return b
 }
 
-func IsLocalMsg(m pb.Message) bool { return m.Type == pb.MsgHup || m.Type == pb.MsgBeat }
+func IsLocalMsg(m pb.Message) bool {
+	return m.Type == pb.MsgHup || m.Type == pb.MsgBeat || m.Type == pb.MsgUnreachable || m.Type == pb.MsgSnapStatus
+}
 
 func IsResponseMsg(m pb.Message) bool {
 	return m.Type == pb.MsgAppResp || m.Type == pb.MsgVoteResp || m.Type == pb.MsgHeartbeatResp || m.Type == pb.MsgUnreachable