Browse Source

rafttest: add network drop

Xiang Li 11 years ago
parent
commit
d423946fa4
2 changed files with 56 additions and 2 deletions
  1. 20 2
      raft/rafttest/network.go
  2. 36 0
      raft/rafttest/network_test.go

+ 20 - 2
raft/rafttest/network.go

@@ -1,6 +1,7 @@
 package rafttest
 
 import (
+	"math/rand"
 	"sync"
 	"time"
 
@@ -31,12 +32,18 @@ type network interface {
 type raftNetwork struct {
 	mu           sync.Mutex
 	disconnected map[uint64]bool
+	dropmap      map[conn]float64
 	recvQueues   map[uint64]chan raftpb.Message
 }
 
+type conn struct {
+	from, to uint64
+}
+
 func newRaftNetwork(nodes ...uint64) *raftNetwork {
 	pn := &raftNetwork{
 		recvQueues:   make(map[uint64]chan raftpb.Message),
+		dropmap:      make(map[conn]float64),
 		disconnected: make(map[uint64]bool),
 	}
 
@@ -56,11 +63,16 @@ func (rn *raftNetwork) send(m raftpb.Message) {
 	if rn.disconnected[m.To] {
 		to = nil
 	}
+	drop := rn.dropmap[conn{m.From, m.To}]
 	rn.mu.Unlock()
 
 	if to == nil {
 		return
 	}
+	if drop != 0 && rand.Float64() < drop {
+		return
+	}
+
 	to <- m
 }
 
@@ -76,14 +88,20 @@ func (rn *raftNetwork) recvFrom(from uint64) chan raftpb.Message {
 }
 
 func (rn *raftNetwork) drop(from, to uint64, rate float64) {
-	panic("unimplemented")
+	rn.mu.Lock()
+	defer rn.mu.Unlock()
+	rn.dropmap[conn{from, to}] = rate
 }
 
 func (rn *raftNetwork) delay(from, to uint64, d time.Duration, rate float64) {
 	panic("unimplemented")
 }
 
-func (rn *raftNetwork) heal() {}
+func (rn *raftNetwork) heal() {
+	rn.mu.Lock()
+	defer rn.mu.Unlock()
+	rn.dropmap = make(map[conn]float64)
+}
 
 func (rn *raftNetwork) disconnect(id uint64) {
 	rn.mu.Lock()

+ 36 - 0
raft/rafttest/network_test.go

@@ -0,0 +1,36 @@
+package rafttest
+
+import (
+	"testing"
+
+	"github.com/coreos/etcd/raft/raftpb"
+)
+
+func TestNetworkDrop(t *testing.T) {
+	// drop around 10% messages
+	sent := 1000
+	droprate := 0.1
+	nt := newRaftNetwork(1, 2)
+	nt.drop(1, 2, droprate)
+	for i := 0; i < sent; i++ {
+		nt.send(raftpb.Message{From: 1, To: 2})
+	}
+
+	c := nt.recvFrom(2)
+
+	received := 0
+	done := false
+	for !done {
+		select {
+		case <-c:
+			received++
+		default:
+			done = true
+		}
+	}
+
+	drop := sent - received
+	if drop > int((droprate+0.1)*float64(sent)) || drop < int((droprate-0.1)*float64(sent)) {
+		t.Errorf("drop = %d, want around %d", drop, droprate*float64(sent))
+	}
+}