Browse Source

raftest: add restart and related simple test

Xiang Li 11 years ago
parent
commit
b147a6328d
3 changed files with 105 additions and 17 deletions
  1. 33 5
      raft/rafttest/network.go
  2. 31 10
      raft/rafttest/node.go
  3. 41 2
      raft/rafttest/node_test.go

+ 33 - 5
raft/rafttest/network.go

@@ -1,6 +1,7 @@
 package rafttest
 
 import (
+	"sync"
 	"time"
 
 	"github.com/coreos/etcd/raft/raftpb"
@@ -14,15 +15,21 @@ type network interface {
 	// delay message for (0, d] randomly at given rate (1.0 delay all messages)
 	// do we need rate here?
 	delay(from, to uint64, d time.Duration, rate float64)
+
+	disconnect(id uint64)
+	connect(id uint64)
 }
 
 type raftNetwork struct {
-	recvQueues map[uint64]chan raftpb.Message
+	mu           sync.Mutex
+	disconnected map[uint64]bool
+	recvQueues   map[uint64]chan raftpb.Message
 }
 
 func newRaftNetwork(nodes ...uint64) *raftNetwork {
 	pn := &raftNetwork{
-		recvQueues: make(map[uint64]chan raftpb.Message, 0),
+		recvQueues:   make(map[uint64]chan raftpb.Message),
+		disconnected: make(map[uint64]bool),
 	}
 
 	for _, n := range nodes {
@@ -36,18 +43,27 @@ func (rn *raftNetwork) nodeNetwork(id uint64) *nodeNetwork {
 }
 
 func (rn *raftNetwork) send(m raftpb.Message) {
+	rn.mu.Lock()
 	to := rn.recvQueues[m.To]
+	if rn.disconnected[m.To] {
+		to = nil
+	}
+	rn.mu.Unlock()
+
 	if to == nil {
-		panic("sent to nil")
+		return
 	}
 	to <- m
 }
 
 func (rn *raftNetwork) recvFrom(from uint64) chan raftpb.Message {
+	rn.mu.Lock()
 	fromc := rn.recvQueues[from]
-	if fromc == nil {
-		panic("recv from nil")
+	if rn.disconnected[from] {
+		fromc = nil
 	}
+	rn.mu.Unlock()
+
 	return fromc
 }
 
@@ -59,6 +75,18 @@ func (rn *raftNetwork) delay(from, to uint64, d time.Duration, rate float64) {
 	panic("unimplemented")
 }
 
+func (rn *raftNetwork) disconnect(id uint64) {
+	rn.mu.Lock()
+	defer rn.mu.Unlock()
+	rn.disconnected[id] = true
+}
+
+func (rn *raftNetwork) connect(id uint64) {
+	rn.mu.Lock()
+	defer rn.mu.Unlock()
+	rn.disconnected[id] = false
+}
+
 type nodeNetwork struct {
 	id uint64
 	*raftNetwork

+ 31 - 10
raft/rafttest/node.go

@@ -11,6 +11,7 @@ import (
 
 type node struct {
 	raft.Node
+	id     uint64
 	paused bool
 	nt     network
 	stopc  chan struct{}
@@ -25,12 +26,18 @@ func startNode(id uint64, peers []raft.Peer, nt network) *node {
 	rn := raft.StartNode(id, peers, 10, 1, st)
 	n := &node{
 		Node:    rn,
+		id:      id,
 		storage: st,
 		nt:      nt,
-		stopc:   make(chan struct{}),
 	}
+	n.start()
+	return n
+}
 
+func (n *node) start() {
+	n.stopc = make(chan struct{})
 	ticker := time.Tick(5 * time.Millisecond)
+
 	go func() {
 		for {
 			select {
@@ -39,32 +46,46 @@ func startNode(id uint64, peers []raft.Peer, nt network) *node {
 			case rd := <-n.Ready():
 				if !raft.IsEmptyHardState(rd.HardState) {
 					n.state = rd.HardState
+					n.storage.SetHardState(n.state)
 				}
 				n.storage.Append(rd.Entries)
 				go func() {
 					for _, m := range rd.Messages {
-						nt.send(m)
+						n.nt.send(m)
 					}
 				}()
 				n.Advance()
 			case m := <-n.nt.recv():
 				n.Step(context.TODO(), m)
 			case <-n.stopc:
-				log.Printf("raft.%d: stop", id)
+				n.Stop()
+				log.Printf("raft.%d: stop", n.id)
+				n.Node = nil
+				close(n.stopc)
 				return
 			}
 		}
 	}()
-	return n
 }
 
-func (n *node) stop() { close(n.stopc) }
-
-// restart restarts the node with the given delay.
-// All in memory state of node is reset to initialized state.
+// stop stops the node. stop a stopped node might panic.
+// All in memory state of node is discarded.
 // All stable MUST be unchanged.
-func (n *node) restart(delay time.Duration) {
-	panic("unimplemented")
+func (n *node) stop() {
+	n.nt.disconnect(n.id)
+	n.stopc <- struct{}{}
+	// wait for the shutdown
+	<-n.stopc
+}
+
+// restart restarts the node. restart a started node
+// blocks and might affect the future stop operation.
+func (n *node) restart() {
+	// wait for the shutdown
+	<-n.stopc
+	n.Node = raft.RestartNode(n.id, 10, 1, n.storage, 0)
+	n.start()
+	n.nt.connect(n.id)
 }
 
 // pause pauses the node.

+ 41 - 2
raft/rafttest/node_test.go

@@ -27,8 +27,47 @@ func TestBasicProgress(t *testing.T) {
 	time.Sleep(100 * time.Millisecond)
 	for _, n := range nodes {
 		n.stop()
-		if n.state.Commit < 1000 {
-			t.Errorf("commit = %d, want > 1000", n.state.Commit)
+		if n.state.Commit != 1006 {
+			t.Errorf("commit = %d, want = 1006", n.state.Commit)
+		}
+	}
+}
+
+func TestRestart(t *testing.T) {
+	peers := []raft.Peer{{1, nil}, {2, nil}, {3, nil}, {4, nil}, {5, nil}}
+	nt := newRaftNetwork(1, 2, 3, 4, 5)
+
+	nodes := make([]*node, 0)
+
+	for i := 1; i <= 5; i++ {
+		n := startNode(uint64(i), peers, nt.nodeNetwork(uint64(i)))
+		nodes = append(nodes, n)
+	}
+
+	time.Sleep(50 * time.Millisecond)
+	for i := 0; i < 300; i++ {
+		nodes[0].Propose(context.TODO(), []byte("somedata"))
+	}
+	nodes[1].stop()
+	for i := 0; i < 300; i++ {
+		nodes[0].Propose(context.TODO(), []byte("somedata"))
+	}
+	nodes[2].stop()
+	for i := 0; i < 300; i++ {
+		nodes[0].Propose(context.TODO(), []byte("somedata"))
+	}
+	nodes[2].restart()
+	for i := 0; i < 300; i++ {
+		nodes[0].Propose(context.TODO(), []byte("somedata"))
+	}
+	nodes[1].restart()
+
+	// give some time for nodes to catch up with the raft leader
+	time.Sleep(300 * time.Millisecond)
+	for _, n := range nodes {
+		n.stop()
+		if n.state.Commit != 1206 {
+			t.Errorf("commit = %d, want = 1206", n.state.Commit)
 		}
 	}
 }