Browse Source

raft: Propose in raft node wait the proposal result so we can fail fast while dropping proposal.

Vincent Lee 8 years ago
parent
commit
f0dffb4163
2 changed files with 101 additions and 16 deletions
  1. 55 14
      raft/node.go
  2. 46 2
      raft/node_test.go

+ 55 - 14
raft/node.go

@@ -224,9 +224,14 @@ func RestartNode(c *Config) Node {
 	return &n
 }
 
+type msgWithResult struct {
+	m      pb.Message
+	result chan error
+}
+
 // node is the canonical implementation of the Node interface
 type node struct {
-	propc      chan pb.Message
+	propc      chan msgWithResult
 	recvc      chan pb.Message
 	confc      chan pb.ConfChange
 	confstatec chan pb.ConfState
@@ -242,7 +247,7 @@ type node struct {
 
 func newNode() node {
 	return node{
-		propc:      make(chan pb.Message),
+		propc:      make(chan msgWithResult),
 		recvc:      make(chan pb.Message),
 		confc:      make(chan pb.ConfChange),
 		confstatec: make(chan pb.ConfState),
@@ -271,7 +276,7 @@ func (n *node) Stop() {
 }
 
 func (n *node) run(r *raft) {
-	var propc chan pb.Message
+	var propc chan msgWithResult
 	var readyc chan Ready
 	var advancec chan struct{}
 	var prevLastUnstablei, prevLastUnstablet uint64
@@ -314,13 +319,18 @@ func (n *node) run(r *raft) {
 		// TODO: maybe buffer the config propose if there exists one (the way
 		// described in raft dissertation)
 		// Currently it is dropped in Step silently.
-		case m := <-propc:
+		case pm := <-propc:
+			m := pm.m
 			m.From = r.id
-			r.Step(m)
+			err := r.Step(m)
+			if pm.result != nil {
+				pm.result <- err
+				close(pm.result)
+			}
 		case m := <-n.recvc:
 			// filter out response message from unknown From.
 			if pr := r.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) {
-				r.Step(m) // raft never returns an error
+				r.Step(m)
 			}
 		case cc := <-n.confc:
 			if cc.NodeID == None {
@@ -408,7 +418,7 @@ func (n *node) Tick() {
 func (n *node) Campaign(ctx context.Context) error { return n.step(ctx, pb.Message{Type: pb.MsgHup}) }
 
 func (n *node) Propose(ctx context.Context, data []byte) error {
-	return n.step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}})
+	return n.stepWait(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}})
 }
 
 func (n *node) Step(ctx context.Context, m pb.Message) error {
@@ -428,22 +438,53 @@ func (n *node) ProposeConfChange(ctx context.Context, cc pb.ConfChange) error {
 	return n.Step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Type: pb.EntryConfChange, Data: data}}})
 }
 
+func (n *node) step(ctx context.Context, m pb.Message) error {
+	return n.stepWithWaitOption(ctx, m, false)
+}
+
+func (n *node) stepWait(ctx context.Context, m pb.Message) error {
+	return n.stepWithWaitOption(ctx, m, true)
+}
+
 // Step advances the state machine using msgs. The ctx.Err() will be returned,
 // if any.
-func (n *node) step(ctx context.Context, m pb.Message) error {
-	ch := n.recvc
-	if m.Type == pb.MsgProp {
-		ch = n.propc
+func (n *node) stepWithWaitOption(ctx context.Context, m pb.Message, wait bool) error {
+	if m.Type != pb.MsgProp {
+		select {
+		case n.recvc <- m:
+			return nil
+		case <-ctx.Done():
+			return ctx.Err()
+		case <-n.done:
+			return ErrStopped
+		}
+	}
+	ch := n.propc
+	pm := msgWithResult{m: m}
+	if wait {
+		pm.result = make(chan error, 1)
 	}
-
 	select {
-	case ch <- m:
-		return nil
+	case ch <- pm:
+		if !wait {
+			return nil
+		}
+	case <-ctx.Done():
+		return ctx.Err()
+	case <-n.done:
+		return ErrStopped
+	}
+	select {
+	case rsp := <-pm.result:
+		if rsp != nil {
+			return rsp
+		}
 	case <-ctx.Done():
 		return ctx.Err()
 	case <-n.done:
 		return ErrStopped
 	}
+	return nil
 }
 
 func (n *node) Ready() <-chan Ready { return n.readyc }

+ 46 - 2
raft/node_test.go

@@ -18,6 +18,7 @@ import (
 	"bytes"
 	"context"
 	"reflect"
+	"strings"
 	"testing"
 	"time"
 
@@ -30,7 +31,7 @@ import (
 func TestNodeStep(t *testing.T) {
 	for i, msgn := range raftpb.MessageType_name {
 		n := &node{
-			propc: make(chan raftpb.Message, 1),
+			propc: make(chan msgWithResult, 1),
 			recvc: make(chan raftpb.Message, 1),
 		}
 		msgt := raftpb.MessageType(i)
@@ -64,7 +65,7 @@ func TestNodeStep(t *testing.T) {
 func TestNodeStepUnblock(t *testing.T) {
 	// a node without buffer to block step
 	n := &node{
-		propc: make(chan raftpb.Message),
+		propc: make(chan msgWithResult),
 		done:  make(chan struct{}),
 	}
 
@@ -433,6 +434,49 @@ func TestBlockProposal(t *testing.T) {
 	}
 }
 
+func TestNodeProposeWaitDropped(t *testing.T) {
+	msgs := []raftpb.Message{}
+	droppingMsg := []byte("test_dropping")
+	dropStep := func(r *raft, m raftpb.Message) error {
+		if m.Type == raftpb.MsgProp && strings.Contains(m.String(), string(droppingMsg)) {
+			t.Logf("dropping message: %v", m.String())
+			return ErrProposalDropped
+		}
+		msgs = append(msgs, m)
+		return nil
+	}
+
+	n := newNode()
+	s := NewMemoryStorage()
+	r := newTestRaft(1, []uint64{1}, 10, 1, s)
+	go n.run(r)
+	n.Campaign(context.TODO())
+	for {
+		rd := <-n.Ready()
+		s.Append(rd.Entries)
+		// change the step function to dropStep until this raft becomes leader
+		if rd.SoftState.Lead == r.id {
+			r.step = dropStep
+			n.Advance()
+			break
+		}
+		n.Advance()
+	}
+	proposalTimeout := time.Millisecond * 100
+	ctx, cancel := context.WithTimeout(context.Background(), proposalTimeout)
+	// propose with cancel should be cancelled earyly if dropped
+	err := n.Propose(ctx, droppingMsg)
+	if err != ErrProposalDropped {
+		t.Errorf("should drop proposal : %v", err)
+	}
+	cancel()
+
+	n.Stop()
+	if len(msgs) != 0 {
+		t.Fatalf("len(msgs) = %d, want %d", len(msgs), 1)
+	}
+}
+
 // TestNodeTick ensures that node.Tick() will increase the
 // elapsed of the underlying raft state machine.
 func TestNodeTick(t *testing.T) {