Explorar o código

*: http and https on the same port

Xiang Li %!s(int64=9) %!d(string=hai) anos
pai
achega
900a61b023

+ 70 - 42
etcdmain/etcd.go

@@ -18,6 +18,7 @@
 package etcdmain
 package etcdmain
 
 
 import (
 import (
+	"crypto/tls"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
@@ -33,7 +34,6 @@ import (
 
 
 	"github.com/coreos/etcd/discovery"
 	"github.com/coreos/etcd/discovery"
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver"
-	"github.com/coreos/etcd/etcdserver/api/v3rpc"
 	"github.com/coreos/etcd/etcdserver/etcdhttp"
 	"github.com/coreos/etcd/etcdserver/etcdhttp"
 	"github.com/coreos/etcd/pkg/cors"
 	"github.com/coreos/etcd/pkg/cors"
 	"github.com/coreos/etcd/pkg/fileutil"
 	"github.com/coreos/etcd/pkg/fileutil"
@@ -49,7 +49,6 @@ import (
 	systemdutil "github.com/coreos/go-systemd/util"
 	systemdutil "github.com/coreos/go-systemd/util"
 	"github.com/coreos/pkg/capnslog"
 	"github.com/coreos/pkg/capnslog"
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/prometheus/client_golang/prometheus"
-	"google.golang.org/grpc"
 )
 )
 
 
 type dirType string
 type dirType string
@@ -220,14 +219,24 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 	if !cfg.peerTLSInfo.Empty() {
 	if !cfg.peerTLSInfo.Empty() {
 		plog.Infof("peerTLS: %s", cfg.peerTLSInfo)
 		plog.Infof("peerTLS: %s", cfg.peerTLSInfo)
 	}
 	}
-
 	plns := make([]net.Listener, 0)
 	plns := make([]net.Listener, 0)
 	for _, u := range cfg.lpurls {
 	for _, u := range cfg.lpurls {
 		if u.Scheme == "http" && !cfg.peerTLSInfo.Empty() {
 		if u.Scheme == "http" && !cfg.peerTLSInfo.Empty() {
 			plog.Warningf("The scheme of peer url %s is http while peer key/cert files are presented. Ignored peer key/cert files.", u.String())
 			plog.Warningf("The scheme of peer url %s is http while peer key/cert files are presented. Ignored peer key/cert files.", u.String())
 		}
 		}
-		var l net.Listener
-		l, err = rafthttp.NewListener(u, cfg.peerTLSInfo)
+		var (
+			l      net.Listener
+			tlscfg *tls.Config
+		)
+
+		if !cfg.peerTLSInfo.Empty() {
+			tlscfg, err = cfg.peerTLSInfo.ServerConfig()
+			if err != nil {
+				return nil, err
+			}
+		}
+
+		l, err = rafthttp.NewListener(u, tlscfg)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -243,15 +252,40 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 		plns = append(plns, l)
 		plns = append(plns, l)
 	}
 	}
 
 
+	var ctlscfg *tls.Config
 	if !cfg.clientTLSInfo.Empty() {
 	if !cfg.clientTLSInfo.Empty() {
 		plog.Infof("clientTLS: %s", cfg.clientTLSInfo)
 		plog.Infof("clientTLS: %s", cfg.clientTLSInfo)
+		ctlscfg, err = cfg.clientTLSInfo.ServerConfig()
+		if err != nil {
+			return nil, err
+		}
 	}
 	}
-	clns := make([]net.Listener, 0)
+	sctxs := make(map[string]*serveCtx)
 	for _, u := range cfg.lcurls {
 	for _, u := range cfg.lcurls {
 		if u.Scheme == "http" && !cfg.clientTLSInfo.Empty() {
 		if u.Scheme == "http" && !cfg.clientTLSInfo.Empty() {
 			plog.Warningf("The scheme of client url %s is http while client key/cert files are presented. Ignored client key/cert files.", u.String())
 			plog.Warningf("The scheme of client url %s is http while client key/cert files are presented. Ignored client key/cert files.", u.String())
 		}
 		}
+
+		ctx := &serveCtx{host: u.Host}
+
+		if u.Scheme == "https" {
+			ctx.secure = true
+		} else {
+			ctx.insecure = true
+		}
+
+		if sctxs[u.Host] != nil {
+			if ctx.secure {
+				sctxs[u.Host].secure = true
+			}
+			if ctx.insecure {
+				sctxs[u.Host].insecure = true
+			}
+			continue
+		}
+
 		var l net.Listener
 		var l net.Listener
+
 		l, err = net.Listen("tcp", u.Host)
 		l, err = net.Listen("tcp", u.Host)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
@@ -265,22 +299,20 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 			l = transport.LimitListener(l, int(fdLimit-reservedInternalFDNum))
 			l = transport.LimitListener(l, int(fdLimit-reservedInternalFDNum))
 		}
 		}
 
 
-		// Do not wrap around this listener if TLS Info is set.
-		// HTTPS server expects TLS Conn created by TLSListener.
-		l, err = transport.NewKeepAliveListener(l, u.Scheme, cfg.clientTLSInfo)
+		l, err = transport.NewKeepAliveListener(l, "tcp", nil)
+		ctx.l = l
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		urlStr := u.String()
-		plog.Info("listening for client requests on ", urlStr)
+		plog.Info("listening for client requests on ", u.Host)
 		defer func() {
 		defer func() {
 			if err != nil {
 			if err != nil {
 				l.Close()
 				l.Close()
-				plog.Info("stopping listening for client requests on ", urlStr)
+				plog.Info("stopping listening for client requests on ", u.Host)
 			}
 			}
 		}()
 		}()
-		clns = append(clns, l)
+		sctxs[u.Host] = ctx
 	}
 	}
 
 
 	srvcfg := &etcdserver.ServerConfig{
 	srvcfg := &etcdserver.ServerConfig{
@@ -317,40 +349,25 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 	if cfg.corsInfo.String() != "" {
 	if cfg.corsInfo.String() != "" {
 		plog.Infof("cors = %s", cfg.corsInfo)
 		plog.Infof("cors = %s", cfg.corsInfo)
 	}
 	}
-	ch := &cors.CORSHandler{
+	ch := http.Handler(&cors.CORSHandler{
 		Handler: etcdhttp.NewClientHandler(s, srvcfg.ReqTimeout()),
 		Handler: etcdhttp.NewClientHandler(s, srvcfg.ReqTimeout()),
 		Info:    cfg.corsInfo,
 		Info:    cfg.corsInfo,
-	}
+	})
 	ph := etcdhttp.NewPeerHandler(s)
 	ph := etcdhttp.NewPeerHandler(s)
 
 
-	var grpcS *grpc.Server
-	if cfg.v3demo {
-		// set up v3 demo rpc
-		tls := &cfg.clientTLSInfo
-		if cfg.clientTLSInfo.Empty() {
-			tls = nil
-		}
-		grpcS, err = v3rpc.Server(s, tls)
-		if err != nil {
-			s.Stop()
-			<-s.StopNotify()
-			return nil, err
-		}
-	}
-
 	// Start the peer server in a goroutine
 	// Start the peer server in a goroutine
 	for _, l := range plns {
 	for _, l := range plns {
 		go func(l net.Listener) {
 		go func(l net.Listener) {
-			plog.Fatal(serve(l, nil, ph, 5*time.Minute))
+			plog.Fatal(servePeerHTTP(l, ph))
 		}(l)
 		}(l)
 	}
 	}
 	// Start a client server goroutine for each listen address
 	// Start a client server goroutine for each listen address
-	for _, l := range clns {
-		go func(l net.Listener) {
+	for _, sctx := range sctxs {
+		go func(sctx *serveCtx) {
 			// read timeout does not work with http close notify
 			// read timeout does not work with http close notify
 			// TODO: https://github.com/golang/go/issues/9524
 			// TODO: https://github.com/golang/go/issues/9524
-			plog.Fatal(serve(l, grpcS, ch, 0))
-		}(l)
+			plog.Fatal(serve(sctx, s, ctlscfg, ch))
+		}(sctx)
 	}
 	}
 
 
 	return s.StopNotify(), nil
 	return s.StopNotify(), nil
@@ -419,11 +436,11 @@ func startProxy(cfg *config) error {
 
 
 	clientURLs := []string{}
 	clientURLs := []string{}
 	uf := func() []string {
 	uf := func() []string {
-		gcls, err := etcdserver.GetClusterFromRemotePeers(peerURLs, tr)
+		gcls, gerr := etcdserver.GetClusterFromRemotePeers(peerURLs, tr)
 		// TODO: remove the 2nd check when we fix GetClusterFromRemotePeers
 		// TODO: remove the 2nd check when we fix GetClusterFromRemotePeers
 		// GetClusterFromRemotePeers should not return nil error with an invalid empty cluster
 		// GetClusterFromRemotePeers should not return nil error with an invalid empty cluster
-		if err != nil {
-			plog.Warningf("proxy: %v", err)
+		if gerr != nil {
+			plog.Warningf("proxy: %v", gerr)
 			return []string{}
 			return []string{}
 		}
 		}
 		if len(gcls.Members()) == 0 {
 		if len(gcls.Members()) == 0 {
@@ -432,9 +449,9 @@ func startProxy(cfg *config) error {
 		clientURLs = gcls.ClientURLs()
 		clientURLs = gcls.ClientURLs()
 
 
 		urls := struct{ PeerURLs []string }{gcls.PeerURLs()}
 		urls := struct{ PeerURLs []string }{gcls.PeerURLs()}
-		b, err := json.Marshal(urls)
-		if err != nil {
-			plog.Warningf("proxy: error on marshal peer urls %s", err)
+		b, jerr := json.Marshal(urls)
+		if jerr != nil {
+			plog.Warningf("proxy: error on marshal peer urls %s", jerr)
 			return clientURLs
 			return clientURLs
 		}
 		}
 
 
@@ -466,7 +483,18 @@ func startProxy(cfg *config) error {
 	}
 	}
 	// Start a proxy server goroutine for each listen address
 	// Start a proxy server goroutine for each listen address
 	for _, u := range cfg.lcurls {
 	for _, u := range cfg.lcurls {
-		l, err := transport.NewListener(u.Host, u.Scheme, cfg.clientTLSInfo)
+		var (
+			l      net.Listener
+			tlscfg *tls.Config
+		)
+		if !cfg.clientTLSInfo.Empty() {
+			tlscfg, err = cfg.clientTLSInfo.ServerConfig()
+			if err != nil {
+				return err
+			}
+		}
+
+		l, err := transport.NewListener(u.Host, u.Scheme, tlscfg)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}

+ 62 - 12
etcdmain/serve.go

@@ -15,37 +15,87 @@
 package etcdmain
 package etcdmain
 
 
 import (
 import (
+	"crypto/tls"
 	"io/ioutil"
 	"io/ioutil"
 	defaultLog "log"
 	defaultLog "log"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"strings"
 	"time"
 	"time"
 
 
 	"github.com/cockroachdb/cmux"
 	"github.com/cockroachdb/cmux"
+	"github.com/coreos/etcd/etcdserver"
+	"github.com/coreos/etcd/etcdserver/api/v3rpc"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc"
 )
 )
 
 
+type serveCtx struct {
+	l        net.Listener
+	host     string
+	secure   bool
+	insecure bool
+}
+
 // serve accepts incoming connections on the listener l,
 // serve accepts incoming connections on the listener l,
 // creating a new service goroutine for each. The service goroutines
 // creating a new service goroutine for each. The service goroutines
 // read requests and then call handler to reply to them.
 // read requests and then call handler to reply to them.
-func serve(l net.Listener, grpcS *grpc.Server, handler http.Handler, readTimeout time.Duration) error {
-	// TODO: assert net.Listener type? Arbitrary listener might break HTTPS server which
-	// expect a TLS Conn type.
-	httpl := l
-	if grpcS != nil {
-		m := cmux.New(l)
-		grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc"))
-		httpl = m.Match(cmux.Any())
-		go func() { plog.Fatal(m.Serve()) }()
-		go plog.Fatal(grpcS.Serve(grpcl))
+func serve(sctx *serveCtx, s *etcdserver.EtcdServer, tlscfg *tls.Config, handler http.Handler) error {
+	logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
+
+	m := cmux.New(sctx.l)
+
+	if sctx.insecure {
+		gs := v3rpc.Server(s, nil)
+		grpcl := m.Match(cmux.HTTP2())
+		go func() { plog.Fatal(gs.Serve(grpcl)) }()
+
+		srvhttp := &http.Server{
+			Handler:  handler,
+			ErrorLog: logger, // do not log user error
+		}
+		httpl := m.Match(cmux.HTTP1())
+		go func() { plog.Fatal(srvhttp.Serve(httpl)) }()
+		plog.Noticef("serving insecure client requests on %s, this is strongly discouraged!", sctx.host)
+	}
+
+	if sctx.secure {
+		gs := v3rpc.Server(s, tlscfg)
+		handler = grpcHandlerFunc(gs, handler)
+
+		tlsl := tls.NewListener(m.Match(cmux.Any()), tlscfg)
+		// TODO: add debug flag; enable logging when debug flag is set
+		srv := &http.Server{
+			Handler:   handler,
+			TLSConfig: tlscfg,
+			ErrorLog:  logger, // do not log user error
+		}
+		go func() { plog.Fatal(srv.Serve(tlsl)) }()
+
+		plog.Infof("serving client requests on %s", sctx.host)
 	}
 	}
 
 
+	return m.Serve()
+}
+
+// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC
+// connections or otherHandler otherwise. Copied from cockroachdb.
+func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
+			grpcServer.ServeHTTP(w, r)
+		} else {
+			otherHandler.ServeHTTP(w, r)
+		}
+	})
+}
+
+func servePeerHTTP(l net.Listener, handler http.Handler) error {
 	logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
 	logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
 	// TODO: add debug flag; enable logging when debug flag is set
 	// TODO: add debug flag; enable logging when debug flag is set
 	srv := &http.Server{
 	srv := &http.Server{
 		Handler:     handler,
 		Handler:     handler,
-		ReadTimeout: readTimeout,
+		ReadTimeout: 5 * time.Minute,
 		ErrorLog:    logger, // do not log user error
 		ErrorLog:    logger, // do not log user error
 	}
 	}
-	return srv.Serve(httpl)
+	return srv.Serve(l)
 }
 }

+ 5 - 8
etcdserver/api/v3rpc/grpc.go

@@ -15,21 +15,18 @@
 package v3rpc
 package v3rpc
 
 
 import (
 import (
+	"crypto/tls"
+
 	"github.com/coreos/etcd/etcdserver"
 	"github.com/coreos/etcd/etcdserver"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
-	"github.com/coreos/etcd/pkg/transport"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/credentials"
 	"google.golang.org/grpc/credentials"
 )
 )
 
 
-func Server(s *etcdserver.EtcdServer, tls *transport.TLSInfo) (*grpc.Server, error) {
+func Server(s *etcdserver.EtcdServer, tls *tls.Config) *grpc.Server {
 	var opts []grpc.ServerOption
 	var opts []grpc.ServerOption
 	if tls != nil {
 	if tls != nil {
-		creds, err := credentials.NewServerTLSFromFile(tls.CertFile, tls.KeyFile)
-		if err != nil {
-			return nil, err
-		}
-		opts = append(opts, grpc.Creds(creds))
+		opts = append(opts, grpc.Creds(credentials.NewTLS(tls)))
 	}
 	}
 
 
 	grpcServer := grpc.NewServer(opts...)
 	grpcServer := grpc.NewServer(opts...)
@@ -39,5 +36,5 @@ func Server(s *etcdserver.EtcdServer, tls *transport.TLSInfo) (*grpc.Server, err
 	pb.RegisterClusterServer(grpcServer, NewClusterServer(s))
 	pb.RegisterClusterServer(grpcServer, NewClusterServer(s))
 	pb.RegisterAuthServer(grpcServer, NewAuthServer(s))
 	pb.RegisterAuthServer(grpcServer, NewAuthServer(s))
 	pb.RegisterMaintenanceServer(grpcServer, NewMaintenanceServer(s))
 	pb.RegisterMaintenanceServer(grpcServer, NewMaintenanceServer(s))
-	return grpcServer, nil
+	return grpcServer
 }
 }

+ 11 - 1
integration/cluster.go

@@ -15,6 +15,7 @@
 package integration
 package integration
 
 
 import (
 import (
+	"crypto/tls"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
 	"math/rand"
 	"math/rand"
@@ -585,7 +586,16 @@ func (m *member) Launch() error {
 		m.hss = append(m.hss, hs)
 		m.hss = append(m.hss, hs)
 	}
 	}
 	if m.grpcListener != nil {
 	if m.grpcListener != nil {
-		m.grpcServer, err = v3rpc.Server(m.s, m.ClientTLSInfo)
+		var (
+			tlscfg *tls.Config
+		)
+		if m.ClientTLSInfo != nil && !m.ClientTLSInfo.Empty() {
+			tlscfg, err = m.ClientTLSInfo.ServerConfig()
+			if err != nil {
+				return err
+			}
+		}
+		m.grpcServer = v3rpc.Server(m.s, tlscfg)
 		go m.grpcServer.Serve(m.grpcListener)
 		go m.grpcServer.Serve(m.grpcListener)
 	}
 	}
 	return nil
 	return nil

+ 3 - 8
pkg/transport/keepalive_listener.go

@@ -30,17 +30,12 @@ type keepAliveConn interface {
 // Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil.
 // Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil.
 // Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake.
 // Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake.
 // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
 // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
-func NewKeepAliveListener(l net.Listener, scheme string, info TLSInfo) (net.Listener, error) {
+func NewKeepAliveListener(l net.Listener, scheme string, tlscfg *tls.Config) (net.Listener, error) {
 	if scheme == "https" {
 	if scheme == "https" {
-		if info.Empty() {
+		if tlscfg == nil {
 			return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented")
 			return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented")
 		}
 		}
-		cfg, err := info.ServerConfig()
-		if err != nil {
-			return nil, err
-		}
-
-		return newTLSKeepaliveListener(l, cfg), nil
+		return newTLSKeepaliveListener(l, tlscfg), nil
 	}
 	}
 
 
 	return &keepaliveListener{
 	return &keepaliveListener{

+ 8 - 4
pkg/transport/keepalive_listener_test.go

@@ -31,7 +31,7 @@ func TestNewKeepAliveListener(t *testing.T) {
 		t.Fatalf("unexpected listen error: %v", err)
 		t.Fatalf("unexpected listen error: %v", err)
 	}
 	}
 
 
-	ln, err = NewKeepAliveListener(ln, "http", TLSInfo{})
+	ln, err = NewKeepAliveListener(ln, "http", nil)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 	}
 	}
@@ -53,7 +53,11 @@ func TestNewKeepAliveListener(t *testing.T) {
 	defer os.Remove(tmp)
 	defer os.Remove(tmp)
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
-	tlsln, err := NewKeepAliveListener(ln, "https", tlsInfo)
+	tlscfg, err := tlsInfo.ServerConfig()
+	if err != nil {
+		t.Fatalf("unexpected serverConfig error: %v", err)
+	}
+	tlsln, err := NewKeepAliveListener(ln, "https", tlscfg)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 	}
 	}
@@ -70,13 +74,13 @@ func TestNewKeepAliveListener(t *testing.T) {
 	tlsln.Close()
 	tlsln.Close()
 }
 }
 
 
-func TestNewKeepAliveListenerTLSEmptyInfo(t *testing.T) {
+func TestNewKeepAliveListenerTLSEmptyConfig(t *testing.T) {
 	ln, err := net.Listen("tcp", "127.0.0.1:0")
 	ln, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected listen error: %v", err)
 		t.Fatalf("unexpected listen error: %v", err)
 	}
 	}
 
 
-	_, err = NewKeepAliveListener(ln, "https", TLSInfo{})
+	_, err = NewKeepAliveListener(ln, "https", nil)
 	if err == nil {
 	if err == nil {
 		t.Errorf("err = nil, want not presented error")
 		t.Errorf("err = nil, want not presented error")
 	}
 	}

+ 3 - 7
pkg/transport/listener.go

@@ -33,7 +33,7 @@ import (
 	"time"
 	"time"
 )
 )
 
 
-func NewListener(addr string, scheme string, info TLSInfo) (net.Listener, error) {
+func NewListener(addr string, scheme string, tlscfg *tls.Config) (net.Listener, error) {
 	nettype := "tcp"
 	nettype := "tcp"
 	if scheme == "unix" {
 	if scheme == "unix" {
 		// unix sockets via unix://laddr
 		// unix sockets via unix://laddr
@@ -46,15 +46,11 @@ func NewListener(addr string, scheme string, info TLSInfo) (net.Listener, error)
 	}
 	}
 
 
 	if scheme == "https" {
 	if scheme == "https" {
-		if info.Empty() {
+		if tlscfg == nil {
 			return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
 			return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
 		}
 		}
-		cfg, err := info.ServerConfig()
-		if err != nil {
-			return nil, err
-		}
 
 
-		l = tls.NewListener(l, cfg)
+		l = tls.NewListener(l, tlscfg)
 	}
 	}
 
 
 	return l, nil
 	return l, nil

+ 20 - 16
pkg/transport/listener_test.go

@@ -58,7 +58,11 @@ func TestNewListenerTLSInfo(t *testing.T) {
 }
 }
 
 
 func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
 func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
-	ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
+	tlscfg, err := tlsInfo.ServerConfig()
+	if err != nil {
+		t.Fatalf("unexpected serverConfig error: %v", err)
+	}
+	ln, err := NewListener("127.0.0.1:0", "https", tlscfg)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewListener error: %v", err)
 		t.Fatalf("unexpected NewListener error: %v", err)
 	}
 	}
@@ -76,25 +80,12 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
 }
 }
 
 
 func TestNewListenerTLSEmptyInfo(t *testing.T) {
 func TestNewListenerTLSEmptyInfo(t *testing.T) {
-	_, err := NewListener("127.0.0.1:0", "https", TLSInfo{})
+	_, err := NewListener("127.0.0.1:0", "https", nil)
 	if err == nil {
 	if err == nil {
 		t.Errorf("err = nil, want not presented error")
 		t.Errorf("err = nil, want not presented error")
 	}
 	}
 }
 }
 
 
-func TestNewListenerTLSInfoNonexist(t *testing.T) {
-	tlsInfo := TLSInfo{CertFile: "@badname", KeyFile: "@badname"}
-	_, err := NewListener("127.0.0.1:0", "https", tlsInfo)
-	werr := &os.PathError{
-		Op:   "open",
-		Path: "@badname",
-		Err:  errors.New("no such file or directory"),
-	}
-	if err.Error() != werr.Error() {
-		t.Errorf("err = %v, want %v", err, werr)
-	}
-}
-
 func TestNewTransportTLSInfo(t *testing.T) {
 func TestNewTransportTLSInfo(t *testing.T) {
 	tmp, err := createTempFile([]byte("XXX"))
 	tmp, err := createTempFile([]byte("XXX"))
 	if err != nil {
 	if err != nil {
@@ -131,6 +122,19 @@ func TestNewTransportTLSInfo(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestTLSInfoNonexist(t *testing.T) {
+	tlsInfo := TLSInfo{CertFile: "@badname", KeyFile: "@badname"}
+	_, err := tlsInfo.ServerConfig()
+	werr := &os.PathError{
+		Op:   "open",
+		Path: "@badname",
+		Err:  errors.New("no such file or directory"),
+	}
+	if err.Error() != werr.Error() {
+		t.Errorf("err = %v, want %v", err, werr)
+	}
+}
+
 func TestTLSInfoEmpty(t *testing.T) {
 func TestTLSInfoEmpty(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		info TLSInfo
 		info TLSInfo
@@ -247,7 +251,7 @@ func TestTLSInfoConfigFuncs(t *testing.T) {
 }
 }
 
 
 func TestNewListenerUnixSocket(t *testing.T) {
 func TestNewListenerUnixSocket(t *testing.T) {
-	l, err := NewListener("testsocket", "unix", TLSInfo{})
+	l, err := NewListener("testsocket", "unix", nil)
 	if err != nil {
 	if err != nil {
 		t.Errorf("error listening on unix socket (%v)", err)
 		t.Errorf("error listening on unix socket (%v)", err)
 	}
 	}

+ 3 - 2
pkg/transport/timeout_listener.go

@@ -15,6 +15,7 @@
 package transport
 package transport
 
 
 import (
 import (
+	"crypto/tls"
 	"net"
 	"net"
 	"time"
 	"time"
 )
 )
@@ -22,8 +23,8 @@ import (
 // NewTimeoutListener returns a listener that listens on the given address.
 // NewTimeoutListener returns a listener that listens on the given address.
 // If read/write on the accepted connection blocks longer than its time limit,
 // If read/write on the accepted connection blocks longer than its time limit,
 // it will return timeout error.
 // it will return timeout error.
-func NewTimeoutListener(addr string, scheme string, info TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
-	ln, err := NewListener(addr, scheme, info)
+func NewTimeoutListener(addr string, scheme string, tlscfg *tls.Config, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
+	ln, err := NewListener(addr, scheme, tlscfg)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 1 - 1
pkg/transport/timeout_listener_test.go

@@ -23,7 +23,7 @@ import (
 // TestNewTimeoutListener tests that NewTimeoutListener returns a
 // TestNewTimeoutListener tests that NewTimeoutListener returns a
 // rwTimeoutListener struct with timeouts set.
 // rwTimeoutListener struct with timeouts set.
 func TestNewTimeoutListener(t *testing.T) {
 func TestNewTimeoutListener(t *testing.T) {
-	l, err := NewTimeoutListener("127.0.0.1:0", "http", TLSInfo{}, time.Hour, time.Hour)
+	l, err := NewTimeoutListener("127.0.0.1:0", "http", nil, time.Hour, time.Hour)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewTimeoutListener error: %v", err)
 		t.Fatalf("unexpected NewTimeoutListener error: %v", err)
 	}
 	}

+ 3 - 2
rafthttp/util.go

@@ -15,6 +15,7 @@
 package rafthttp
 package rafthttp
 
 
 import (
 import (
+	"crypto/tls"
 	"encoding/binary"
 	"encoding/binary"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
@@ -38,8 +39,8 @@ var (
 
 
 // NewListener returns a listener for raft message transfer between peers.
 // NewListener returns a listener for raft message transfer between peers.
 // It uses timeout listener to identify broken streams promptly.
 // It uses timeout listener to identify broken streams promptly.
-func NewListener(u url.URL, tlsInfo transport.TLSInfo) (net.Listener, error) {
-	return transport.NewTimeoutListener(u.Host, u.Scheme, tlsInfo, ConnReadTimeout, ConnWriteTimeout)
+func NewListener(u url.URL, tlscfg *tls.Config) (net.Listener, error) {
+	return transport.NewTimeoutListener(u.Host, u.Scheme, tlscfg, ConnReadTimeout, ConnWriteTimeout)
 }
 }
 
 
 // NewRoundTripper returns a roundTripper used to send requests
 // NewRoundTripper returns a roundTripper used to send requests