Browse Source

Merge pull request #2401 from yichengq/331

rafthttp: add unit tests and SendMsgApp benchmark
Yicheng Qin 10 years ago
parent
commit
4dd3be0f05

+ 12 - 8
rafthttp/http.go

@@ -42,11 +42,15 @@ func NewHandler(r Raft, cid types.ID) http.Handler {
 	}
 }
 
-func newStreamHandler(tr *transport, id, cid types.ID) http.Handler {
+type peerGetter interface {
+	Get(id types.ID) Peer
+}
+
+func newStreamHandler(peerGetter peerGetter, id, cid types.ID) http.Handler {
 	return &streamHandler{
-		tr:  tr,
-		id:  id,
-		cid: cid,
+		peerGetter: peerGetter,
+		id:         id,
+		cid:        cid,
 	}
 }
 
@@ -107,9 +111,9 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 }
 
 type streamHandler struct {
-	tr  *transport
-	id  types.ID
-	cid types.ID
+	peerGetter peerGetter
+	id         types.ID
+	cid        types.ID
 }
 
 func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -141,7 +145,7 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "invalid from", http.StatusNotFound)
 		return
 	}
-	p := h.tr.Peer(from)
+	p := h.peerGetter.Get(from)
 	if p == nil {
 		log.Printf("rafthttp: fail to find sender %s", from)
 		http.Error(w, "error sender not found", http.StatusNotFound)

+ 197 - 20
rafthttp/http_test.go

@@ -22,6 +22,7 @@ import (
 	"net/http/httptest"
 	"strings"
 	"testing"
+	"time"
 
 	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
 	"github.com/coreos/etcd/pkg/pbutil"
@@ -45,7 +46,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			bytes.NewReader(
 				pbutil.MustMarshal(&raftpb.Message{}),
 			),
-			&nopProcessor{},
+			&fakeRaft{},
 			"0",
 			http.StatusMethodNotAllowed,
 		},
@@ -55,7 +56,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			bytes.NewReader(
 				pbutil.MustMarshal(&raftpb.Message{}),
 			),
-			&nopProcessor{},
+			&fakeRaft{},
 			"0",
 			http.StatusMethodNotAllowed,
 		},
@@ -65,7 +66,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			bytes.NewReader(
 				pbutil.MustMarshal(&raftpb.Message{}),
 			),
-			&nopProcessor{},
+			&fakeRaft{},
 			"0",
 			http.StatusMethodNotAllowed,
 		},
@@ -73,7 +74,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			// bad request body
 			"POST",
 			&errReader{},
-			&nopProcessor{},
+			&fakeRaft{},
 			"0",
 			http.StatusBadRequest,
 		},
@@ -81,7 +82,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			// bad request protobuf
 			"POST",
 			strings.NewReader("malformed garbage"),
-			&nopProcessor{},
+			&fakeRaft{},
 			"0",
 			http.StatusBadRequest,
 		},
@@ -91,7 +92,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			bytes.NewReader(
 				pbutil.MustMarshal(&raftpb.Message{}),
 			),
-			&nopProcessor{},
+			&fakeRaft{},
 			"1",
 			http.StatusPreconditionFailed,
 		},
@@ -101,7 +102,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			bytes.NewReader(
 				pbutil.MustMarshal(&raftpb.Message{}),
 			),
-			&errProcessor{
+			&fakeRaft{
 				err: &resWriterToError{code: http.StatusForbidden},
 			},
 			"0",
@@ -113,7 +114,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			bytes.NewReader(
 				pbutil.MustMarshal(&raftpb.Message{}),
 			),
-			&errProcessor{
+			&fakeRaft{
 				err: &resWriterToError{code: http.StatusInternalServerError},
 			},
 			"0",
@@ -125,7 +126,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			bytes.NewReader(
 				pbutil.MustMarshal(&raftpb.Message{}),
 			),
-			&errProcessor{err: errors.New("blah")},
+			&fakeRaft{err: errors.New("blah")},
 			"0",
 			http.StatusInternalServerError,
 		},
@@ -135,7 +136,7 @@ func TestServeRaftPrefix(t *testing.T) {
 			bytes.NewReader(
 				pbutil.MustMarshal(&raftpb.Message{}),
 			),
-			&nopProcessor{},
+			&fakeRaft{},
 			"0",
 			http.StatusNoContent,
 		},
@@ -155,24 +156,177 @@ func TestServeRaftPrefix(t *testing.T) {
 	}
 }
 
+func TestServeRaftStreamPrefix(t *testing.T) {
+	tests := []struct {
+		path  string
+		wtype streamType
+	}{
+		{
+			RaftStreamPrefix + "/message/1",
+			streamTypeMessage,
+		},
+		{
+			RaftStreamPrefix + "/msgapp/1",
+			streamTypeMsgApp,
+		},
+		// backward compatibility
+		{
+			RaftStreamPrefix + "/1",
+			streamTypeMsgApp,
+		},
+	}
+	for i, tt := range tests {
+		req, err := http.NewRequest("GET", "http://localhost:7001"+tt.path, nil)
+		if err != nil {
+			t.Fatalf("#%d: could not create request: %#v", i, err)
+		}
+		req.Header.Set("X-Etcd-Cluster-ID", "1")
+		req.Header.Set("X-Raft-To", "2")
+		wterm := "1"
+		req.Header.Set("X-Raft-Term", wterm)
+
+		peer := newFakePeer()
+		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
+		h := newStreamHandler(peerGetter, types.ID(2), types.ID(1))
+
+		rw := httptest.NewRecorder()
+		go h.ServeHTTP(rw, req)
+
+		var conn *outgoingConn
+		select {
+		case conn = <-peer.connc:
+		case <-time.After(time.Second):
+			t.Fatalf("#%d: failed to attach outgoingConn", i)
+		}
+		if conn.t != tt.wtype {
+			t.Errorf("$%d: type = %s, want %s", i, conn.t, tt.wtype)
+		}
+		if conn.termStr != wterm {
+			t.Errorf("$%d: term = %s, want %s", i, conn.termStr, wterm)
+		}
+		conn.Close()
+	}
+}
+
+func TestServeRaftStreamPrefixBad(t *testing.T) {
+	tests := []struct {
+		method    string
+		path      string
+		clusterID string
+		remote    string
+
+		wcode int
+	}{
+		// bad method
+		{
+			"PUT",
+			RaftStreamPrefix + "/message/1",
+			"1",
+			"1",
+			http.StatusMethodNotAllowed,
+		},
+		// bad method
+		{
+			"POST",
+			RaftStreamPrefix + "/message/1",
+			"1",
+			"1",
+			http.StatusMethodNotAllowed,
+		},
+		// bad method
+		{
+			"DELETE",
+			RaftStreamPrefix + "/message/1",
+			"1",
+			"1",
+			http.StatusMethodNotAllowed,
+		},
+		// bad path
+		{
+			"GET",
+			RaftStreamPrefix + "/strange/1",
+			"1",
+			"1",
+			http.StatusNotFound,
+		},
+		// bad path
+		{
+			"GET",
+			RaftStreamPrefix + "/strange",
+			"1",
+			"1",
+			http.StatusNotFound,
+		},
+		// non-existant peer
+		{
+			"GET",
+			RaftStreamPrefix + "/message/2",
+			"1",
+			"1",
+			http.StatusNotFound,
+		},
+		// wrong cluster ID
+		{
+			"GET",
+			RaftStreamPrefix + "/message/1",
+			"2",
+			"1",
+			http.StatusPreconditionFailed,
+		},
+		// wrong remote id
+		{
+			"GET",
+			RaftStreamPrefix + "/message/1",
+			"1",
+			"2",
+			http.StatusPreconditionFailed,
+		},
+	}
+	for i, tt := range tests {
+		req, err := http.NewRequest(tt.method, "http://localhost:7001"+tt.path, nil)
+		if err != nil {
+			t.Fatalf("#%d: could not create request: %#v", i, err)
+		}
+		req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID)
+		req.Header.Set("X-Raft-To", tt.remote)
+		rw := httptest.NewRecorder()
+		peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
+		h := newStreamHandler(peerGetter, types.ID(1), types.ID(1))
+		h.ServeHTTP(rw, req)
+
+		if rw.Code != tt.wcode {
+			t.Errorf("#%d: code = %d, want %d", i, rw.Code, tt.wcode)
+		}
+	}
+}
+
+func TestCloseNotifier(t *testing.T) {
+	c := newCloseNotifier()
+	select {
+	case <-c.closeNotify():
+		t.Fatalf("received unexpected close notification")
+	default:
+	}
+	c.Close()
+	select {
+	case <-c.closeNotify():
+	default:
+		t.Fatalf("failed to get close notification")
+	}
+}
+
 // errReader implements io.Reader to facilitate a broken request.
 type errReader struct{}
 
 func (er *errReader) Read(_ []byte) (int, error) { return 0, errors.New("some error") }
 
-type nopProcessor struct{}
-
-func (p *nopProcessor) Process(ctx context.Context, m raftpb.Message) error  { return nil }
-func (p *nopProcessor) ReportUnreachable(id uint64)                          {}
-func (p *nopProcessor) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
-
-type errProcessor struct {
+type fakeRaft struct {
 	err error
 }
 
-func (p *errProcessor) Process(ctx context.Context, m raftpb.Message) error  { return p.err }
-func (p *errProcessor) ReportUnreachable(id uint64)                          {}
-func (p *errProcessor) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
+func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error  { return p.err }
+func (p *fakeRaft) ReportUnreachable(id uint64)                          {}
+func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
 
 type resWriterToError struct {
 	code int
@@ -180,3 +334,26 @@ type resWriterToError struct {
 
 func (e *resWriterToError) Error() string                 { return "" }
 func (e *resWriterToError) WriteTo(w http.ResponseWriter) { w.WriteHeader(e.code) }
+
+type fakePeerGetter struct {
+	peers map[types.ID]Peer
+}
+
+func (pg *fakePeerGetter) Get(id types.ID) Peer { return pg.peers[id] }
+
+type fakePeer struct {
+	msgs  []raftpb.Message
+	u     string
+	connc chan *outgoingConn
+}
+
+func newFakePeer() *fakePeer {
+	return &fakePeer{
+		connc: make(chan *outgoingConn, 1),
+	}
+}
+
+func (pr *fakePeer) Send(m raftpb.Message)                 { pr.msgs = append(pr.msgs, m) }
+func (pr *fakePeer) Update(u string)                       { pr.u = u }
+func (pr *fakePeer) attachOutgoingConn(conn *outgoingConn) { pr.connc <- conn }
+func (pr *fakePeer) Stop()                                 {}

+ 18 - 2
rafthttp/peer.go

@@ -33,6 +33,24 @@ const (
 	recvBufSize = 4096
 )
 
+type Peer interface {
+	// Send sends the message to the remote peer. The function is non-blocking
+	// and has no promise that the message will be received by the remote.
+	// When it fails to send message out, it will report the status to underlying
+	// raft.
+	Send(m raftpb.Message)
+	// Update updates the urls of remote peer.
+	Update(u string)
+	// attachOutgoingConn attachs the outgoing connection to the peer for
+	// stream usage. After the call, the ownership of the outgoing
+	// connection hands over to the peer. The peer will close the connection
+	// when it is no longer used.
+	attachOutgoingConn(conn *outgoingConn)
+	// Stop performs any necessary finalization and terminates the peer
+	// elegantly.
+	Stop()
+}
+
 // peer is the representative of a remote raft node. Local raft node sends
 // messages to the remote through peer.
 // Each peer has two underlying mechanisms to send out a message: stream and
@@ -171,8 +189,6 @@ func (p *peer) Resume() {
 	}
 }
 
-// Stop performs any necessary finalization and terminates the peer
-// elegantly.
 func (p *peer) Stop() {
 	close(p.stopc)
 	<-p.done

+ 87 - 0
rafthttp/peer_test.go

@@ -0,0 +1,87 @@
+// Copyright 2015 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rafthttp
+
+import (
+	"testing"
+
+	"github.com/coreos/etcd/raft/raftpb"
+)
+
+func TestPeerPick(t *testing.T) {
+	tests := []struct {
+		msgappWorking  bool
+		messageWorking bool
+		m              raftpb.Message
+		wpicked        string
+	}{
+		{
+			true, true,
+			raftpb.Message{Type: raftpb.MsgSnap},
+			"pipeline",
+		},
+		{
+			true, true,
+			raftpb.Message{Type: raftpb.MsgApp, Term: 1, LogTerm: 1},
+			"msgapp stream",
+		},
+		{
+			true, true,
+			raftpb.Message{Type: raftpb.MsgProp},
+			"general stream",
+		},
+		{
+			true, true,
+			raftpb.Message{Type: raftpb.MsgHeartbeat},
+			"general stream",
+		},
+		{
+			false, true,
+			raftpb.Message{Type: raftpb.MsgApp, Term: 1, LogTerm: 1},
+			"general stream",
+		},
+		{
+			false, false,
+			raftpb.Message{Type: raftpb.MsgApp, Term: 1, LogTerm: 1},
+			"pipeline",
+		},
+		{
+			false, false,
+			raftpb.Message{Type: raftpb.MsgProp},
+			"pipeline",
+		},
+		{
+			false, false,
+			raftpb.Message{Type: raftpb.MsgSnap},
+			"pipeline",
+		},
+		{
+			false, false,
+			raftpb.Message{Type: raftpb.MsgHeartbeat},
+			"pipeline",
+		},
+	}
+	for i, tt := range tests {
+		peer := &peer{
+			msgAppWriter: &streamWriter{working: tt.msgappWorking},
+			writer:       &streamWriter{working: tt.messageWorking},
+			pipeline:     &pipeline{},
+		}
+		_, picked, _ := peer.pick(tt.m)
+		if picked != tt.wpicked {
+			t.Errorf("#%d: picked = %v, want %v", i, picked, tt.wpicked)
+		}
+	}
+}

+ 6 - 6
rafthttp/pipeline_test.go

@@ -32,7 +32,7 @@ import (
 func TestPipelineSend(t *testing.T) {
 	tr := &roundTripperRecorder{}
 	fs := &stats.FollowerStats{}
-	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), fs, &nopProcessor{}, nil)
+	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
 	p.stop()
@@ -50,7 +50,7 @@ func TestPipelineSend(t *testing.T) {
 func TestPipelineExceedMaximalServing(t *testing.T) {
 	tr := newRoundTripperBlocker()
 	fs := &stats.FollowerStats{}
-	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), fs, &nopProcessor{}, nil)
+	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
 
 	// keep the sender busy and make the buffer full
 	// nothing can go out as we block the sender
@@ -89,7 +89,7 @@ func TestPipelineExceedMaximalServing(t *testing.T) {
 // it increases fail count in stats.
 func TestPipelineSendFailed(t *testing.T) {
 	fs := &stats.FollowerStats{}
-	p := newPipeline(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), types.ID(1), fs, &nopProcessor{}, nil)
+	p := newPipeline(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
 	p.stop()
@@ -103,7 +103,7 @@ func TestPipelineSendFailed(t *testing.T) {
 
 func TestPipelinePost(t *testing.T) {
 	tr := &roundTripperRecorder{}
-	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), nil, &nopProcessor{}, nil)
+	p := newPipeline(tr, "http://10.0.0.1", types.ID(1), types.ID(1), nil, &fakeRaft{}, nil)
 	if err := p.post([]byte("some data")); err != nil {
 		t.Fatalf("unexpect post error: %v", err)
 	}
@@ -145,7 +145,7 @@ func TestPipelinePostBad(t *testing.T) {
 		{"http://10.0.0.1", http.StatusCreated, nil},
 	}
 	for i, tt := range tests {
-		p := newPipeline(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), types.ID(1), nil, &nopProcessor{}, make(chan error))
+		p := newPipeline(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), types.ID(1), nil, &fakeRaft{}, make(chan error))
 		err := p.post([]byte("some data"))
 		p.stop()
 
@@ -166,7 +166,7 @@ func TestPipelinePostErrorc(t *testing.T) {
 	}
 	for i, tt := range tests {
 		errorc := make(chan error, 1)
-		p := newPipeline(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), types.ID(1), nil, &nopProcessor{}, errorc)
+		p := newPipeline(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), types.ID(1), nil, &fakeRaft{}, errorc)
 		p.post([]byte("some data"))
 		p.stop()
 		select {

+ 2 - 2
rafthttp/stream.go

@@ -219,7 +219,7 @@ func startStreamReader(tr http.RoundTripper, u string, t streamType, from, to, c
 
 func (cr *streamReader) run() {
 	for {
-		rc, err := cr.roundtrip()
+		rc, err := cr.dial()
 		if err != nil {
 			log.Printf("rafthttp: roundtripping error: %v", err)
 		} else {
@@ -307,7 +307,7 @@ func (cr *streamReader) isWorking() bool {
 	return cr.closer != nil
 }
 
-func (cr *streamReader) roundtrip() (io.ReadCloser, error) {
+func (cr *streamReader) dial() (io.ReadCloser, error) {
 	cr.mu.Lock()
 	u := cr.u
 	term := cr.msgAppTerm

+ 237 - 0
rafthttp/stream_test.go

@@ -0,0 +1,237 @@
+package rafthttp
+
+import (
+	"errors"
+	"net/http"
+	"net/http/httptest"
+	"reflect"
+	"testing"
+
+	"github.com/coreos/etcd/etcdserver/stats"
+	"github.com/coreos/etcd/pkg/testutil"
+	"github.com/coreos/etcd/pkg/types"
+	"github.com/coreos/etcd/raft/raftpb"
+)
+
+// TestStreamWriterAttachOutgoingConn tests that outgoingConn can be attached
+// to streamWriter. After that, streamWriter can use it to send messages
+// continuously, and closes it when stopped.
+func TestStreamWriterAttachOutgoingConn(t *testing.T) {
+	sw := startStreamWriter(&stats.FollowerStats{}, &fakeRaft{})
+	// the expected initial state of streamWrite is not working
+	if g := sw.isWorking(); g != false {
+		t.Errorf("initial working status = %v, want false", g)
+	}
+
+	// repeatitive tests to ensure it can use latest connection
+	var wfc *fakeWriteFlushCloser
+	for i := 0; i < 3; i++ {
+		prevwfc := wfc
+		wfc = &fakeWriteFlushCloser{}
+		sw.attach(&outgoingConn{t: streamTypeMessage, Writer: wfc, Flusher: wfc, Closer: wfc})
+		testutil.ForceGosched()
+		// previous attached connection should be closed
+		if prevwfc != nil && prevwfc.closed != true {
+			t.Errorf("#%d: close of previous connection = %v, want true", i, prevwfc.closed)
+		}
+		// starts working
+		if g := sw.isWorking(); g != true {
+			t.Errorf("#%d: working status = %v, want true", i, g)
+		}
+
+		sw.msgc <- raftpb.Message{}
+		testutil.ForceGosched()
+		// still working
+		if g := sw.isWorking(); g != true {
+			t.Errorf("#%d: working status = %v, want true", i, g)
+		}
+		if wfc.written == 0 {
+			t.Errorf("#%d: failed to write to the underlying connection", i)
+		}
+	}
+
+	sw.stop()
+	// no longer in working status now
+	if g := sw.isWorking(); g != false {
+		t.Errorf("working status after stop = %v, want false", g)
+	}
+	if wfc.closed != true {
+		t.Errorf("failed to close the underlying connection")
+	}
+}
+
+// TestStreamWriterAttachBadOutgoingConn tests that streamWriter with bad
+// outgoingConn will close the outgoingConn and fall back to non-working status.
+func TestStreamWriterAttachBadOutgoingConn(t *testing.T) {
+	sw := startStreamWriter(&stats.FollowerStats{}, &fakeRaft{})
+	defer sw.stop()
+	wfc := &fakeWriteFlushCloser{err: errors.New("blah")}
+	sw.attach(&outgoingConn{t: streamTypeMessage, Writer: wfc, Flusher: wfc, Closer: wfc})
+
+	sw.msgc <- raftpb.Message{}
+	testutil.ForceGosched()
+	// no longer working
+	if g := sw.isWorking(); g != false {
+		t.Errorf("working = %v, want false", g)
+	}
+	if wfc.closed != true {
+		t.Errorf("failed to close the underlying connection")
+	}
+}
+
+func TestStreamReaderDialRequest(t *testing.T) {
+	for i, tt := range []streamType{streamTypeMsgApp, streamTypeMessage} {
+		tr := &roundTripperRecorder{}
+		sr := &streamReader{
+			tr:         tr,
+			u:          "http://localhost:7001",
+			t:          tt,
+			from:       types.ID(1),
+			to:         types.ID(2),
+			cid:        types.ID(1),
+			msgAppTerm: 1,
+		}
+		sr.dial()
+
+		req := tr.Request()
+		var wurl string
+		switch tt {
+		case streamTypeMsgApp:
+			wurl = "http://localhost:7001/raft/stream/1"
+		case streamTypeMessage:
+			wurl = "http://localhost:7001/raft/stream/message/1"
+		}
+		if req.URL.String() != wurl {
+			t.Errorf("#%d: url = %s, want %s", i, req.URL.String(), wurl)
+		}
+		if w := "GET"; req.Method != w {
+			t.Errorf("#%d: method = %s, want %s", i, req.Method, w)
+		}
+		if g := req.Header.Get("X-Etcd-Cluster-ID"); g != "1" {
+			t.Errorf("#%d: header X-Etcd-Cluster-ID = %s, want 1", i, g)
+		}
+		if g := req.Header.Get("X-Raft-To"); g != "2" {
+			t.Errorf("#%d: header X-Raft-To = %s, want 2", i, g)
+		}
+		if g := req.Header.Get("X-Raft-Term"); tt == streamTypeMsgApp && g != "1" {
+			t.Errorf("#%d: header X-Raft-Term = %s, want 1", i, g)
+		}
+	}
+}
+
+// TestStreamReaderDialResult tests the result of the dial func call meets the
+// HTTP response received.
+func TestStreamReaderDialResult(t *testing.T) {
+	tests := []struct {
+		code int
+		err  error
+		wok  bool
+	}{
+		{0, errors.New("blah"), false},
+		{http.StatusOK, nil, true},
+		{http.StatusMethodNotAllowed, nil, false},
+		{http.StatusNotFound, nil, false},
+		{http.StatusPreconditionFailed, nil, false},
+	}
+	for i, tt := range tests {
+		tr := newRespRoundTripper(tt.code, tt.err)
+		sr := &streamReader{
+			tr:   tr,
+			u:    "http://localhost:7001",
+			t:    streamTypeMessage,
+			from: types.ID(1),
+			to:   types.ID(2),
+			cid:  types.ID(1),
+		}
+
+		_, err := sr.dial()
+		if ok := err == nil; ok != tt.wok {
+			t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
+		}
+	}
+}
+
+// TestStream tests that streamReader and streamWriter can build stream to
+// send messages between each other.
+func TestStream(t *testing.T) {
+	tests := []struct {
+		t    streamType
+		term uint64
+		m    raftpb.Message
+	}{
+		{
+			streamTypeMessage,
+			0,
+			raftpb.Message{Type: raftpb.MsgProp, To: 2},
+		},
+		{
+			streamTypeMsgApp,
+			1,
+			raftpb.Message{
+				Type:    raftpb.MsgApp,
+				From:    2,
+				To:      1,
+				Term:    1,
+				LogTerm: 1,
+				Index:   3,
+				Entries: []raftpb.Entry{{Term: 1, Index: 4}},
+			},
+		},
+	}
+	for i, tt := range tests {
+		h := &fakeStreamHandler{t: tt.t}
+		srv := httptest.NewServer(h)
+		defer srv.Close()
+
+		sw := startStreamWriter(&stats.FollowerStats{}, &fakeRaft{})
+		defer sw.stop()
+		h.sw = sw
+
+		recvc := make(chan raftpb.Message)
+		sr := startStreamReader(&http.Transport{}, srv.URL, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc)
+		defer sr.stop()
+		if tt.t == streamTypeMsgApp {
+			sr.updateMsgAppTerm(tt.term)
+		}
+
+		sw.msgc <- tt.m
+		m := <-recvc
+		if !reflect.DeepEqual(m, tt.m) {
+			t.Errorf("#%d: message = %+v, want %+v", i, m, tt.m)
+		}
+	}
+}
+
+type fakeWriteFlushCloser struct {
+	err     error
+	written int
+	closed  bool
+}
+
+func (wfc *fakeWriteFlushCloser) Write(p []byte) (n int, err error) {
+	wfc.written += len(p)
+	return len(p), wfc.err
+}
+func (wfc *fakeWriteFlushCloser) Flush() {}
+func (wfc *fakeWriteFlushCloser) Close() error {
+	wfc.closed = true
+	return wfc.err
+}
+
+type fakeStreamHandler struct {
+	t  streamType
+	sw *streamWriter
+}
+
+func (h *fakeStreamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	w.(http.Flusher).Flush()
+	c := newCloseNotifier()
+	h.sw.attach(&outgoingConn{
+		t:       h.t,
+		termStr: r.Header.Get("X-Raft-Term"),
+		Writer:  w,
+		Flusher: w.(http.Flusher),
+		Closer:  c,
+	})
+	<-c.closeNotify()
+}

+ 6 - 6
rafthttp/transport.go

@@ -52,8 +52,8 @@ type transport struct {
 	serverStats  *stats.ServerStats
 	leaderStats  *stats.LeaderStats
 
-	mu     sync.RWMutex       // protect the peer map
-	peers  map[types.ID]*peer // remote peers
+	mu     sync.RWMutex      // protect the peer map
+	peers  map[types.ID]Peer // remote peers
 	errorc chan error
 }
 
@@ -65,7 +65,7 @@ func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan
 		raft:         r,
 		serverStats:  ss,
 		leaderStats:  ls,
-		peers:        make(map[types.ID]*peer),
+		peers:        make(map[types.ID]Peer),
 		errorc:       errorc,
 	}
 }
@@ -79,7 +79,7 @@ func (t *transport) Handler() http.Handler {
 	return mux
 }
 
-func (t *transport) Peer(id types.ID) *peer {
+func (t *transport) Get(id types.ID) Peer {
 	t.mu.RLock()
 	defer t.mu.RUnlock()
 	return t.peers[id]
@@ -181,12 +181,12 @@ type Pausable interface {
 // for testing
 func (t *transport) Pause() {
 	for _, p := range t.peers {
-		p.Pause()
+		p.(Pausable).Pause()
 	}
 }
 
 func (t *transport) Resume() {
 	for _, p := range t.peers {
-		p.Resume()
+		p.(Pausable).Resume()
 	}
 }

+ 78 - 0
rafthttp/transport_bench_test.go

@@ -0,0 +1,78 @@
+// Copyright 2015 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rafthttp
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
+	"github.com/coreos/etcd/etcdserver/stats"
+	"github.com/coreos/etcd/pkg/types"
+	"github.com/coreos/etcd/raft"
+	"github.com/coreos/etcd/raft/raftpb"
+)
+
+func BenchmarkSendingMsgApp(b *testing.B) {
+	r := &countRaft{}
+	ss := &stats.ServerStats{}
+	ss.Initialize()
+	tr := NewTransporter(&http.Transport{}, types.ID(1), types.ID(1), r, nil, ss, stats.NewLeaderStats("1"))
+	srv := httptest.NewServer(tr.Handler())
+	defer srv.Close()
+	tr.AddPeer(types.ID(1), []string{srv.URL})
+	defer tr.Stop()
+	// wait for underlying stream created
+	time.Sleep(time.Second)
+
+	b.ReportAllocs()
+	b.SetBytes(64)
+
+	b.ResetTimer()
+	data := make([]byte, 64)
+	for i := 0; i < b.N; i++ {
+		tr.Send([]raftpb.Message{{Type: raftpb.MsgApp, To: 1, Entries: []raftpb.Entry{{Data: data}}}})
+	}
+	// wait until all messages are received by the target raft
+	for r.count() != b.N {
+		time.Sleep(time.Millisecond)
+	}
+	b.StopTimer()
+}
+
+type countRaft struct {
+	mu  sync.Mutex
+	cnt int
+}
+
+func (r *countRaft) Process(ctx context.Context, m raftpb.Message) error {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	r.cnt++
+	return nil
+}
+
+func (r *countRaft) ReportUnreachable(id uint64) {}
+
+func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
+
+func (r *countRaft) count() int {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	return r.cnt
+}

+ 55 - 3
rafthttp/transport_test.go

@@ -16,6 +16,7 @@ package rafthttp
 
 import (
 	"net/http"
+	"reflect"
 	"testing"
 	"time"
 
@@ -25,12 +26,51 @@ import (
 	"github.com/coreos/etcd/raft/raftpb"
 )
 
+// TestTransportSend tests that transport can send messages using correct
+// underlying peer, and drop local or unknown-target messages.
+func TestTransportSend(t *testing.T) {
+	ss := &stats.ServerStats{}
+	ss.Initialize()
+	peer1 := newFakePeer()
+	peer2 := newFakePeer()
+	tr := &transport{
+		serverStats: ss,
+		peers:       map[types.ID]Peer{types.ID(1): peer1, types.ID(2): peer2},
+	}
+	wmsgsIgnored := []raftpb.Message{
+		// bad local message
+		{Type: raftpb.MsgBeat},
+		// bad remote message
+		{Type: raftpb.MsgProp, To: 3},
+	}
+	wmsgsTo1 := []raftpb.Message{
+		// good message
+		{Type: raftpb.MsgProp, To: 1},
+		{Type: raftpb.MsgApp, To: 1},
+	}
+	wmsgsTo2 := []raftpb.Message{
+		// good message
+		{Type: raftpb.MsgProp, To: 2},
+		{Type: raftpb.MsgApp, To: 2},
+	}
+	tr.Send(wmsgsIgnored)
+	tr.Send(wmsgsTo1)
+	tr.Send(wmsgsTo2)
+
+	if !reflect.DeepEqual(peer1.msgs, wmsgsTo1) {
+		t.Errorf("msgs to peer 1 = %+v, want %+v", peer1.msgs, wmsgsTo1)
+	}
+	if !reflect.DeepEqual(peer2.msgs, wmsgsTo2) {
+		t.Errorf("msgs to peer 2 = %+v, want %+v", peer2.msgs, wmsgsTo2)
+	}
+}
+
 func TestTransportAdd(t *testing.T) {
 	ls := stats.NewLeaderStats("")
 	tr := &transport{
 		roundTripper: &roundTripperRecorder{},
 		leaderStats:  ls,
-		peers:        make(map[types.ID]*peer),
+		peers:        make(map[types.ID]Peer),
 	}
 	tr.AddPeer(1, []string{"http://a"})
 	defer tr.Stop()
@@ -55,7 +95,7 @@ func TestTransportRemove(t *testing.T) {
 	tr := &transport{
 		roundTripper: &roundTripperRecorder{},
 		leaderStats:  stats.NewLeaderStats(""),
-		peers:        make(map[types.ID]*peer),
+		peers:        make(map[types.ID]Peer),
 	}
 	tr.AddPeer(1, []string{"http://a"})
 	tr.RemovePeer(types.ID(1))
@@ -66,12 +106,24 @@ func TestTransportRemove(t *testing.T) {
 	}
 }
 
+func TestTransportUpdate(t *testing.T) {
+	peer := newFakePeer()
+	tr := &transport{
+		peers: map[types.ID]Peer{types.ID(1): peer},
+	}
+	u := "http://localhost:7001"
+	tr.UpdatePeer(types.ID(1), []string{u})
+	if w := "http://localhost:7001/raft"; peer.u != w {
+		t.Errorf("url = %s, want %s", peer.u, w)
+	}
+}
+
 func TestTransportErrorc(t *testing.T) {
 	errorc := make(chan error, 1)
 	tr := &transport{
 		roundTripper: newRespRoundTripper(http.StatusForbidden, nil),
 		leaderStats:  stats.NewLeaderStats(""),
-		peers:        make(map[types.ID]*peer),
+		peers:        make(map[types.ID]Peer),
 		errorc:       errorc,
 	}
 	tr.AddPeer(1, []string{"http://a"})

+ 9 - 2
test

@@ -15,12 +15,16 @@ COVER=${COVER:-"-cover"}
 source ./build
 
 # Hack: gofmt ./ will recursively check the .git directory. So use *.go for gofmt.
-TESTABLE_AND_FORMATTABLE="client discovery error etcdctl/command etcdmain etcdserver etcdserver/etcdhttp etcdserver/etcdhttp/httptypes migrate pkg/fileutil pkg/flags pkg/idutil pkg/ioutil pkg/netutil pkg/osutil pkg/pbutil pkg/types pkg/transport pkg/wait proxy raft rafthttp snap store wal"
-FORMATTABLE="$TESTABLE_AND_FORMATTABLE *.go etcdctl/ integration"
+TESTABLE_AND_FORMATTABLE="client discovery error etcdctl/command etcdmain etcdserver etcdserver/etcdhttp etcdserver/etcdhttp/httptypes migrate pkg/fileutil pkg/flags pkg/idutil pkg/ioutil pkg/netutil pkg/osutil pkg/pbutil pkg/types pkg/transport pkg/wait proxy raft snap store wal"
+# TODO: add it to race testing when the issue is resolved
+# https://github.com/golang/go/issues/9946
+NO_RACE_TESTABLE="rafthttp"
+FORMATTABLE="$TESTABLE_AND_FORMATTABLE $NO_RACE_TESTABLE *.go etcdctl/ integration"
 
 # user has not provided PKG override
 if [ -z "$PKG" ]; then
 	TEST=$TESTABLE_AND_FORMATTABLE
+	NO_RACE_TEST=$NO_RACE_TESTABLE
 	FMT=$FORMATTABLE
 
 # user has provided PKG override
@@ -37,9 +41,12 @@ fi
 # split TEST into an array and prepend REPO_PATH to each local package
 split=(${TEST// / })
 TEST=${split[@]/#/${REPO_PATH}/}
+split=(${NO_RACE_TEST// / })
+NO_RACE_TEST=${split[@]/#/${REPO_PATH}/}
 
 echo "Running tests..."
 go test -timeout 3m ${COVER} $@ ${TEST} --race
+go test -timeout 3m ${COVER} $@ ${NO_RACE_TEST}
 
 if [ -n "$INTEGRATION" ]; then
 	echo "Running integration tests..."