Browse Source

Merge pull request #3664 from yichengq/transport-more

rafthttp: build transport inside pkg instead of passed-in
Xiang Li 10 years ago
parent
commit
df7074911e

+ 1 - 12
etcdmain/etcd.go

@@ -199,11 +199,6 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 		return nil, fmt.Errorf("error setting up initial cluster: %v", err)
 	}
 
-	pt, err := transport.NewTimeoutTransport(cfg.peerTLSInfo, peerDialTimeout(cfg.ElectionMs), rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout)
-	if err != nil {
-		return nil, err
-	}
-
 	if !cfg.peerTLSInfo.Empty() {
 		plog.Infof("peerTLS: %s", cfg.peerTLSInfo)
 	}
@@ -284,7 +279,7 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 		DiscoveryProxy:      cfg.dproxy,
 		NewCluster:          cfg.isNewCluster(),
 		ForceNewCluster:     cfg.forceNewCluster,
-		Transport:           pt,
+		PeerTLSInfo:         cfg.peerTLSInfo,
 		TickMs:              cfg.TickMs,
 		ElectionTicks:       cfg.electionTicks(),
 		V3demo:              cfg.v3demo,
@@ -534,9 +529,3 @@ func setupLogging(cfg *config) {
 		repoLog.SetLogLevel(settings)
 	}
 }
-
-func peerDialTimeout(electionMs uint) time.Duration {
-	// 1s for queue wait and system delay
-	// + one RTT, which is smaller than 1/5 election timeout
-	return time.Second + time.Duration(electionMs)*time.Millisecond/5
-}

+ 8 - 2
etcdserver/config.go

@@ -16,13 +16,13 @@ package etcdserver
 
 import (
 	"fmt"
-	"net/http"
 	"path"
 	"sort"
 	"strings"
 	"time"
 
 	"github.com/coreos/etcd/pkg/netutil"
+	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 )
 
@@ -44,7 +44,7 @@ type ServerConfig struct {
 	InitialClusterToken string
 	NewCluster          bool
 	ForceNewCluster     bool
-	Transport           *http.Transport
+	PeerTLSInfo         transport.TLSInfo
 
 	TickMs        uint
 	ElectionTicks int
@@ -132,6 +132,12 @@ func (c *ServerConfig) ReqTimeout() time.Duration {
 	return 5*time.Second + 2*time.Duration(c.ElectionTicks)*time.Duration(c.TickMs)*time.Millisecond
 }
 
+func (c *ServerConfig) peerDialTimeout() time.Duration {
+	// 1s for queue wait and system delay
+	// + one RTT, which is smaller than 1/5 election timeout
+	return time.Second + time.Duration(c.ElectionTicks)*time.Duration(c.TickMs)*time.Millisecond/5
+}
+
 func (c *ServerConfig) PrintWithInitial() { c.print(true) }
 
 func (c *ServerConfig) Print() { c.print(false) }

+ 24 - 13
etcdserver/server.go

@@ -40,6 +40,7 @@ import (
 	"github.com/coreos/etcd/pkg/pbutil"
 	"github.com/coreos/etcd/pkg/runtime"
 	"github.com/coreos/etcd/pkg/timeutil"
+	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/pkg/wait"
 	"github.com/coreos/etcd/raft"
@@ -167,7 +168,9 @@ type EtcdServer struct {
 
 	SyncTicker <-chan time.Time
 
-	reqIDGen *idutil.Generator
+	// versionTr used to send requests for peer version
+	versionTr *http.Transport
+	reqIDGen  *idutil.Generator
 
 	// forceVersionC is used to force the version monitor loop
 	// to detect the cluster version immediately.
@@ -205,6 +208,10 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 	haveWAL := wal.Exist(cfg.WALDir())
 	ss := snap.New(cfg.SnapDir())
 
+	pt, err := transport.NewTransport(cfg.PeerTLSInfo, cfg.peerDialTimeout())
+	if err != nil {
+		return nil, err
+	}
 	var remotes []*Member
 	switch {
 	case !haveWAL && !cfg.NewCluster:
@@ -215,14 +222,14 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 		if err != nil {
 			return nil, err
 		}
-		existingCluster, err := GetClusterFromRemotePeers(getRemotePeerURLs(cl, cfg.Name), cfg.Transport)
+		existingCluster, err := GetClusterFromRemotePeers(getRemotePeerURLs(cl, cfg.Name), pt)
 		if err != nil {
 			return nil, fmt.Errorf("cannot fetch cluster info from peer urls: %v", err)
 		}
 		if err := ValidateClusterAndAssignIDs(cl, existingCluster); err != nil {
 			return nil, fmt.Errorf("error validating peerURLs %s: %v", existingCluster, err)
 		}
-		if !isCompatibleWithCluster(cl, cl.MemberByName(cfg.Name).ID, cfg.Transport) {
+		if !isCompatibleWithCluster(cl, cl.MemberByName(cfg.Name).ID, pt) {
 			return nil, fmt.Errorf("incomptible with current running cluster")
 		}
 
@@ -240,7 +247,7 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 			return nil, err
 		}
 		m := cl.MemberByName(cfg.Name)
-		if isMemberBootstrapped(cl, cfg.Name, cfg.Transport) {
+		if isMemberBootstrapped(cl, cfg.Name, pt) {
 			return nil, fmt.Errorf("member %s has already been bootstrapped", m.ID)
 		}
 		if cfg.ShouldDiscover() {
@@ -328,6 +335,7 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 		stats:         sstats,
 		lstats:        lstats,
 		SyncTicker:    time.Tick(500 * time.Millisecond),
+		versionTr:     pt,
 		reqIDGen:      idutil.NewGenerator(uint8(id), time.Now()),
 		forceVersionC: make(chan struct{}),
 	}
@@ -346,15 +354,18 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) {
 
 	// TODO: move transport initialization near the definition of remote
 	tr := &rafthttp.Transport{
-		RoundTripper: cfg.Transport,
-		ID:           id,
-		ClusterID:    cl.ID(),
-		Raft:         srv,
-		ServerStats:  sstats,
-		LeaderStats:  lstats,
-		ErrorC:       srv.errorc,
+		TLSInfo:     cfg.PeerTLSInfo,
+		DialTimeout: cfg.peerDialTimeout(),
+		ID:          id,
+		ClusterID:   cl.ID(),
+		Raft:        srv,
+		ServerStats: sstats,
+		LeaderStats: lstats,
+		ErrorC:      srv.errorc,
+	}
+	if err := tr.Start(); err != nil {
+		return nil, err
 	}
-	tr.Start()
 	// add all remotes into transport
 	for _, m := range remotes {
 		if m.ID != id {
@@ -1032,7 +1043,7 @@ func (s *EtcdServer) monitorVersions() {
 			continue
 		}
 
-		v := decideClusterVersion(getVersions(s.cluster, s.id, s.cfg.Transport))
+		v := decideClusterVersion(getVersions(s.cluster, s.id, s.versionTr))
 		if v != nil {
 			// only keep major.minor version for comparasion
 			v = &semver.Version{

+ 1 - 1
etcdserver/server_test.go

@@ -1468,7 +1468,7 @@ func (n *readyNode) Ready() <-chan raft.Ready { return n.readyc }
 
 type nopTransporter struct{}
 
-func (s *nopTransporter) Start()                              {}
+func (s *nopTransporter) Start() error                        { return nil }
 func (s *nopTransporter) Handler() http.Handler               { return nil }
 func (s *nopTransporter) Send(m []raftpb.Message)             {}
 func (s *nopTransporter) AddRemote(id types.ID, us []string)  {}

+ 1 - 2
integration/cluster_test.go

@@ -688,7 +688,7 @@ func mustNewMember(t *testing.T, name string, usePeerTLS bool) *member {
 	}
 	m.InitialClusterToken = clusterName
 	m.NewCluster = true
-	m.Transport = mustNewTransport(t, m.PeerTLSInfo)
+	m.ServerConfig.PeerTLSInfo = m.PeerTLSInfo
 	m.ElectionTicks = electionTicks
 	m.TickMs = uint(tickDuration / time.Millisecond)
 	return m
@@ -720,7 +720,6 @@ func (m *member) Clone(t *testing.T) *member {
 		panic(err)
 	}
 	mm.InitialClusterToken = m.InitialClusterToken
-	mm.Transport = mustNewTransport(t, m.PeerTLSInfo)
 	mm.ElectionTicks = m.ElectionTicks
 	mm.PeerTLSInfo = m.PeerTLSInfo
 	return mm

+ 20 - 25
rafthttp/functional_test.go

@@ -15,7 +15,6 @@
 package rafthttp
 
 import (
-	"net/http"
 	"net/http/httptest"
 	"reflect"
 	"testing"
@@ -31,12 +30,11 @@ import (
 func TestSendMessage(t *testing.T) {
 	// member 1
 	tr := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(1),
-		ClusterID:    types.ID(1),
-		Raft:         &fakeRaft{},
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("1"),
+		ID:          types.ID(1),
+		ClusterID:   types.ID(1),
+		Raft:        &fakeRaft{},
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("1"),
 	}
 	tr.Start()
 	srv := httptest.NewServer(tr.Handler())
@@ -46,12 +44,11 @@ func TestSendMessage(t *testing.T) {
 	recvc := make(chan raftpb.Message, 1)
 	p := &fakeRaft{recvc: recvc}
 	tr2 := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(2),
-		ClusterID:    types.ID(1),
-		Raft:         p,
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("2"),
+		ID:          types.ID(2),
+		ClusterID:   types.ID(1),
+		Raft:        p,
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("2"),
 	}
 	tr2.Start()
 	srv2 := httptest.NewServer(tr2.Handler())
@@ -92,12 +89,11 @@ func TestSendMessage(t *testing.T) {
 func TestSendMessageWhenStreamIsBroken(t *testing.T) {
 	// member 1
 	tr := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(1),
-		ClusterID:    types.ID(1),
-		Raft:         &fakeRaft{},
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("1"),
+		ID:          types.ID(1),
+		ClusterID:   types.ID(1),
+		Raft:        &fakeRaft{},
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("1"),
 	}
 	tr.Start()
 	srv := httptest.NewServer(tr.Handler())
@@ -107,12 +103,11 @@ func TestSendMessageWhenStreamIsBroken(t *testing.T) {
 	recvc := make(chan raftpb.Message, 1)
 	p := &fakeRaft{recvc: recvc}
 	tr2 := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(2),
-		ClusterID:    types.ID(1),
-		Raft:         p,
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("2"),
+		ID:          types.ID(2),
+		ClusterID:   types.ID(1),
+		Raft:        p,
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("2"),
 	}
 	tr2.Start()
 	srv2 := httptest.NewServer(tr2.Handler())

+ 4 - 4
rafthttp/peer.go

@@ -111,7 +111,7 @@ type peer struct {
 	done  chan struct{}
 }
 
-func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r Raft, fs *stats.FollowerStats, errorc chan error, term uint64) *peer {
+func startPeer(streamRt, pipelineRt http.RoundTripper, urls types.URLs, local, to, cid types.ID, r Raft, fs *stats.FollowerStats, errorc chan error, term uint64) *peer {
 	picker := newURLPicker(urls)
 	status := newPeerStatus(to)
 	p := &peer{
@@ -120,7 +120,7 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
 		status:       status,
 		msgAppWriter: startStreamWriter(to, status, fs, r),
 		writer:       startStreamWriter(to, status, fs, r),
-		pipeline:     newPipeline(tr, picker, local, to, cid, status, fs, r, errorc),
+		pipeline:     newPipeline(pipelineRt, picker, local, to, cid, status, fs, r, errorc),
 		sendc:        make(chan raftpb.Message),
 		recvc:        make(chan raftpb.Message, recvBufSize),
 		propc:        make(chan raftpb.Message, maxPendingProposals),
@@ -148,8 +148,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
 		}
 	}()
 
-	p.msgAppReader = startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc, term)
-	reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc, term)
+	p.msgAppReader = startStreamReader(streamRt, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc, term)
+	reader := startStreamReader(streamRt, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc, term)
 	go func() {
 		var paused bool
 		for {

+ 33 - 11
rafthttp/transport.go

@@ -23,6 +23,7 @@ import (
 	"github.com/coreos/etcd/Godeps/_workspace/src/github.com/xiang90/probing"
 	"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/net/context"
 	"github.com/coreos/etcd/etcdserver/stats"
+	"github.com/coreos/etcd/pkg/transport"
 	"github.com/coreos/etcd/pkg/types"
 	"github.com/coreos/etcd/raft"
 	"github.com/coreos/etcd/raft/raftpb"
@@ -40,7 +41,7 @@ type Raft interface {
 type Transporter interface {
 	// Start starts the given Transporter.
 	// Start MUST be called before calling other functions in the interface.
-	Start()
+	Start() error
 	// Handler returns the HTTP handler of the transporter.
 	// A transporter HTTP handler handles the HTTP requests
 	// from remote peers.
@@ -88,11 +89,13 @@ type Transporter interface {
 // User needs to call Start before calling other functions, and call
 // Stop when the Transport is no longer used.
 type Transport struct {
-	RoundTripper http.RoundTripper  // roundTripper to send requests
-	ID           types.ID           // local member ID
-	ClusterID    types.ID           // raft cluster ID for request validation
-	Raft         Raft               // raft state machine, to which the Transport forwards received messages and reports status
-	ServerStats  *stats.ServerStats // used to record general transportation statistics
+	DialTimeout time.Duration     // maximum duration before timing out dial of the request
+	TLSInfo     transport.TLSInfo // TLS information used when creating connection
+
+	ID          types.ID           // local member ID
+	ClusterID   types.ID           // raft cluster ID for request validation
+	Raft        Raft               // raft state machine, to which the Transport forwards received messages and reports status
+	ServerStats *stats.ServerStats // used to record general transportation statistics
 	// used to record transportation statistics with followers when
 	// performing as leader in raft protocol
 	LeaderStats *stats.LeaderStats
@@ -102,6 +105,9 @@ type Transport struct {
 	// machine and thus stop the Transport.
 	ErrorC chan error
 
+	streamRt   http.RoundTripper // roundTripper used by streams
+	pipelineRt http.RoundTripper // roundTripper used by pipelines
+
 	mu      sync.RWMutex         // protect the term, remote and peer map
 	term    uint64               // the latest term that has been observed
 	remotes map[types.ID]*remote // remotes map that helps newly joined member to catch up
@@ -110,10 +116,23 @@ type Transport struct {
 	prober probing.Prober
 }
 
-func (t *Transport) Start() {
+func (t *Transport) Start() error {
+	var err error
+	// Read/write timeout is set for stream roundTripper to promptly
+	// find out broken status, which minimizes the number of messages
+	// sent on broken connection.
+	t.streamRt, err = transport.NewTimeoutTransport(t.TLSInfo, t.DialTimeout, ConnReadTimeout, ConnWriteTimeout)
+	if err != nil {
+		return err
+	}
+	t.pipelineRt, err = transport.NewTransport(t.TLSInfo, t.DialTimeout)
+	if err != nil {
+		return err
+	}
 	t.remotes = make(map[types.ID]*remote)
 	t.peers = make(map[types.ID]Peer)
-	t.prober = probing.NewProber(t.RoundTripper)
+	t.prober = probing.NewProber(t.pipelineRt)
+	return nil
 }
 
 func (t *Transport) Handler() http.Handler {
@@ -183,7 +202,10 @@ func (t *Transport) Stop() {
 		p.Stop()
 	}
 	t.prober.RemoveAll()
-	if tr, ok := t.RoundTripper.(*http.Transport); ok {
+	if tr, ok := t.streamRt.(*http.Transport); ok {
+		tr.CloseIdleConnections()
+	}
+	if tr, ok := t.pipelineRt.(*http.Transport); ok {
 		tr.CloseIdleConnections()
 	}
 }
@@ -198,7 +220,7 @@ func (t *Transport) AddRemote(id types.ID, us []string) {
 	if err != nil {
 		plog.Panicf("newURLs %+v should never fail: %+v", us, err)
 	}
-	t.remotes[id] = startRemote(t.RoundTripper, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC)
+	t.remotes[id] = startRemote(t.pipelineRt, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC)
 }
 
 func (t *Transport) AddPeer(id types.ID, us []string) {
@@ -212,7 +234,7 @@ func (t *Transport) AddPeer(id types.ID, us []string) {
 		plog.Panicf("newURLs %+v should never fail: %+v", us, err)
 	}
 	fs := t.LeaderStats.Follower(id.String())
-	t.peers[id] = startPeer(t.RoundTripper, urls, t.ID, id, t.ClusterID, t.Raft, fs, t.ErrorC, t.term)
+	t.peers[id] = startPeer(t.streamRt, t.pipelineRt, urls, t.ID, id, t.ClusterID, t.Raft, fs, t.ErrorC, t.term)
 	addPeerToProber(t.prober, id.String(), us)
 }
 

+ 10 - 13
rafthttp/transport_bench_test.go

@@ -15,7 +15,6 @@
 package rafthttp
 
 import (
-	"net/http"
 	"net/http/httptest"
 	"sync"
 	"testing"
@@ -31,12 +30,11 @@ import (
 func BenchmarkSendingMsgApp(b *testing.B) {
 	// member 1
 	tr := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(1),
-		ClusterID:    types.ID(1),
-		Raft:         &fakeRaft{},
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("1"),
+		ID:          types.ID(1),
+		ClusterID:   types.ID(1),
+		Raft:        &fakeRaft{},
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("1"),
 	}
 	tr.Start()
 	srv := httptest.NewServer(tr.Handler())
@@ -45,12 +43,11 @@ func BenchmarkSendingMsgApp(b *testing.B) {
 	// member 2
 	r := &countRaft{}
 	tr2 := &Transport{
-		RoundTripper: &http.Transport{},
-		ID:           types.ID(2),
-		ClusterID:    types.ID(1),
-		Raft:         r,
-		ServerStats:  newServerStats(),
-		LeaderStats:  stats.NewLeaderStats("2"),
+		ID:          types.ID(2),
+		ClusterID:   types.ID(1),
+		Raft:        r,
+		ServerStats: newServerStats(),
+		LeaderStats: stats.NewLeaderStats("2"),
 	}
 	tr2.Start()
 	srv2 := httptest.NewServer(tr2.Handler())

+ 15 - 14
rafthttp/transport_test.go

@@ -70,11 +70,11 @@ func TestTransportAdd(t *testing.T) {
 	ls := stats.NewLeaderStats("")
 	term := uint64(10)
 	tr := &Transport{
-		RoundTripper: &roundTripperRecorder{},
-		LeaderStats:  ls,
-		term:         term,
-		peers:        make(map[types.ID]Peer),
-		prober:       probing.NewProber(nil),
+		LeaderStats: ls,
+		streamRt:    &roundTripperRecorder{},
+		term:        term,
+		peers:       make(map[types.ID]Peer),
+		prober:      probing.NewProber(nil),
 	}
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 
@@ -103,10 +103,10 @@ func TestTransportAdd(t *testing.T) {
 
 func TestTransportRemove(t *testing.T) {
 	tr := &Transport{
-		RoundTripper: &roundTripperRecorder{},
-		LeaderStats:  stats.NewLeaderStats(""),
-		peers:        make(map[types.ID]Peer),
-		prober:       probing.NewProber(nil),
+		LeaderStats: stats.NewLeaderStats(""),
+		streamRt:    &roundTripperRecorder{},
+		peers:       make(map[types.ID]Peer),
+		prober:      probing.NewProber(nil),
 	}
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 	tr.RemovePeer(types.ID(1))
@@ -134,11 +134,12 @@ func TestTransportUpdate(t *testing.T) {
 func TestTransportErrorc(t *testing.T) {
 	errorc := make(chan error, 1)
 	tr := &Transport{
-		RoundTripper: newRespRoundTripper(http.StatusForbidden, nil),
-		LeaderStats:  stats.NewLeaderStats(""),
-		ErrorC:       errorc,
-		peers:        make(map[types.ID]Peer),
-		prober:       probing.NewProber(nil),
+		LeaderStats: stats.NewLeaderStats(""),
+		ErrorC:      errorc,
+		streamRt:    newRespRoundTripper(http.StatusForbidden, nil),
+		pipelineRt:  newRespRoundTripper(http.StatusForbidden, nil),
+		peers:       make(map[types.ID]Peer),
+		prober:      probing.NewProber(nil),
 	}
 	tr.AddPeer(1, []string{"http://localhost:2380"})
 	defer tr.Stop()