Browse Source

Merge pull request #2544 from xiang90/raft-inflight

raft: add flow control for progress
Xiang Li 10 years ago
parent
commit
a552722f03
5 changed files with 434 additions and 15 deletions
  1. 79 1
      raft/progress.go
  2. 175 0
      raft/progress_test.go
  3. 16 6
      raft/raft.go
  4. 155 0
      raft/raft_flow_control_test.go
  5. 9 8
      raft/raft_test.go

+ 79 - 1
raft/progress.go

@@ -55,12 +55,22 @@ type Progress struct {
 	// this Progress will be paused. raft will not resend snapshot until the pending one
 	// is reported to be failed.
 	PendingSnapshot uint64
+
+	// inflights is a sliding window for the inflight messages.
+	// When inflights is full, no more message should be sent.
+	// When sends out a message, the index of the last entry should
+	// be add to inflights. The index MUST be added into inflights
+	// in order.
+	// When receives a reply, the previous inflights should be freed
+	// by calling inflights.freeTo.
+	ins *inflights
 }
 
 func (pr *Progress) resetState(state ProgressStateType) {
 	pr.Paused = false
 	pr.PendingSnapshot = 0
 	pr.State = state
+	pr.ins.reset()
 }
 
 func (pr *Progress) becomeProbe() {
@@ -135,7 +145,16 @@ func (pr *Progress) resume() { pr.Paused = false }
 
 // isPaused returns whether progress stops sending message.
 func (pr *Progress) isPaused() bool {
-	return pr.State == ProgressStateProbe && pr.Paused || pr.State == ProgressStateSnapshot
+	switch pr.State {
+	case ProgressStateProbe:
+		return pr.Paused
+	case ProgressStateReplicate:
+		return pr.ins.full()
+	case ProgressStateSnapshot:
+		return true
+	default:
+		panic("unexpected state")
+	}
 }
 
 func (pr *Progress) snapshotFailure() { pr.PendingSnapshot = 0 }
@@ -149,3 +168,62 @@ func (pr *Progress) maybeSnapshotAbort() bool {
 func (pr *Progress) String() string {
 	return fmt.Sprintf("next = %d, match = %d, state = %s, waiting = %v, pendingSnapshot = %d", pr.Next, pr.Match, pr.State, pr.isPaused(), pr.PendingSnapshot)
 }
+
+type inflights struct {
+	// the starting index in the buffer
+	start int
+	// number of inflights in the buffer
+	count int
+
+	// the size of the buffer
+	size   int
+	buffer []uint64
+}
+
+func newInflights(size int) *inflights {
+	return &inflights{
+		size:   size,
+		buffer: make([]uint64, size),
+	}
+}
+
+// add adds an inflight into inflights
+func (in *inflights) add(inflight uint64) {
+	if in.full() {
+		panic("cannot add into a full inflights")
+	}
+	next := in.start + in.count
+	if next >= in.size {
+		next -= in.size
+	}
+	in.buffer[next] = inflight
+	in.count++
+}
+
+// freeTo frees the inflights smaller or equal to the given `to` flight.
+func (in *inflights) freeTo(to uint64) {
+	for i := in.start; i < in.start+in.count; i++ {
+		idx := i
+		if i >= in.size {
+			idx -= in.size
+		}
+		if to < in.buffer[idx] {
+			in.count -= i - in.start
+			in.start = idx
+			break
+		}
+	}
+}
+
+func (in *inflights) freeFirstOne() { in.freeTo(in.buffer[in.start]) }
+
+// full returns true if the inflights is full.
+func (in *inflights) full() bool {
+	return in.count == in.size
+}
+
+// resets frees all inflights.
+func (in *inflights) reset() {
+	in.count = 0
+	in.start = 0
+}

+ 175 - 0
raft/progress_test.go

@@ -0,0 +1,175 @@
+// Copyright 2015 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raft
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestInflightsAdd(t *testing.T) {
+	// no rotating case
+	in := &inflights{
+		size:   10,
+		buffer: make([]uint64, 10),
+	}
+
+	for i := 0; i < 5; i++ {
+		in.add(uint64(i))
+	}
+
+	wantIn := &inflights{
+		start: 0,
+		count: 5,
+		size:  10,
+		//               ↓------------
+		buffer: []uint64{0, 1, 2, 3, 4, 0, 0, 0, 0, 0},
+	}
+
+	if !reflect.DeepEqual(in, wantIn) {
+		t.Fatalf("in = %+v, want %+v", in, wantIn)
+	}
+
+	for i := 5; i < 10; i++ {
+		in.add(uint64(i))
+	}
+
+	wantIn2 := &inflights{
+		start: 0,
+		count: 10,
+		size:  10,
+		//               ↓---------------------------
+		buffer: []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
+	}
+
+	if !reflect.DeepEqual(in, wantIn2) {
+		t.Fatalf("in = %+v, want %+v", in, wantIn2)
+	}
+
+	// rotating case
+	in2 := &inflights{
+		start:  5,
+		size:   10,
+		buffer: make([]uint64, 10),
+	}
+
+	for i := 0; i < 5; i++ {
+		in2.add(uint64(i))
+	}
+
+	wantIn21 := &inflights{
+		start: 5,
+		count: 5,
+		size:  10,
+		//                              ↓------------
+		buffer: []uint64{0, 0, 0, 0, 0, 0, 1, 2, 3, 4},
+	}
+
+	if !reflect.DeepEqual(in2, wantIn21) {
+		t.Fatalf("in = %+v, want %+v", in2, wantIn21)
+	}
+
+	for i := 5; i < 10; i++ {
+		in2.add(uint64(i))
+	}
+
+	wantIn22 := &inflights{
+		start: 5,
+		count: 10,
+		size:  10,
+		//               -------------- ↓------------
+		buffer: []uint64{5, 6, 7, 8, 9, 0, 1, 2, 3, 4},
+	}
+
+	if !reflect.DeepEqual(in2, wantIn22) {
+		t.Fatalf("in = %+v, want %+v", in2, wantIn22)
+	}
+}
+
+func TestInflightFreeTo(t *testing.T) {
+	// no rotating case
+	in := newInflights(10)
+	for i := 0; i < 10; i++ {
+		in.add(uint64(i))
+	}
+
+	in.freeTo(4)
+
+	wantIn := &inflights{
+		start: 5,
+		count: 5,
+		size:  10,
+		//                              ↓------------
+		buffer: []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
+	}
+
+	if !reflect.DeepEqual(in, wantIn) {
+		t.Fatalf("in = %+v, want %+v", in, wantIn)
+	}
+
+	in.freeTo(8)
+
+	wantIn2 := &inflights{
+		start: 9,
+		count: 1,
+		size:  10,
+		//                                          ↓
+		buffer: []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
+	}
+
+	if !reflect.DeepEqual(in, wantIn2) {
+		t.Fatalf("in = %+v, want %+v", in, wantIn2)
+	}
+
+	// rotating case
+	for i := 10; i < 15; i++ {
+		in.add(uint64(i))
+	}
+
+	in.freeTo(12)
+
+	wantIn3 := &inflights{
+		start: 3,
+		count: 2,
+		size:  10,
+		//                           ↓-----
+		buffer: []uint64{10, 11, 12, 13, 14, 5, 6, 7, 8, 9},
+	}
+
+	if !reflect.DeepEqual(in, wantIn3) {
+		t.Fatalf("in = %+v, want %+v", in, wantIn3)
+	}
+}
+
+func TestInflightFreeFirstOne(t *testing.T) {
+	in := newInflights(10)
+	for i := 0; i < 10; i++ {
+		in.add(uint64(i))
+	}
+
+	in.freeFirstOne()
+
+	wantIn := &inflights{
+		start: 1,
+		count: 9,
+		size:  10,
+		//                  ↓------------------------
+		buffer: []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
+	}
+
+	if !reflect.DeepEqual(in, wantIn) {
+		t.Fatalf("in = %+v, want %+v", in, wantIn)
+	}
+}

+ 16 - 6
raft/raft.go

@@ -59,8 +59,9 @@ type raft struct {
 	// the log
 	raftLog *raftLog
 
-	maxMsgSize uint64
-	prs        map[uint64]*Progress
+	maxInflight int
+	maxMsgSize  uint64
+	prs         map[uint64]*Progress
 
 	state StateType
 
@@ -109,13 +110,14 @@ func newRaft(id uint64, peers []uint64, election, heartbeat int, storage Storage
 		// TODO(xiang): add a config arguement into newRaft after we add
 		// the max inflight message field.
 		maxMsgSize:       4 * 1024 * 1024,
+		maxInflight:      256,
 		prs:              make(map[uint64]*Progress),
 		electionTimeout:  election,
 		heartbeatTimeout: heartbeat,
 	}
 	r.rand = rand.New(rand.NewSource(int64(id)))
 	for _, p := range peers {
-		r.prs[p] = &Progress{Next: 1}
+		r.prs[p] = &Progress{Next: 1, ins: newInflights(r.maxInflight)}
 	}
 	if !isHardStateEqual(hs, emptyState) {
 		r.loadState(hs)
@@ -195,7 +197,9 @@ func (r *raft) sendAppend(to uint64) {
 			switch pr.State {
 			// optimistically increase the next when in ProgressStateReplicate
 			case ProgressStateReplicate:
-				pr.optimisticUpdate(m.Entries[n-1].Index)
+				last := m.Entries[n-1].Index
+				pr.optimisticUpdate(last)
+				pr.ins.add(last)
 			case ProgressStateProbe:
 				pr.pause()
 			default:
@@ -265,7 +269,7 @@ func (r *raft) reset(term uint64) {
 	r.elapsed = 0
 	r.votes = make(map[uint64]bool)
 	for i := range r.prs {
-		r.prs[i] = &Progress{Next: r.raftLog.lastIndex() + 1}
+		r.prs[i] = &Progress{Next: r.raftLog.lastIndex() + 1, ins: newInflights(r.maxInflight)}
 		if i == r.id {
 			r.prs[i].Match = r.raftLog.lastIndex()
 		}
@@ -456,6 +460,8 @@ func stepLeader(r *raft, m pb.Message) {
 				case pr.State == ProgressStateSnapshot && pr.maybeSnapshotAbort():
 					raftLogger.Infof("raft: %x snapshot aborted, resumed sending replication messages to %x [%s]", r.id, m.From, pr)
 					pr.becomeProbe()
+				case pr.State == ProgressStateReplicate:
+					pr.ins.freeTo(m.Index)
 				}
 
 				if r.maybeCommit() {
@@ -468,6 +474,10 @@ func stepLeader(r *raft, m pb.Message) {
 			}
 		}
 	case pb.MsgHeartbeatResp:
+		// free one slot for the full inflights window to allow progress.
+		if pr.State == ProgressStateReplicate && pr.ins.full() {
+			pr.ins.freeFirstOne()
+		}
 		if pr.Match < r.raftLog.lastIndex() {
 			r.sendAppend(m.From)
 		}
@@ -661,7 +671,7 @@ func (r *raft) removeNode(id uint64) {
 func (r *raft) resetPendingConf() { r.pendingConf = false }
 
 func (r *raft) setProgress(id, match, next uint64) {
-	r.prs[id] = &Progress{Next: next, Match: match}
+	r.prs[id] = &Progress{Next: next, Match: match, ins: newInflights(r.maxInflight)}
 }
 
 func (r *raft) delProgress(id uint64) {

+ 155 - 0
raft/raft_flow_control_test.go

@@ -0,0 +1,155 @@
+// Copyright 2015 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raft
+
+import (
+	"testing"
+
+	pb "github.com/coreos/etcd/raft/raftpb"
+)
+
+// TestMsgAppFlowControlFull ensures:
+// 1. msgApp can fill the sending window until full
+// 2. when the window is full, no more msgApp can be sent.
+func TestMsgAppFlowControlFull(t *testing.T) {
+	r := newRaft(1, []uint64{1, 2}, 5, 1, NewMemoryStorage(), 0)
+	r.becomeCandidate()
+	r.becomeLeader()
+
+	pr2 := r.prs[2]
+	// force the progress to be in replicate state
+	pr2.becomeReplicate()
+	// fill in the inflights window
+	for i := 0; i < r.maxInflight; i++ {
+		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+		ms := r.readMessages()
+		if len(ms) != 1 {
+			t.Fatalf("#%d: len(ms) = %d, want 1", i, len(ms))
+		}
+	}
+
+	// ensure 1
+	if !pr2.ins.full() {
+		t.Fatalf("inflights.full = %t, want %t", pr2.ins.full(), true)
+	}
+
+	// ensure 2
+	for i := 0; i < 10; i++ {
+		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+		ms := r.readMessages()
+		if len(ms) != 0 {
+			t.Fatalf("#%d: len(ms) = %d, want 0", i, len(ms))
+		}
+	}
+}
+
+// TestMsgAppFlowControlMoveForward ensures msgAppResp can move
+// forward the sending window correctly:
+// 1. vaild msgAppResp.index moves the windows to pass all smaller or equal index.
+// 2. out-of-dated msgAppResp has no effect on the silding window.
+func TestMsgAppFlowControlMoveForward(t *testing.T) {
+	r := newRaft(1, []uint64{1, 2}, 5, 1, NewMemoryStorage(), 0)
+	r.becomeCandidate()
+	r.becomeLeader()
+
+	pr2 := r.prs[2]
+	// force the progress to be in replicate state
+	pr2.becomeReplicate()
+	// fill in the inflights window
+	for i := 0; i < r.maxInflight; i++ {
+		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+		r.readMessages()
+	}
+
+	// 1 is noop, 2 is the first proposal we just sent.
+	// so we start with 2.
+	for tt := 2; tt < r.maxInflight; tt++ {
+		// move forward the window
+		r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: uint64(tt)})
+		r.readMessages()
+
+		// fill in the inflights window again
+		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+		ms := r.readMessages()
+		if len(ms) != 1 {
+			t.Fatalf("#%d: len(ms) = %d, want 1", tt, len(ms))
+		}
+
+		// ensure 1
+		if !pr2.ins.full() {
+			t.Fatalf("inflights.full = %t, want %t", pr2.ins.full(), true)
+		}
+
+		// ensure 2
+		for i := 0; i < tt; i++ {
+			r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: uint64(i)})
+			if !pr2.ins.full() {
+				t.Fatalf("#%d: inflights.full = %t, want %t", tt, pr2.ins.full(), true)
+			}
+		}
+	}
+}
+
+// TestMsgAppFlowControlRecvHeartbeat ensures a heartbeat response
+// frees one slot if the window is full.
+func TestMsgAppFlowControlRecvHeartbeat(t *testing.T) {
+	r := newRaft(1, []uint64{1, 2}, 5, 1, NewMemoryStorage(), 0)
+	r.becomeCandidate()
+	r.becomeLeader()
+
+	pr2 := r.prs[2]
+	// force the progress to be in replicate state
+	pr2.becomeReplicate()
+	// fill in the inflights window
+	for i := 0; i < r.maxInflight; i++ {
+		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+		r.readMessages()
+	}
+
+	for tt := 1; tt < 5; tt++ {
+		if !pr2.ins.full() {
+			t.Fatalf("#%d: inflights.full = %t, want %t", tt, pr2.ins.full(), true)
+		}
+
+		// recv tt msgHeartbeatResp and expect one free slot
+		for i := 0; i < tt; i++ {
+			r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgHeartbeatResp})
+			r.readMessages()
+			if pr2.ins.full() {
+				t.Fatalf("#%d.%d: inflights.full = %t, want %t", tt, i, pr2.ins.full(), false)
+			}
+		}
+
+		// one slot
+		r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+		ms := r.readMessages()
+		if len(ms) != 1 {
+			t.Fatalf("#%d: free slot = 0, want 1", tt)
+		}
+
+		// and just one slot
+		for i := 0; i < 10; i++ {
+			r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
+			ms1 := r.readMessages()
+			if len(ms1) != 0 {
+				t.Fatalf("#%d.%d: len(ms) = %d, want 0", tt, i, len(ms1))
+			}
+		}
+
+		// clear all pending messages.
+		r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgHeartbeatResp})
+		r.readMessages()
+	}
+}

+ 9 - 8
raft/raft_test.go

@@ -48,24 +48,24 @@ func (r *raft) readMessages() []pb.Message {
 	return msgs
 }
 
-func TestBecomeProbe(t *testing.T) {
+func TestProgressBecomeProbe(t *testing.T) {
 	match := uint64(1)
 	tests := []struct {
 		p     *Progress
 		wnext uint64
 	}{
 		{
-			&Progress{State: ProgressStateReplicate, Match: match, Next: 5},
+			&Progress{State: ProgressStateReplicate, Match: match, Next: 5, ins: newInflights(256)},
 			2,
 		},
 		{
 			// snapshot finish
-			&Progress{State: ProgressStateSnapshot, Match: match, Next: 5, PendingSnapshot: 10},
+			&Progress{State: ProgressStateSnapshot, Match: match, Next: 5, PendingSnapshot: 10, ins: newInflights(256)},
 			11,
 		},
 		{
 			// snapshot failure
-			&Progress{State: ProgressStateSnapshot, Match: match, Next: 5, PendingSnapshot: 0},
+			&Progress{State: ProgressStateSnapshot, Match: match, Next: 5, PendingSnapshot: 0, ins: newInflights(256)},
 			2,
 		},
 	}
@@ -83,8 +83,8 @@ func TestBecomeProbe(t *testing.T) {
 	}
 }
 
-func TestBecomeReplicate(t *testing.T) {
-	p := &Progress{State: ProgressStateProbe, Match: 1, Next: 5}
+func TestProgressBecomeReplicate(t *testing.T) {
+	p := &Progress{State: ProgressStateProbe, Match: 1, Next: 5, ins: newInflights(256)}
 	p.becomeReplicate()
 
 	if p.State != ProgressStateReplicate {
@@ -98,8 +98,8 @@ func TestBecomeReplicate(t *testing.T) {
 	}
 }
 
-func TestBecomeSnapshot(t *testing.T) {
-	p := &Progress{State: ProgressStateProbe, Match: 1, Next: 5}
+func TestProgressBecomeSnapshot(t *testing.T) {
+	p := &Progress{State: ProgressStateProbe, Match: 1, Next: 5, ins: newInflights(256)}
 	p.becomeSnapshot(10)
 
 	if p.State != ProgressStateSnapshot {
@@ -234,6 +234,7 @@ func TestProgressIsPaused(t *testing.T) {
 		p := &Progress{
 			State:  tt.state,
 			Paused: tt.paused,
+			ins:    newInflights(256),
 		}
 		if g := p.isPaused(); g != tt.w {
 			t.Errorf("#%d: shouldwait = %t, want %t", i, g, tt.w)