Browse Source

Merge pull request #4431 from heyitsanthony/rafthttp-localurl

rafthttp: plumb local peer URLs through transport
Anthony Romano 10 years ago
parent
commit
c8fc3413b7

+ 2 - 0
clientv3/integration/txn_test.go

@@ -52,6 +52,8 @@ func TestTxnWriteFail(t *testing.T) {
 	go func() {
 		// reconnect so terminate doesn't complain about double-close
 		clus.Members[0].Restart(t)
+		// wait for etcdserver to get established (CI races and get req times out)
+		time.Sleep(2 * time.Second)
 		donec <- struct{}{}
 
 		// and ensure the put didn't take

+ 1 - 0
etcdserver/server.go

@@ -374,6 +374,7 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 		TLSInfo:     cfg.PeerTLSInfo,
 		DialTimeout: cfg.peerDialTimeout(),
 		ID:          id,
+		URLs:        cfg.PeerURLs,
 		ClusterID:   cl.ID(),
 		Raft:        srv,
 		Snapshotter: ss,

+ 22 - 9
rafthttp/http.go

@@ -59,6 +59,7 @@ type writerToResponse interface {
 }
 
 type pipelineHandler struct {
+	tr  Transporter
 	r   Raft
 	cid types.ID
 }
@@ -68,8 +69,9 @@ type pipelineHandler struct {
 //
 // The handler reads out the raft message from request body,
 // and forwards it to the given raft state machine for processing.
-func newPipelineHandler(r Raft, cid types.ID) http.Handler {
+func newPipelineHandler(tr Transporter, r Raft, cid types.ID) http.Handler {
 	return &pipelineHandler{
+		tr:  tr,
 		r:   r,
 		cid: cid,
 	}
@@ -89,6 +91,12 @@ func (h *pipelineHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	if from, err := types.IDFromString(r.Header.Get("X-Server-From")); err != nil {
+		if urls := r.Header.Get("X-PeerURLs"); urls != "" {
+			h.tr.AddRemote(from, strings.Split(urls, ","))
+		}
+	}
+
 	// Limit the data size that could be read from the request body, which ensures that read from
 	// connection will not time out accidentally due to possible blocking in underlying implementation.
 	limitedr := pioutil.NewLimitedBufferReader(r.Body, connReadLimitByte)
@@ -114,19 +122,22 @@ func (h *pipelineHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		}
 		return
 	}
+
 	// Write StatusNoContent header after the message has been processed by
 	// raft, which facilitates the client to report MsgSnap status.
 	w.WriteHeader(http.StatusNoContent)
 }
 
 type snapshotHandler struct {
+	tr          Transporter
 	r           Raft
 	snapshotter *snap.Snapshotter
 	cid         types.ID
 }
 
-func newSnapshotHandler(r Raft, snapshotter *snap.Snapshotter, cid types.ID) http.Handler {
+func newSnapshotHandler(tr Transporter, r Raft, snapshotter *snap.Snapshotter, cid types.ID) http.Handler {
 	return &snapshotHandler{
+		tr:          tr,
 		r:           r,
 		snapshotter: snapshotter,
 		cid:         cid,
@@ -156,6 +167,12 @@ func (h *snapshotHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	if from, err := types.IDFromString(r.Header.Get("X-Server-From")); err != nil {
+		if urls := r.Header.Get("X-PeerURLs"); urls != "" {
+			h.tr.AddRemote(from, strings.Split(urls, ","))
+		}
+	}
+
 	dec := &messageDecoder{r: r.Body}
 	m, err := dec.decode()
 	if err != nil {
@@ -256,19 +273,15 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	p := h.peerGetter.Get(from)
-	if p == nil {
-		if urls := r.Header.Get("X-Server-Peers"); urls != "" {
-			h.tr.AddPeer(from, strings.Split(urls, ","))
-		}
-		p = h.peerGetter.Get(from)
-	}
-
 	if p == nil {
 		// This may happen in following cases:
 		// 1. user starts a remote peer that belongs to a different cluster
 		// with the same cluster ID.
 		// 2. local etcd falls behind of the cluster, and cannot recognize
 		// the members that joined after its current progress.
+		if urls := r.Header.Get("X-PeerURLs"); urls != "" {
+			h.tr.AddRemote(from, strings.Split(urls, ","))
+		}
 		plog.Errorf("failed to find member %s in cluster %s", from, h.cid)
 		http.Error(w, "error sender not found", http.StatusNotFound)
 		return

+ 1 - 2
rafthttp/http_test.go

@@ -151,7 +151,7 @@ func TestServeRaftPrefix(t *testing.T) {
 		req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID)
 		req.Header.Set("X-Server-Version", version.Version)
 		rw := httptest.NewRecorder()
-		h := newPipelineHandler(tt.p, types.ID(0))
+		h := newPipelineHandler(NewNopTransporter(), tt.p, types.ID(0))
 		h.ServeHTTP(rw, req)
 		if rw.Code != tt.wcode {
 			t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode)
@@ -364,4 +364,3 @@ func (pr *fakePeer) update(urls types.URLs)                { pr.peerURLs = urls
 func (pr *fakePeer) attachOutgoingConn(conn *outgoingConn) { pr.connc <- conn }
 func (pr *fakePeer) activeSince() time.Time                { return time.Time{} }
 func (pr *fakePeer) stop()                                 {}
-func (pr *fakePeer) urls() types.URLs                      { return pr.peerURLs }

+ 4 - 12
rafthttp/peer.go

@@ -65,9 +65,6 @@ type Peer interface {
 	// update updates the urls of remote peer.
 	update(urls types.URLs)
 
-	// urls  retrieves the urls of the remote peer
-	urls() types.URLs
-
 	// attachOutgoingConn attaches the outgoing connection to the peer for
 	// stream usage. After the call, the ownership of the outgoing
 	// connection hands over to the peer. The peer will close the connection
@@ -124,7 +121,6 @@ type peer struct {
 func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r Raft, fs *stats.FollowerStats, errorc chan error, v3demo bool) *peer {
 	status := newPeerStatus(to)
 	picker := newURLPicker(urls)
-	pipelineRt := transport.pipelineRt
 	p := &peer{
 		id:             to,
 		r:              r,
@@ -133,8 +129,8 @@ func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r
 		picker:         picker,
 		msgAppV2Writer: startStreamWriter(to, status, fs, r),
 		writer:         startStreamWriter(to, status, fs, r),
-		pipeline:       newPipeline(pipelineRt, picker, local, to, cid, status, fs, r, errorc),
-		snapSender:     newSnapshotSender(pipelineRt, picker, local, to, cid, status, r, errorc),
+		pipeline:       newPipeline(transport, picker, local, to, cid, status, fs, r, errorc),
+		snapSender:     newSnapshotSender(transport, picker, local, to, cid, status, r, errorc),
 		sendc:          make(chan raftpb.Message),
 		recvc:          make(chan raftpb.Message, recvBufSize),
 		propc:          make(chan raftpb.Message, maxPendingProposals),
@@ -161,8 +157,8 @@ func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r
 		}
 	}()
 
-	p.msgAppV2Reader = startStreamReader(p, transport.streamRt, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc)
-	reader := startStreamReader(p, transport.streamRt, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc)
+	p.msgAppV2Reader = startStreamReader(transport, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc)
+	reader := startStreamReader(transport, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc)
 	go func() {
 		var paused bool
 		for {
@@ -229,10 +225,6 @@ func (p *peer) update(urls types.URLs) {
 	}
 }
 
-func (p *peer) urls() types.URLs {
-	return p.picker.urls
-}
-
 func (p *peer) attachOutgoingConn(conn *outgoingConn) {
 	var ok bool
 	switch conn.t {

+ 5 - 6
rafthttp/pipeline.go

@@ -18,7 +18,6 @@ import (
 	"bytes"
 	"errors"
 	"io/ioutil"
-	"net/http"
 	"sync"
 	"time"
 
@@ -45,7 +44,7 @@ type pipeline struct {
 	from, to types.ID
 	cid      types.ID
 
-	tr     http.RoundTripper
+	tr     *Transport
 	picker *urlPicker
 	status *peerStatus
 	fs     *stats.FollowerStats
@@ -58,7 +57,7 @@ type pipeline struct {
 	stopc chan struct{}
 }
 
-func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID, status *peerStatus, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline {
+func newPipeline(tr *Transport, picker *urlPicker, from, to, cid types.ID, status *peerStatus, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline {
 	p := &pipeline{
 		from:   from,
 		to:     to,
@@ -126,10 +125,10 @@ func (p *pipeline) handle() {
 // error on any failure.
 func (p *pipeline) post(data []byte) (err error) {
 	u := p.picker.pick()
-	req := createPostRequest(u, RaftPrefix, bytes.NewBuffer(data), "application/protobuf", p.from, p.cid)
+	req := createPostRequest(u, RaftPrefix, bytes.NewBuffer(data), "application/protobuf", p.tr.URLs, p.from, p.cid)
 
 	done := make(chan struct{}, 1)
-	cancel := httputil.RequestCanceler(p.tr, req)
+	cancel := httputil.RequestCanceler(p.tr.pipelineRt, req)
 	go func() {
 		select {
 		case <-done:
@@ -139,7 +138,7 @@ func (p *pipeline) post(data []byte) (err error) {
 		}
 	}()
 
-	resp, err := p.tr.RoundTrip(req)
+	resp, err := p.tr.pipelineRt.RoundTrip(req)
 	done <- struct{}{}
 	if err != nil {
 		p.picker.unreachable(u)

+ 16 - 8
rafthttp/pipeline_test.go

@@ -37,7 +37,8 @@ func TestPipelineSend(t *testing.T) {
 	tr := &roundTripperRecorder{}
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
 	fs := &stats.FollowerStats{}
-	p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil)
+	tp := &Transport{pipelineRt: tr}
+	p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
 	testutil.WaitSchedule()
@@ -59,7 +60,8 @@ func TestPipelineKeepSendingWhenPostError(t *testing.T) {
 	tr := &respRoundTripper{err: fmt.Errorf("roundtrip error")}
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
 	fs := &stats.FollowerStats{}
-	p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil)
+	tp := &Transport{pipelineRt: tr}
+	p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil)
 
 	for i := 0; i < 50; i++ {
 		p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
@@ -79,7 +81,8 @@ func TestPipelineExceedMaximumServing(t *testing.T) {
 	tr := newRoundTripperBlocker()
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
 	fs := &stats.FollowerStats{}
-	p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil)
+	tp := &Transport{pipelineRt: tr}
+	p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil)
 
 	// keep the sender busy and make the buffer full
 	// nothing can go out as we block the sender
@@ -119,7 +122,8 @@ func TestPipelineExceedMaximumServing(t *testing.T) {
 func TestPipelineSendFailed(t *testing.T) {
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
 	fs := &stats.FollowerStats{}
-	p := newPipeline(newRespRoundTripper(0, errors.New("blah")), picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil)
+	tp := &Transport{pipelineRt: newRespRoundTripper(0, errors.New("blah"))}
+	p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil)
 
 	p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
 	testutil.WaitSchedule()
@@ -135,7 +139,8 @@ func TestPipelineSendFailed(t *testing.T) {
 func TestPipelinePost(t *testing.T) {
 	tr := &roundTripperRecorder{}
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
-	p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, nil)
+	tp := &Transport{pipelineRt: tr}
+	p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, nil)
 	if err := p.post([]byte("some data")); err != nil {
 		t.Fatalf("unexpected post error: %v", err)
 	}
@@ -182,7 +187,8 @@ func TestPipelinePostBad(t *testing.T) {
 	}
 	for i, tt := range tests {
 		picker := mustNewURLPicker(t, []string{tt.u})
-		p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, make(chan error))
+		tp := &Transport{pipelineRt: newRespRoundTripper(tt.code, tt.err)}
+		p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, make(chan error))
 		err := p.post([]byte("some data"))
 		p.stop()
 
@@ -203,7 +209,8 @@ func TestPipelinePostErrorc(t *testing.T) {
 	for i, tt := range tests {
 		picker := mustNewURLPicker(t, []string{tt.u})
 		errorc := make(chan error, 1)
-		p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, errorc)
+		tp := &Transport{pipelineRt: newRespRoundTripper(tt.code, tt.err)}
+		p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, errorc)
 		p.post([]byte("some data"))
 		p.stop()
 		select {
@@ -216,7 +223,8 @@ func TestPipelinePostErrorc(t *testing.T) {
 
 func TestStopBlockedPipeline(t *testing.T) {
 	picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
-	p := newPipeline(newRoundTripperBlocker(), picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, nil)
+	tp := &Transport{pipelineRt: newRoundTripperBlocker()}
+	p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, nil)
 	// send many messages that most of them will be blocked in buffer
 	for i := 0; i < connPerPipeline*10; i++ {
 		p.msgc <- raftpb.Message{}

+ 1 - 3
rafthttp/remote.go

@@ -15,8 +15,6 @@
 package rafthttp
 
 import (
-	"net/http"
-
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/raft/raftpb"
 )
@@ -27,7 +25,7 @@ type remote struct {
 	pipeline *pipeline
 }
 
-func startRemote(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r Raft, errorc chan error) *remote {
+func startRemote(tr *Transport, urls types.URLs, local, to, cid types.ID, r Raft, errorc chan error) *remote {
 	picker := newURLPicker(urls)
 	status := newPeerStatus(to)
 	return &remote{

+ 5 - 5
rafthttp/snapshot_sender.go

@@ -37,7 +37,7 @@ type snapshotSender struct {
 	from, to types.ID
 	cid      types.ID
 
-	tr     http.RoundTripper
+	tr     *Transport
 	picker *urlPicker
 	status *peerStatus
 	r      Raft
@@ -46,7 +46,7 @@ type snapshotSender struct {
 	stopc chan struct{}
 }
 
-func newSnapshotSender(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID, status *peerStatus, r Raft, errorc chan error) *snapshotSender {
+func newSnapshotSender(tr *Transport, picker *urlPicker, from, to, cid types.ID, status *peerStatus, r Raft, errorc chan error) *snapshotSender {
 	return &snapshotSender{
 		from:   from,
 		to:     to,
@@ -71,7 +71,7 @@ func (s *snapshotSender) send(merged snap.Message) {
 	defer body.Close()
 
 	u := s.picker.pick()
-	req := createPostRequest(u, RaftSnapshotPrefix, body, "application/octet-stream", s.from, s.cid)
+	req := createPostRequest(u, RaftSnapshotPrefix, body, "application/octet-stream", s.tr.URLs, s.from, s.cid)
 
 	plog.Infof("start to send database snapshot [index: %d, to %s]...", m.Snapshot.Metadata.Index, types.ID(m.To))
 
@@ -105,7 +105,7 @@ func (s *snapshotSender) send(merged snap.Message) {
 // post posts the given request.
 // It returns nil when request is sent out and processed successfully.
 func (s *snapshotSender) post(req *http.Request) (err error) {
-	cancel := httputil.RequestCanceler(s.tr, req)
+	cancel := httputil.RequestCanceler(s.tr.pipelineRt, req)
 
 	type responseAndError struct {
 		resp *http.Response
@@ -115,7 +115,7 @@ func (s *snapshotSender) post(req *http.Request) (err error) {
 	result := make(chan responseAndError, 1)
 
 	go func() {
-		resp, err := s.tr.RoundTrip(req)
+		resp, err := s.tr.pipelineRt.RoundTrip(req)
 		if err != nil {
 			result <- responseAndError{resp, nil, err}
 			return

+ 17 - 23
rafthttp/stream.go

@@ -226,8 +226,7 @@ func (cw *streamWriter) stop() {
 // streamReader is a long-running go-routine that dials to the remote stream
 // endpoint and reads messages from the response body returned.
 type streamReader struct {
-	localPeer     Peer
-	tr            http.RoundTripper
+	tr            *Transport
 	picker        *urlPicker
 	t             streamType
 	local, remote types.ID
@@ -244,21 +243,20 @@ type streamReader struct {
 	done   chan struct{}
 }
 
-func startStreamReader(p Peer, tr http.RoundTripper, picker *urlPicker, t streamType, local, remote, cid types.ID, status *peerStatus, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
+func startStreamReader(tr *Transport, picker *urlPicker, t streamType, local, remote, cid types.ID, status *peerStatus, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
 	r := &streamReader{
-		localPeer: p,
-		tr:        tr,
-		picker:    picker,
-		t:         t,
-		local:     local,
-		remote:    remote,
-		cid:       cid,
-		status:    status,
-		recvc:     recvc,
-		propc:     propc,
-		errorc:    errorc,
-		stopc:     make(chan struct{}),
-		done:      make(chan struct{}),
+		tr:     tr,
+		picker: picker,
+		t:      t,
+		local:  local,
+		remote: remote,
+		cid:    cid,
+		status: status,
+		recvc:  recvc,
+		propc:  propc,
+		errorc: errorc,
+		stopc:  make(chan struct{}),
+		done:   make(chan struct{}),
 	}
 	go r.run()
 	return r
@@ -374,11 +372,7 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
 	req.Header.Set("X-Etcd-Cluster-ID", cr.cid.String())
 	req.Header.Set("X-Raft-To", cr.remote.String())
 
-	var peerURLs []string
-	for _, url := range cr.localPeer.urls() {
-		peerURLs = append(peerURLs, url.String())
-	}
-	req.Header.Set("X-Server-Peers", strings.Join(peerURLs, ","))
+	setPeerURLsHeader(req, cr.tr.URLs)
 
 	cr.mu.Lock()
 	select {
@@ -387,10 +381,10 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
 		return nil, fmt.Errorf("stream reader is stopped")
 	default:
 	}
-	cr.cancel = httputil.RequestCanceler(cr.tr, req)
+	cr.cancel = httputil.RequestCanceler(cr.tr.streamRt, req)
 	cr.mu.Unlock()
 
-	resp, err := cr.tr.RoundTrip(req)
+	resp, err := cr.tr.streamRt.RoundTrip(req)
 	if err != nil {
 		cr.picker.unreachable(u)
 		return nil, err

+ 18 - 22
rafthttp/stream_test.go

@@ -116,12 +116,11 @@ func TestStreamReaderDialRequest(t *testing.T) {
 	for i, tt := range []streamType{streamTypeMessage, streamTypeMsgAppV2} {
 		tr := &roundTripperRecorder{}
 		sr := &streamReader{
-			tr:        tr,
-			localPeer: newFakePeer(),
-			picker:    mustNewURLPicker(t, []string{"http://localhost:2380"}),
-			local:     types.ID(1),
-			remote:    types.ID(2),
-			cid:       types.ID(1),
+			tr:     &Transport{streamRt: tr},
+			picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
+			local:  types.ID(1),
+			remote: types.ID(2),
+			cid:    types.ID(1),
 		}
 		sr.dial(tt)
 
@@ -167,13 +166,12 @@ func TestStreamReaderDialResult(t *testing.T) {
 			err:    tt.err,
 		}
 		sr := &streamReader{
-			tr:        tr,
-			localPeer: newFakePeer(),
-			picker:    mustNewURLPicker(t, []string{"http://localhost:2380"}),
-			local:     types.ID(1),
-			remote:    types.ID(2),
-			cid:       types.ID(1),
-			errorc:    make(chan error, 1),
+			tr:     &Transport{streamRt: tr},
+			picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
+			local:  types.ID(1),
+			remote: types.ID(2),
+			cid:    types.ID(1),
+			errorc: make(chan error, 1),
 		}
 
 		_, err := sr.dial(streamTypeMessage)
@@ -196,12 +194,11 @@ func TestStreamReaderDialDetectUnsupport(t *testing.T) {
 			header: http.Header{},
 		}
 		sr := &streamReader{
-			tr:        tr,
-			localPeer: newFakePeer(),
-			picker:    mustNewURLPicker(t, []string{"http://localhost:2380"}),
-			local:     types.ID(1),
-			remote:    types.ID(2),
-			cid:       types.ID(1),
+			tr:     &Transport{streamRt: tr},
+			picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
+			local:  types.ID(1),
+			remote: types.ID(2),
+			cid:    types.ID(1),
 		}
 
 		_, err := sr.dial(typ)
@@ -257,9 +254,8 @@ func TestStream(t *testing.T) {
 		h.sw = sw
 
 		picker := mustNewURLPicker(t, []string{srv.URL})
-		tr := &http.Transport{}
-		peer := newFakePeer()
-		sr := startStreamReader(peer, tr, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), newPeerStatus(types.ID(1)), recvc, propc, nil)
+		tr := &Transport{streamRt: &http.Transport{}}
+		sr := startStreamReader(tr, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), newPeerStatus(types.ID(1)), recvc, propc, nil)
 		defer sr.stop()
 		// wait for stream to work
 		var writec chan<- raftpb.Message

+ 10 - 6
rafthttp/transport.go

@@ -97,9 +97,10 @@ type Transport struct {
 	DialTimeout time.Duration     // maximum duration before timing out dial of the request
 	TLSInfo     transport.TLSInfo // TLS information used when creating connection
 
-	ID          types.ID // local member ID
-	ClusterID   types.ID // raft cluster ID for request validation
-	Raft        Raft     // raft state machine, to which the Transport forwards received messages and reports status
+	ID          types.ID   // local member ID
+	URLs        types.URLs // local peer URLs
+	ClusterID   types.ID   // raft cluster ID for request validation
+	Raft        Raft       // raft state machine, to which the Transport forwards received messages and reports status
 	Snapshotter *snap.Snapshotter
 	ServerStats *stats.ServerStats // used to record general transportation statistics
 	// used to record transportation statistics with followers when
@@ -139,9 +140,9 @@ func (t *Transport) Start() error {
 }
 
 func (t *Transport) Handler() http.Handler {
-	pipelineHandler := newPipelineHandler(t.Raft, t.ClusterID)
+	pipelineHandler := newPipelineHandler(t, t.Raft, t.ClusterID)
 	streamHandler := newStreamHandler(t, t, t.Raft, t.ID, t.ClusterID)
-	snapHandler := newSnapshotHandler(t.Raft, t.Snapshotter, t.ClusterID)
+	snapHandler := newSnapshotHandler(t, t.Raft, t.Snapshotter, t.ClusterID)
 	mux := http.NewServeMux()
 	mux.Handle(RaftPrefix, pipelineHandler)
 	mux.Handle(RaftStreamPrefix+"/", streamHandler)
@@ -205,6 +206,9 @@ func (t *Transport) Stop() {
 func (t *Transport) AddRemote(id types.ID, us []string) {
 	t.mu.Lock()
 	defer t.mu.Unlock()
+	if _, ok := t.peers[id]; ok {
+		return
+	}
 	if _, ok := t.remotes[id]; ok {
 		return
 	}
@@ -212,7 +216,7 @@ func (t *Transport) AddRemote(id types.ID, us []string) {
 	if err != nil {
 		plog.Panicf("newURLs %+v should never fail: %+v", us, err)
 	}
-	t.remotes[id] = startRemote(t.pipelineRt, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC)
+	t.remotes[id] = startRemote(t, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC)
 }
 
 func (t *Transport) AddPeer(id types.ID, us []string) {

+ 1 - 1
rafthttp/transport_test.go

@@ -121,7 +121,7 @@ func TestTransportUpdate(t *testing.T) {
 	tr.UpdatePeer(types.ID(1), []string{u})
 	wurls := types.URLs(testutil.MustNewURLs(t, []string{"http://localhost:2380"}))
 	if !reflect.DeepEqual(peer.peerURLs, wurls) {
-		t.Errorf("urls = %+v, want %+v", peer.urls, wurls)
+		t.Errorf("urls = %+v, want %+v", peer.peerURLs, wurls)
 	}
 }
 

+ 16 - 1
rafthttp/util.go

@@ -86,7 +86,7 @@ func readEntryFrom(r io.Reader, ent *raftpb.Entry) error {
 }
 
 // createPostRequest creates a HTTP POST request that sends raft message.
-func createPostRequest(u url.URL, path string, body io.Reader, ct string, from, cid types.ID) *http.Request {
+func createPostRequest(u url.URL, path string, body io.Reader, ct string, urls types.URLs, from, cid types.ID) *http.Request {
 	uu := u
 	uu.Path = path
 	req, err := http.NewRequest("POST", uu.String(), body)
@@ -98,6 +98,8 @@ func createPostRequest(u url.URL, path string, body io.Reader, ct string, from,
 	req.Header.Set("X-Server-Version", version.Version)
 	req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
 	req.Header.Set("X-Etcd-Cluster-ID", cid.String())
+	setPeerURLsHeader(req, urls)
+
 	return req
 }
 
@@ -187,3 +189,16 @@ func checkVersionCompability(name string, server, minCluster *semver.Version) er
 	}
 	return nil
 }
+
+// setPeerURLsHeader reports local urls for peer discovery
+func setPeerURLsHeader(req *http.Request, urls types.URLs) {
+	if urls == nil {
+		// often not set in unit tests
+		return
+	}
+	var peerURLs []string
+	for _, url := range urls {
+		peerURLs = append(peerURLs, url.String())
+	}
+	req.Header.Set("X-PeerURLs", strings.Join(peerURLs, ","))
+}