浏览代码

rafthttp: build transport inside pkg instead of passed-in

rafthttp has different requirements for connections created by the
transport for different usage, and this is hard to achieve when giving
one http.RoundTripper. Pass into pkg the data needed to build transport
now, and let rafthttp build its own transports.
Yicheng Qin 10 年之前
父节点
当前提交
207c92b627

+ 1 - 12
etcdmain/etcd.go

@@ -199,11 +199,6 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 		return nil, fmt.Errorf("error setting up initial cluster: %v", err)
 		return nil, fmt.Errorf("error setting up initial cluster: %v", err)
 	}
 	}
 
 
-	pt, err := transport.NewTimeoutTransport(cfg.peerTLSInfo, peerDialTimeout(cfg.ElectionMs), rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout)
-	if err != nil {
-		return nil, err
-	}
-
 	if !cfg.peerTLSInfo.Empty() {
 	if !cfg.peerTLSInfo.Empty() {
 		plog.Infof("peerTLS: %s", cfg.peerTLSInfo)
 		plog.Infof("peerTLS: %s", cfg.peerTLSInfo)
 	}
 	}
@@ -284,7 +279,7 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 		DiscoveryProxy:      cfg.dproxy,
 		DiscoveryProxy:      cfg.dproxy,
 		NewCluster:          cfg.isNewCluster(),
 		NewCluster:          cfg.isNewCluster(),
 		ForceNewCluster:     cfg.forceNewCluster,
 		ForceNewCluster:     cfg.forceNewCluster,
-		Transport:           pt,
+		PeerTLSInfo:         cfg.peerTLSInfo,
 		TickMs:              cfg.TickMs,
 		TickMs:              cfg.TickMs,
 		ElectionTicks:       cfg.electionTicks(),
 		ElectionTicks:       cfg.electionTicks(),
 		V3demo:              cfg.v3demo,
 		V3demo:              cfg.v3demo,
@@ -534,9 +529,3 @@ func setupLogging(cfg *config) {
 		repoLog.SetLogLevel(settings)
 		repoLog.SetLogLevel(settings)
 	}
 	}
 }
 }
-
-func peerDialTimeout(electionMs uint) time.Duration {
-	// 1s for queue wait and system delay
-	// + one RTT, which is smaller than 1/5 election timeout
-	return time.Second + time.Duration(electionMs)*time.Millisecond/5
-}

+ 8 - 2
etcdserver/config.go

@@ -16,13 +16,13 @@ package etcdserver
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"net/http"
 	"path"
 	"path"
 	"sort"
 	"sort"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
 	"github.com/coreos/etcd/pkg/netutil"
 	"github.com/coreos/etcd/pkg/netutil"
+	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
 )
 )
 
 
@@ -44,7 +44,7 @@ type ServerConfig struct {
 	InitialClusterToken string
 	InitialClusterToken string
 	NewCluster          bool
 	NewCluster          bool
 	ForceNewCluster     bool
 	ForceNewCluster     bool
-	Transport           *http.Transport
+	PeerTLSInfo         transport.TLSInfo
 
 
 	TickMs        uint
 	TickMs        uint
 	ElectionTicks int
 	ElectionTicks int
@@ -132,6 +132,12 @@ func (c *ServerConfig) ReqTimeout() time.Duration {
 	return 5*time.Second + 2*time.Duration(c.ElectionTicks)*time.Duration(c.TickMs)*time.Millisecond
 	return 5*time.Second + 2*time.Duration(c.ElectionTicks)*time.Duration(c.TickMs)*time.Millisecond
 }
 }
 
 
+func (c *ServerConfig) peerDialTimeout() time.Duration {
+	// 1s for queue wait and system delay
+	// + one RTT, which is smaller than 1/5 election timeout
+	return time.Second + time.Duration(c.ElectionTicks)*time.Duration(c.TickMs)*time.Millisecond/5
+}
+
 func (c *ServerConfig) PrintWithInitial() { c.print(true) }
 func (c *ServerConfig) PrintWithInitial() { c.print(true) }
 
 
 func (c *ServerConfig) Print() { c.print(false) }
 func (c *ServerConfig) Print() { c.print(false) }

+ 24 - 13
etcdserver/server.go

@@ -40,6 +40,7 @@ import (
 	"github.com/coreos/etcd/pkg/pbutil"
 	"github.com/coreos/etcd/pkg/pbutil"
 	"github.com/coreos/etcd/pkg/runtime"
 	"github.com/coreos/etcd/pkg/runtime"
 	"github.com/coreos/etcd/pkg/timeutil"
 	"github.com/coreos/etcd/pkg/timeutil"
+	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/wait"
 	"github.com/coreos/etcd/pkg/wait"
 	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/raft"
@@ -167,7 +168,9 @@ type EtcdServer struct {
 
 
 	SyncTicker <-chan time.Time
 	SyncTicker <-chan time.Time
 
 
-	reqIDGen *idutil.Generator
+	// versionTr used to send requests for peer version
+	versionTr *http.Transport
+	reqIDGen  *idutil.Generator
 
 
 	// forceVersionC is used to force the version monitor loop
 	// forceVersionC is used to force the version monitor loop
 	// to detect the cluster version immediately.
 	// to detect the cluster version immediately.
@@ -205,6 +208,10 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 	haveWAL := wal.Exist(cfg.WALDir())
 	haveWAL := wal.Exist(cfg.WALDir())
 	ss := snap.New(cfg.SnapDir())
 	ss := snap.New(cfg.SnapDir())
 
 
+	pt, err := transport.NewTransport(cfg.PeerTLSInfo, cfg.peerDialTimeout())
+	if err != nil {
+		return nil, err
+	}
 	var remotes []*Member
 	var remotes []*Member
 	switch {
 	switch {
 	case !haveWAL && !cfg.NewCluster:
 	case !haveWAL && !cfg.NewCluster:
@@ -215,14 +222,14 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		existingCluster, err := GetClusterFromRemotePeers(getRemotePeerURLs(cl, cfg.Name), cfg.Transport)
+		existingCluster, err := GetClusterFromRemotePeers(getRemotePeerURLs(cl, cfg.Name), pt)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("cannot fetch cluster info from peer urls: %v", err)
 			return nil, fmt.Errorf("cannot fetch cluster info from peer urls: %v", err)
 		}
 		}
 		if err := ValidateClusterAndAssignIDs(cl, existingCluster); err != nil {
 		if err := ValidateClusterAndAssignIDs(cl, existingCluster); err != nil {
 			return nil, fmt.Errorf("error validating peerURLs %s: %v", existingCluster, err)
 			return nil, fmt.Errorf("error validating peerURLs %s: %v", existingCluster, err)
 		}
 		}
-		if !isCompatibleWithCluster(cl, cl.MemberByName(cfg.Name).ID, cfg.Transport) {
+		if !isCompatibleWithCluster(cl, cl.MemberByName(cfg.Name).ID, pt) {
 			return nil, fmt.Errorf("incomptible with current running cluster")
 			return nil, fmt.Errorf("incomptible with current running cluster")
 		}
 		}
 
 
@@ -240,7 +247,7 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 			return nil, err
 			return nil, err
 		}
 		}
 		m := cl.MemberByName(cfg.Name)
 		m := cl.MemberByName(cfg.Name)
-		if isMemberBootstrapped(cl, cfg.Name, cfg.Transport) {
+		if isMemberBootstrapped(cl, cfg.Name, pt) {
 			return nil, fmt.Errorf("member %s has already been bootstrapped", m.ID)
 			return nil, fmt.Errorf("member %s has already been bootstrapped", m.ID)
 		}
 		}
 		if cfg.ShouldDiscover() {
 		if cfg.ShouldDiscover() {
@@ -328,6 +335,7 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 		stats:         sstats,
 		stats:         sstats,
 		lstats:        lstats,
 		lstats:        lstats,
 		SyncTicker:    time.Tick(500 * time.Millisecond),
 		SyncTicker:    time.Tick(500 * time.Millisecond),
+		versionTr:     pt,
 		reqIDGen:      idutil.NewGenerator(uint8(id), time.Now()),
 		reqIDGen:      idutil.NewGenerator(uint8(id), time.Now()),
 		forceVersionC: make(chan struct{}),
 		forceVersionC: make(chan struct{}),
 	}
 	}
@@ -346,15 +354,18 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 
 
 	// TODO: move transport initialization near the definition of remote
 	// TODO: move transport initialization near the definition of remote
 	tr := &rafthttp.Transport{
 	tr := &rafthttp.Transport{
-		RoundTripper: cfg.Transport,
-		ID:           id,
-		ClusterID:    cl.ID(),
-		Raft:         srv,
-		ServerStats:  sstats,
-		LeaderStats:  lstats,
-		ErrorC:       srv.errorc,
+		TLSInfo:     cfg.PeerTLSInfo,
+		DialTimeout: cfg.peerDialTimeout(),
+		ID:          id,
+		ClusterID:   cl.ID(),
+		Raft:        srv,
+		ServerStats: sstats,
+		LeaderStats: lstats,
+		ErrorC:      srv.errorc,
+	}
+	if err := tr.Start(); err != nil {
+		return nil, err
 	}
 	}
-	tr.Start()
 	// add all remotes into transport
 	// add all remotes into transport
 	for _, m := range remotes {
 	for _, m := range remotes {
 		if m.ID != id {
 		if m.ID != id {
@@ -1032,7 +1043,7 @@ func (s *EtcdServer) monitorVersions() {
 			continue
 			continue
 		}
 		}
 
 
-		v := decideClusterVersion(getVersions(s.cluster, s.id, s.cfg.Transport))
+		v := decideClusterVersion(getVersions(s.cluster, s.id, s.versionTr))
 		if v != nil {
 		if v != nil {
 			// only keep major.minor version for comparasion
 			// only keep major.minor version for comparasion
 			v = &semver.Version{
 			v = &semver.Version{

+ 1 - 1
etcdserver/server_test.go

@@ -1468,7 +1468,7 @@ func (n *readyNode) Ready() <-chan raft.Ready { return n.readyc }
 
 
 type nopTransporter struct{}
 type nopTransporter struct{}
 
 
-func (s *nopTransporter) Start()                              {}
+func (s *nopTransporter) Start() error                        { return nil }
 func (s *nopTransporter) Handler() http.Handler               { return nil }
 func (s *nopTransporter) Handler() http.Handler               { return nil }
 func (s *nopTransporter) Send(m []raftpb.Message)             {}
 func (s *nopTransporter) Send(m []raftpb.Message)             {}
 func (s *nopTransporter) AddRemote(id types.ID, us []string)  {}
 func (s *nopTransporter) AddRemote(id types.ID, us []string)  {}

+ 1 - 2
integration/cluster_test.go

@@ -688,7 +688,7 @@ func mustNewMember(t *testing.T, name string, usePeerTLS bool) *member {
 	}
 	}
 	m.InitialClusterToken = clusterName
 	m.InitialClusterToken = clusterName
 	m.NewCluster = true
 	m.NewCluster = true
-	m.Transport = mustNewTransport(t, m.PeerTLSInfo)
+	m.ServerConfig.PeerTLSInfo = m.PeerTLSInfo
 	m.ElectionTicks = electionTicks
 	m.ElectionTicks = electionTicks
 	m.TickMs = uint(tickDuration / time.Millisecond)
 	m.TickMs = uint(tickDuration / time.Millisecond)
 	return m
 	return m
@@ -720,7 +720,6 @@ func (m *member) Clone(t *testing.T) *member {
 		panic(err)
 		panic(err)
 	}
 	}
 	mm.InitialClusterToken = m.InitialClusterToken
 	mm.InitialClusterToken = m.InitialClusterToken
-	mm.Transport = mustNewTransport(t, m.PeerTLSInfo)
 	mm.ElectionTicks = m.ElectionTicks
 	mm.ElectionTicks = m.ElectionTicks
 	mm.PeerTLSInfo = m.PeerTLSInfo
 	mm.PeerTLSInfo = m.PeerTLSInfo
 	return mm
 	return mm

+ 20 - 25
rafthttp/functional_test.go

@@ -15,7 +15,6 @@
 package rafthttp
 package rafthttp
 
 
 import (
 import (
-	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
@@ -31,12 +30,11 @@ import (
 func TestSendMessage(t *testing.T) {
 func TestSendMessage(t *testing.T) {
 	// member 1
 	// member 1
 	tr := &Transport{
 	tr := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(1),
-		ClusterID:    types.ID(1),
-		Raft:         &fakeRaft{},
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("1"),
+		ID:          types.ID(1),
+		ClusterID:   types.ID(1),
+		Raft:        &fakeRaft{},
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("1"),
 	}
 	}
 	tr.Start()
 	tr.Start()
 	srv := httptest.NewServer(tr.Handler())
 	srv := httptest.NewServer(tr.Handler())
@@ -46,12 +44,11 @@ func TestSendMessage(t *testing.T) {
 	recvc := make(chan raftpb.Message, 1)
 	recvc := make(chan raftpb.Message, 1)
 	p := &fakeRaft{recvc: recvc}
 	p := &fakeRaft{recvc: recvc}
 	tr2 := &Transport{
 	tr2 := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(2),
-		ClusterID:    types.ID(1),
-		Raft:         p,
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("2"),
+		ID:          types.ID(2),
+		ClusterID:   types.ID(1),
+		Raft:        p,
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("2"),
 	}
 	}
 	tr2.Start()
 	tr2.Start()
 	srv2 := httptest.NewServer(tr2.Handler())
 	srv2 := httptest.NewServer(tr2.Handler())
@@ -92,12 +89,11 @@ func TestSendMessage(t *testing.T) {
 func TestSendMessageWhenStreamIsBroken(t *testing.T) {
 func TestSendMessageWhenStreamIsBroken(t *testing.T) {
 	// member 1
 	// member 1
 	tr := &Transport{
 	tr := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(1),
-		ClusterID:    types.ID(1),
-		Raft:         &fakeRaft{},
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("1"),
+		ID:          types.ID(1),
+		ClusterID:   types.ID(1),
+		Raft:        &fakeRaft{},
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("1"),
 	}
 	}
 	tr.Start()
 	tr.Start()
 	srv := httptest.NewServer(tr.Handler())
 	srv := httptest.NewServer(tr.Handler())
@@ -107,12 +103,11 @@ func TestSendMessageWhenStreamIsBroken(t *testing.T) {
 	recvc := make(chan raftpb.Message, 1)
 	recvc := make(chan raftpb.Message, 1)
 	p := &fakeRaft{recvc: recvc}
 	p := &fakeRaft{recvc: recvc}
 	tr2 := &Transport{
 	tr2 := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(2),
-		ClusterID:    types.ID(1),
-		Raft:         p,
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("2"),
+		ID:          types.ID(2),
+		ClusterID:   types.ID(1),
+		Raft:        p,
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("2"),
 	}
 	}
 	tr2.Start()
 	tr2.Start()
 	srv2 := httptest.NewServer(tr2.Handler())
 	srv2 := httptest.NewServer(tr2.Handler())

+ 4 - 4
rafthttp/peer.go

@@ -111,7 +111,7 @@ type peer struct {
 	done  chan struct{}
 	done  chan struct{}
 }
 }
 
 
-func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r Raft, fs *stats.FollowerStats, errorc chan error, term uint64) *peer {
+func startPeer(streamRt, pipelineRt http.RoundTripper, urls types.URLs, local, to, cid types.ID, r Raft, fs *stats.FollowerStats, errorc chan error, term uint64) *peer {
 	picker := newURLPicker(urls)
 	picker := newURLPicker(urls)
 	status := newPeerStatus(to)
 	status := newPeerStatus(to)
 	p := &peer{
 	p := &peer{
@@ -120,7 +120,7 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
 		status:       status,
 		status:       status,
 		msgAppWriter: startStreamWriter(to, status, fs, r),
 		msgAppWriter: startStreamWriter(to, status, fs, r),
 		writer:       startStreamWriter(to, status, fs, r),
 		writer:       startStreamWriter(to, status, fs, r),
-		pipeline:     newPipeline(tr, picker, local, to, cid, status, fs, r, errorc),
+		pipeline:     newPipeline(pipelineRt, picker, local, to, cid, status, fs, r, errorc),
 		sendc:        make(chan raftpb.Message),
 		sendc:        make(chan raftpb.Message),
 		recvc:        make(chan raftpb.Message, recvBufSize),
 		recvc:        make(chan raftpb.Message, recvBufSize),
 		propc:        make(chan raftpb.Message, maxPendingProposals),
 		propc:        make(chan raftpb.Message, maxPendingProposals),
@@ -148,8 +148,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
 		}
 		}
 	}()
 	}()
 
 
-	p.msgAppReader = startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc, term)
-	reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc, term)
+	p.msgAppReader = startStreamReader(streamRt, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc, term)
+	reader := startStreamReader(streamRt, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc, term)
 	go func() {
 	go func() {
 		var paused bool
 		var paused bool
 		for {
 		for {

+ 33 - 11
rafthttp/transport.go

@@ -23,6 +23,7 @@ import (
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/xiang90/probing"
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/xiang90/probing"
 	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
 	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
 	"github.com/coreos/etcd/etcdserver/stats"
 	"github.com/coreos/etcd/etcdserver/stats"
+	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/raft/raftpb"
 	"github.com/coreos/etcd/raft/raftpb"
@@ -40,7 +41,7 @@ type Raft interface {
 type Transporter interface {
 type Transporter interface {
 	// Start starts the given Transporter.
 	// Start starts the given Transporter.
 	// Start MUST be called before calling other functions in the interface.
 	// Start MUST be called before calling other functions in the interface.
-	Start()
+	Start() error
 	// Handler returns the HTTP handler of the transporter.
 	// Handler returns the HTTP handler of the transporter.
 	// A transporter HTTP handler handles the HTTP requests
 	// A transporter HTTP handler handles the HTTP requests
 	// from remote peers.
 	// from remote peers.
@@ -88,11 +89,13 @@ type Transporter interface {
 // User needs to call Start before calling other functions, and call
 // User needs to call Start before calling other functions, and call
 // Stop when the Transport is no longer used.
 // Stop when the Transport is no longer used.
 type Transport struct {
 type Transport struct {
-	RoundTripper http.RoundTripper  // roundTripper to send requests
-	ID           types.ID           // local member ID
-	ClusterID    types.ID           // raft cluster ID for request validation
-	Raft         Raft               // raft state machine, to which the Transport forwards received messages and reports status
-	ServerStats  *stats.ServerStats // used to record general transportation statistics
+	DialTimeout time.Duration     // maximum duration before timing out dial of the request
+	TLSInfo     transport.TLSInfo // TLS information used when creating connection
+
+	ID          types.ID           // local member ID
+	ClusterID   types.ID           // raft cluster ID for request validation
+	Raft        Raft               // raft state machine, to which the Transport forwards received messages and reports status
+	ServerStats *stats.ServerStats // used to record general transportation statistics
 	// used to record transportation statistics with followers when
 	// used to record transportation statistics with followers when
 	// performing as leader in raft protocol
 	// performing as leader in raft protocol
 	LeaderStats *stats.LeaderStats
 	LeaderStats *stats.LeaderStats
@@ -102,6 +105,9 @@ type Transport struct {
 	// machine and thus stop the Transport.
 	// machine and thus stop the Transport.
 	ErrorC chan error
 	ErrorC chan error
 
 
+	streamRt   http.RoundTripper // roundTripper used by streams
+	pipelineRt http.RoundTripper // roundTripper used by pipelines
+
 	mu      sync.RWMutex         // protect the term, remote and peer map
 	mu      sync.RWMutex         // protect the term, remote and peer map
 	term    uint64               // the latest term that has been observed
 	term    uint64               // the latest term that has been observed
 	remotes map[types.ID]*remote // remotes map that helps newly joined member to catch up
 	remotes map[types.ID]*remote // remotes map that helps newly joined member to catch up
@@ -110,10 +116,23 @@ type Transport struct {
 	prober probing.Prober
 	prober probing.Prober
 }
 }
 
 
-func (t *Transport) Start() {
+func (t *Transport) Start() error {
+	var err error
+	// Read/write timeout is set for stream roundTripper to promptly
+	// find out broken status, which minimizes the number of messages
+	// sent on broken connection.
+	t.streamRt, err = transport.NewTimeoutTransport(t.TLSInfo, t.DialTimeout, ConnReadTimeout, ConnWriteTimeout)
+	if err != nil {
+		return err
+	}
+	t.pipelineRt, err = transport.NewTransport(t.TLSInfo, t.DialTimeout)
+	if err != nil {
+		return err
+	}
 	t.remotes = make(map[types.ID]*remote)
 	t.remotes = make(map[types.ID]*remote)
 	t.peers = make(map[types.ID]Peer)
 	t.peers = make(map[types.ID]Peer)
-	t.prober = probing.NewProber(t.RoundTripper)
+	t.prober = probing.NewProber(t.pipelineRt)
+	return nil
 }
 }
 
 
 func (t *Transport) Handler() http.Handler {
 func (t *Transport) Handler() http.Handler {
@@ -183,7 +202,10 @@ func (t *Transport) Stop() {
 		p.Stop()
 		p.Stop()
 	}
 	}
 	t.prober.RemoveAll()
 	t.prober.RemoveAll()
-	if tr, ok := t.RoundTripper.(*http.Transport); ok {
+	if tr, ok := t.streamRt.(*http.Transport); ok {
+		tr.CloseIdleConnections()
+	}
+	if tr, ok := t.pipelineRt.(*http.Transport); ok {
 		tr.CloseIdleConnections()
 		tr.CloseIdleConnections()
 	}
 	}
 }
 }
@@ -198,7 +220,7 @@ func (t *Transport) AddRemote(id types.ID, us []string) {
 	if err != nil {
 	if err != nil {
 		plog.Panicf("newURLs %+v should never fail: %+v", us, err)
 		plog.Panicf("newURLs %+v should never fail: %+v", us, err)
 	}
 	}
-	t.remotes[id] = startRemote(t.RoundTripper, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC)
+	t.remotes[id] = startRemote(t.pipelineRt, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC)
 }
 }
 
 
 func (t *Transport) AddPeer(id types.ID, us []string) {
 func (t *Transport) AddPeer(id types.ID, us []string) {
@@ -212,7 +234,7 @@ func (t *Transport) AddPeer(id types.ID, us []string) {
 		plog.Panicf("newURLs %+v should never fail: %+v", us, err)
 		plog.Panicf("newURLs %+v should never fail: %+v", us, err)
 	}
 	}
 	fs := t.LeaderStats.Follower(id.String())
 	fs := t.LeaderStats.Follower(id.String())
-	t.peers[id] = startPeer(t.RoundTripper, urls, t.ID, id, t.ClusterID, t.Raft, fs, t.ErrorC, t.term)
+	t.peers[id] = startPeer(t.streamRt, t.pipelineRt, urls, t.ID, id, t.ClusterID, t.Raft, fs, t.ErrorC, t.term)
 	addPeerToProber(t.prober, id.String(), us)
 	addPeerToProber(t.prober, id.String(), us)
 }
 }
 
 

+ 10 - 13
rafthttp/transport_bench_test.go

@@ -15,7 +15,6 @@
 package rafthttp
 package rafthttp
 
 
 import (
 import (
-	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
 	"sync"
 	"sync"
 	"testing"
 	"testing"
@@ -31,12 +30,11 @@ import (
 func BenchmarkSendingMsgApp(b *testing.B) {
 func BenchmarkSendingMsgApp(b *testing.B) {
 	// member 1
 	// member 1
 	tr := &Transport{
 	tr := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(1),
-		ClusterID:    types.ID(1),
-		Raft:         &fakeRaft{},
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("1"),
+		ID:          types.ID(1),
+		ClusterID:   types.ID(1),
+		Raft:        &fakeRaft{},
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("1"),
 	}
 	}
 	tr.Start()
 	tr.Start()
 	srv := httptest.NewServer(tr.Handler())
 	srv := httptest.NewServer(tr.Handler())
@@ -45,12 +43,11 @@ func BenchmarkSendingMsgApp(b *testing.B) {
 	// member 2
 	// member 2
 	r := &countRaft{}
 	r := &countRaft{}
 	tr2 := &Transport{
 	tr2 := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(2),
-		ClusterID:    types.ID(1),
-		Raft:         r,
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("2"),
+		ID:          types.ID(2),
+		ClusterID:   types.ID(1),
+		Raft:        r,
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("2"),
 	}
 	}
 	tr2.Start()
 	tr2.Start()
 	srv2 := httptest.NewServer(tr2.Handler())
 	srv2 := httptest.NewServer(tr2.Handler())

+ 15 - 14
rafthttp/transport_test.go

@@ -70,11 +70,11 @@ func TestTransportAdd(t *testing.T) {
 	ls := stats.NewLeaderStats("")
 	ls := stats.NewLeaderStats("")
 	term := uint64(10)
 	term := uint64(10)
 	tr := &Transport{
 	tr := &Transport{
-		RoundTripper: &roundTripperRecorder{},
-		LeaderStats:  ls,
-		term:         term,
-		peers:        make(map[types.ID]Peer),
-		prober:       probing.NewProber(nil),
+		LeaderStats: ls,
+		streamRt:    &roundTripperRecorder{},
+		term:        term,
+		peers:       make(map[types.ID]Peer),
+		prober:      probing.NewProber(nil),
 	}
 	}
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 
 
@@ -103,10 +103,10 @@ func TestTransportAdd(t *testing.T) {
 
 
 func TestTransportRemove(t *testing.T) {
 func TestTransportRemove(t *testing.T) {
 	tr := &Transport{
 	tr := &Transport{
-		RoundTripper: &roundTripperRecorder{},
-		LeaderStats:  stats.NewLeaderStats(""),
-		peers:        make(map[types.ID]Peer),
-		prober:       probing.NewProber(nil),
+		LeaderStats: stats.NewLeaderStats(""),
+		streamRt:    &roundTripperRecorder{},
+		peers:       make(map[types.ID]Peer),
+		prober:      probing.NewProber(nil),
 	}
 	}
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 	tr.RemovePeer(types.ID(1))
 	tr.RemovePeer(types.ID(1))
@@ -134,11 +134,12 @@ func TestTransportUpdate(t *testing.T) {
 func TestTransportErrorc(t *testing.T) {
 func TestTransportErrorc(t *testing.T) {
 	errorc := make(chan error, 1)
 	errorc := make(chan error, 1)
 	tr := &Transport{
 	tr := &Transport{
-		RoundTripper: newRespRoundTripper(http.StatusForbidden, nil),
-		LeaderStats:  stats.NewLeaderStats(""),
-		ErrorC:       errorc,
-		peers:        make(map[types.ID]Peer),
-		prober:       probing.NewProber(nil),
+		LeaderStats: stats.NewLeaderStats(""),
+		ErrorC:      errorc,
+		streamRt:    newRespRoundTripper(http.StatusForbidden, nil),
+		pipelineRt:  newRespRoundTripper(http.StatusForbidden, nil),
+		peers:       make(map[types.ID]Peer),
+		prober:      probing.NewProber(nil),
 	}
 	}
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 	defer tr.Stop()
 	defer tr.Stop()