Browse Source

Merge pull request #6807 from xiang90/fix_raft_test

rafttest: make raft test reliable
Xiang Li 9 years ago
parent
commit
476ff67047
3 changed files with 101 additions and 36 deletions
  1. 13 1
      raft/rafttest/network.go
  2. 7 2
      raft/rafttest/node.go
  3. 81 33
      raft/rafttest/node_test.go

+ 13 - 1
raft/rafttest/network.go

@@ -100,8 +100,20 @@ func (rn *raftNetwork) send(m raftpb.Message) {
 		time.Sleep(time.Duration(rd))
 		time.Sleep(time.Duration(rd))
 	}
 	}
 
 
+	// use marshal/unmarshal to copy message to avoid data race.
+	b, err := m.Marshal()
+	if err != nil {
+		panic(err)
+	}
+
+	var cm raftpb.Message
+	err = cm.Unmarshal(b)
+	if err != nil {
+		panic(err)
+	}
+
 	select {
 	select {
-	case to <- m:
+	case to <- cm:
 	default:
 	default:
 		// drop messages when the receiver queue is full.
 		// drop messages when the receiver queue is full.
 	}
 	}

+ 7 - 2
raft/rafttest/node.go

@@ -16,6 +16,7 @@ package rafttest
 
 
 import (
 import (
 	"log"
 	"log"
+	"sync"
 	"time"
 	"time"
 
 
 	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/raft"
@@ -32,7 +33,9 @@ type node struct {
 
 
 	// stable
 	// stable
 	storage *raft.MemoryStorage
 	storage *raft.MemoryStorage
-	state   raftpb.HardState
+
+	mu    sync.Mutex // guards state
+	state raftpb.HardState
 }
 }
 
 
 func startNode(id uint64, peers []raft.Peer, iface iface) *node {
 func startNode(id uint64, peers []raft.Peer, iface iface) *node {
@@ -68,7 +71,9 @@ func (n *node) start() {
 				n.Tick()
 				n.Tick()
 			case rd := <-n.Ready():
 			case rd := <-n.Ready():
 				if !raft.IsEmptyHardState(rd.HardState) {
 				if !raft.IsEmptyHardState(rd.HardState) {
+					n.mu.Lock()
 					n.state = rd.HardState
 					n.state = rd.HardState
+					n.mu.Unlock()
 					n.storage.SetHardState(n.state)
 					n.storage.SetHardState(n.state)
 				}
 				}
 				n.storage.Append(rd.Entries)
 				n.storage.Append(rd.Entries)
@@ -79,7 +84,7 @@ func (n *node) start() {
 				}
 				}
 				n.Advance()
 				n.Advance()
 			case m := <-n.iface.recv():
 			case m := <-n.iface.recv():
-				n.Step(context.TODO(), m)
+				go n.Step(context.TODO(), m)
 			case <-n.stopc:
 			case <-n.stopc:
 				n.Stop()
 				n.Stop()
 				log.Printf("raft.%d: stop", n.id)
 				log.Printf("raft.%d: stop", n.id)

+ 81 - 33
raft/rafttest/node_test.go

@@ -33,18 +33,18 @@ func TestBasicProgress(t *testing.T) {
 		nodes = append(nodes, n)
 		nodes = append(nodes, n)
 	}
 	}
 
 
-	time.Sleep(10 * time.Millisecond)
+	waitLeader(nodes)
 
 
-	for i := 0; i < 10000; i++ {
+	for i := 0; i < 100; i++ {
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 	}
 	}
 
 
-	time.Sleep(500 * time.Millisecond)
+	if !waitCommitConverge(nodes, 100) {
+		t.Errorf("commits failed to converge!")
+	}
+
 	for _, n := range nodes {
 	for _, n := range nodes {
 		n.stop()
 		n.stop()
-		if n.state.Commit != 10006 {
-			t.Errorf("commit = %d, want = 10006", n.state.Commit)
-		}
 	}
 	}
 }
 }
 
 
@@ -59,31 +59,32 @@ func TestRestart(t *testing.T) {
 		nodes = append(nodes, n)
 		nodes = append(nodes, n)
 	}
 	}
 
 
-	time.Sleep(50 * time.Millisecond)
-	for i := 0; i < 300; i++ {
-		nodes[0].Propose(context.TODO(), []byte("somedata"))
+	l := waitLeader(nodes)
+	k1, k2 := (l+1)%5, (l+2)%5
+
+	for i := 0; i < 30; i++ {
+		nodes[l].Propose(context.TODO(), []byte("somedata"))
 	}
 	}
-	nodes[1].stop()
-	for i := 0; i < 300; i++ {
-		nodes[0].Propose(context.TODO(), []byte("somedata"))
+	nodes[k1].stop()
+	for i := 0; i < 30; i++ {
+		nodes[(l+3)%5].Propose(context.TODO(), []byte("somedata"))
 	}
 	}
-	nodes[2].stop()
-	for i := 0; i < 300; i++ {
-		nodes[0].Propose(context.TODO(), []byte("somedata"))
+	nodes[k2].stop()
+	for i := 0; i < 30; i++ {
+		nodes[(l+4)%5].Propose(context.TODO(), []byte("somedata"))
 	}
 	}
-	nodes[2].restart()
-	for i := 0; i < 300; i++ {
-		nodes[0].Propose(context.TODO(), []byte("somedata"))
+	nodes[k2].restart()
+	for i := 0; i < 30; i++ {
+		nodes[l].Propose(context.TODO(), []byte("somedata"))
+	}
+	nodes[k1].restart()
+
+	if !waitCommitConverge(nodes, 120) {
+		t.Errorf("commits failed to converge!")
 	}
 	}
-	nodes[1].restart()
 
 
-	// give some time for nodes to catch up with the raft leader
-	time.Sleep(500 * time.Millisecond)
 	for _, n := range nodes {
 	for _, n := range nodes {
 		n.stop()
 		n.stop()
-		if n.state.Commit != 1206 {
-			t.Errorf("commit = %d, want = 1206", n.state.Commit)
-		}
 	}
 	}
 }
 }
 
 
@@ -98,30 +99,77 @@ func TestPause(t *testing.T) {
 		nodes = append(nodes, n)
 		nodes = append(nodes, n)
 	}
 	}
 
 
-	time.Sleep(50 * time.Millisecond)
-	for i := 0; i < 300; i++ {
+	waitLeader(nodes)
+
+	for i := 0; i < 30; i++ {
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 	}
 	}
 	nodes[1].pause()
 	nodes[1].pause()
-	for i := 0; i < 300; i++ {
+	for i := 0; i < 30; i++ {
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 	}
 	}
 	nodes[2].pause()
 	nodes[2].pause()
-	for i := 0; i < 300; i++ {
+	for i := 0; i < 30; i++ {
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 	}
 	}
 	nodes[2].resume()
 	nodes[2].resume()
-	for i := 0; i < 300; i++ {
+	for i := 0; i < 30; i++ {
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 		nodes[0].Propose(context.TODO(), []byte("somedata"))
 	}
 	}
 	nodes[1].resume()
 	nodes[1].resume()
 
 
-	// give some time for nodes to catch up with the raft leader
-	time.Sleep(300 * time.Millisecond)
+	if !waitCommitConverge(nodes, 120) {
+		t.Errorf("commits failed to converge!")
+	}
+
 	for _, n := range nodes {
 	for _, n := range nodes {
 		n.stop()
 		n.stop()
-		if n.state.Commit != 1206 {
-			t.Errorf("commit = %d, want = 1206", n.state.Commit)
+	}
+}
+
+func waitLeader(ns []*node) int {
+	var l map[uint64]struct{}
+	var lindex int
+
+	for {
+		l = make(map[uint64]struct{})
+
+		for i, n := range ns {
+			lead := n.Status().SoftState.Lead
+			if lead != 0 {
+				l[lead] = struct{}{}
+				if n.id == lead {
+					lindex = i
+				}
+			}
+		}
+
+		if len(l) == 1 {
+			return lindex
+		}
+	}
+}
+
+func waitCommitConverge(ns []*node, target uint64) bool {
+	var c map[uint64]struct{}
+
+	for i := 0; i < 50; i++ {
+		c = make(map[uint64]struct{})
+		var good int
+
+		for _, n := range ns {
+			commit := n.Node.Status().HardState.Commit
+			c[commit] = struct{}{}
+			if commit > target {
+				good++
+			}
 		}
 		}
+
+		if len(c) == 1 && good == len(ns) {
+			return true
+		}
+		time.Sleep(100 * time.Millisecond)
 	}
 	}
+
+	return false
 }
 }