Browse Source

Merge pull request #4153 from xiang90/fix_listener

etcdmain: tls listener MUST be at the outer layer of all listeners
Xiang Li 10 years ago
parent
commit
9e0378998b

+ 9 - 1
etcdmain/etcd.go

@@ -236,10 +236,11 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 			plog.Warningf("The scheme of client url %s is http while client key/cert files are presented. Ignored client key/cert files.", u.String())
 			plog.Warningf("The scheme of client url %s is http while client key/cert files are presented. Ignored client key/cert files.", u.String())
 		}
 		}
 		var l net.Listener
 		var l net.Listener
-		l, err = transport.NewKeepAliveListener(u.Host, u.Scheme, cfg.clientTLSInfo)
+		l, err = net.Listen("tcp", u.Host)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
+
 		if fdLimit, err := runtimeutil.FDLimit(); err == nil {
 		if fdLimit, err := runtimeutil.FDLimit(); err == nil {
 			if fdLimit <= reservedInternalFDNum {
 			if fdLimit <= reservedInternalFDNum {
 				plog.Fatalf("file descriptor limit[%d] of etcd process is too low, and should be set higher than %d to ensure internal usage", fdLimit, reservedInternalFDNum)
 				plog.Fatalf("file descriptor limit[%d] of etcd process is too low, and should be set higher than %d to ensure internal usage", fdLimit, reservedInternalFDNum)
@@ -247,6 +248,13 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 			l = netutil.LimitListener(l, int(fdLimit-reservedInternalFDNum))
 			l = netutil.LimitListener(l, int(fdLimit-reservedInternalFDNum))
 		}
 		}
 
 
+		// Do not wrap around this listener if TLS Info is set.
+		// HTTPS server expects TLS Conn created by TLSListener.
+		l, err = transport.NewKeepAliveListener(l, u.Scheme, cfg.clientTLSInfo)
+		if err != nil {
+			return nil, err
+		}
+
 		urlStr := u.String()
 		urlStr := u.String()
 		plog.Info("listening for client requests on ", urlStr)
 		plog.Info("listening for client requests on ", urlStr)
 		defer func() {
 		defer func() {

+ 3 - 0
etcdmain/http.go

@@ -26,6 +26,9 @@ import (
 // creating a new service goroutine for each. The service goroutines
 // creating a new service goroutine for each. The service goroutines
 // read requests and then call handler to reply to them.
 // read requests and then call handler to reply to them.
 func serveHTTP(l net.Listener, handler http.Handler, readTimeout time.Duration) error {
 func serveHTTP(l net.Listener, handler http.Handler, readTimeout time.Duration) error {
+	// TODO: assert net.Listener type? Arbitrary listener might break HTTPS server which
+	// expect a TLS Conn type.
+
 	logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
 	logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
 	// TODO: add debug flag; enable logging when debug flag is set
 	// TODO: add debug flag; enable logging when debug flag is set
 	srv := &http.Server{
 	srv := &http.Server{

+ 4 - 7
pkg/transport/keepalive_listener.go

@@ -22,16 +22,13 @@ import (
 )
 )
 
 
 // NewKeepAliveListener returns a listener that listens on the given address.
 // NewKeepAliveListener returns a listener that listens on the given address.
+// Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil.
+// Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake.
 // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
 // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
-func NewKeepAliveListener(addr string, scheme string, info TLSInfo) (net.Listener, error) {
-	l, err := net.Listen("tcp", addr)
-	if err != nil {
-		return nil, err
-	}
-
+func NewKeepAliveListener(l net.Listener, scheme string, info TLSInfo) (net.Listener, error) {
 	if scheme == "https" {
 	if scheme == "https" {
 		if info.Empty() {
 		if info.Empty() {
-			return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
+			return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented")
 		}
 		}
 		cfg, err := info.ServerConfig()
 		cfg, err := info.ServerConfig()
 		if err != nil {
 		if err != nil {

+ 15 - 3
pkg/transport/keepalive_listener_test.go

@@ -16,6 +16,7 @@ package transport
 
 
 import (
 import (
 	"crypto/tls"
 	"crypto/tls"
+	"net"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
 	"testing"
 	"testing"
@@ -25,7 +26,12 @@ import (
 // that accepts connections.
 // that accepts connections.
 // TODO: verify the keepalive option is set correctly
 // TODO: verify the keepalive option is set correctly
 func TestNewKeepAliveListener(t *testing.T) {
 func TestNewKeepAliveListener(t *testing.T) {
-	ln, err := NewKeepAliveListener("127.0.0.1:0", "http", TLSInfo{})
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("unexpected listen error: %v", err)
+	}
+
+	ln, err = NewKeepAliveListener(ln, "http", TLSInfo{})
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 	}
 	}
@@ -38,6 +44,7 @@ func TestNewKeepAliveListener(t *testing.T) {
 	conn.Close()
 	conn.Close()
 	ln.Close()
 	ln.Close()
 
 
+	ln, err = net.Listen("tcp", "127.0.0.1:0")
 	// tls
 	// tls
 	tmp, err := createTempFile([]byte("XXX"))
 	tmp, err := createTempFile([]byte("XXX"))
 	if err != nil {
 	if err != nil {
@@ -46,7 +53,7 @@ func TestNewKeepAliveListener(t *testing.T) {
 	defer os.Remove(tmp)
 	defer os.Remove(tmp)
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
-	tlsln, err := NewKeepAliveListener("127.0.0.1:0", "https", tlsInfo)
+	tlsln, err := NewKeepAliveListener(ln, "https", tlsInfo)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 	}
 	}
@@ -64,7 +71,12 @@ func TestNewKeepAliveListener(t *testing.T) {
 }
 }
 
 
 func TestNewKeepAliveListenerTLSEmptyInfo(t *testing.T) {
 func TestNewKeepAliveListenerTLSEmptyInfo(t *testing.T) {
-	_, err := NewListener("127.0.0.1:0", "https", TLSInfo{})
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("unexpected listen error: %v", err)
+	}
+
+	_, err = NewKeepAliveListener(ln, "https", TLSInfo{})
 	if err == nil {
 	if err == nil {
 		t.Errorf("err = nil, want not presented error")
 		t.Errorf("err = nil, want not presented error")
 	}
 	}