Browse Source

etcdmain: tls listener MUST be at the outer layer of all listeners

go HTTP library uses type assertion to determine if a connection
is a TLS connection. If we wrapper TLS Listener with any customized
Listener that can create customized Conn, HTTPs will be broken.

This commit fixes the issue.
Xiang Li 10 years ago
parent
commit
1f97f2dc36

+ 9 - 1
etcdmain/etcd.go

@@ -236,10 +236,11 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 			plog.Warningf("The scheme of client url %s is http while client key/cert files are presented. Ignored client key/cert files.", u.String())
 			plog.Warningf("The scheme of client url %s is http while client key/cert files are presented. Ignored client key/cert files.", u.String())
 		}
 		}
 		var l net.Listener
 		var l net.Listener
-		l, err = transport.NewKeepAliveListener(u.Host, u.Scheme, cfg.clientTLSInfo)
+		l, err = net.Listen("tcp", u.Host)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
+
 		if fdLimit, err := runtimeutil.FDLimit(); err == nil {
 		if fdLimit, err := runtimeutil.FDLimit(); err == nil {
 			if fdLimit <= reservedInternalFDNum {
 			if fdLimit <= reservedInternalFDNum {
 				plog.Fatalf("file descriptor limit[%d] of etcd process is too low, and should be set higher than %d to ensure internal usage", fdLimit, reservedInternalFDNum)
 				plog.Fatalf("file descriptor limit[%d] of etcd process is too low, and should be set higher than %d to ensure internal usage", fdLimit, reservedInternalFDNum)
@@ -247,6 +248,13 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 			l = netutil.LimitListener(l, int(fdLimit-reservedInternalFDNum))
 			l = netutil.LimitListener(l, int(fdLimit-reservedInternalFDNum))
 		}
 		}
 
 
+		// Do not wrap around this listener if TLS Info is set.
+		// HTTPS server expects TLS Conn created by TLSListener.
+		l, err = transport.NewKeepAliveListener(l, u.Scheme, cfg.clientTLSInfo)
+		if err != nil {
+			return nil, err
+		}
+
 		urlStr := u.String()
 		urlStr := u.String()
 		plog.Info("listening for client requests on ", urlStr)
 		plog.Info("listening for client requests on ", urlStr)
 		defer func() {
 		defer func() {

+ 3 - 0
etcdmain/http.go

@@ -26,6 +26,9 @@ import (
 // creating a new service goroutine for each. The service goroutines
 // creating a new service goroutine for each. The service goroutines
 // read requests and then call handler to reply to them.
 // read requests and then call handler to reply to them.
 func serveHTTP(l net.Listener, handler http.Handler, readTimeout time.Duration) error {
 func serveHTTP(l net.Listener, handler http.Handler, readTimeout time.Duration) error {
+	// TODO: assert net.Listener type? Arbitrary listener might break HTTPS server which
+	// expect a TLS Conn type.
+
 	logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
 	logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
 	// TODO: add debug flag; enable logging when debug flag is set
 	// TODO: add debug flag; enable logging when debug flag is set
 	srv := &http.Server{
 	srv := &http.Server{

+ 4 - 7
pkg/transport/keepalive_listener.go

@@ -22,16 +22,13 @@ import (
 )
 )
 
 
 // NewKeepAliveListener returns a listener that listens on the given address.
 // NewKeepAliveListener returns a listener that listens on the given address.
+// Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil.
+// Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake.
 // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
 // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
-func NewKeepAliveListener(addr string, scheme string, info TLSInfo) (net.Listener, error) {
-	l, err := net.Listen("tcp", addr)
-	if err != nil {
-		return nil, err
-	}
-
+func NewKeepAliveListener(l net.Listener, scheme string, info TLSInfo) (net.Listener, error) {
 	if scheme == "https" {
 	if scheme == "https" {
 		if info.Empty() {
 		if info.Empty() {
-			return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
+			return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented")
 		}
 		}
 		cfg, err := info.ServerConfig()
 		cfg, err := info.ServerConfig()
 		if err != nil {
 		if err != nil {

+ 15 - 3
pkg/transport/keepalive_listener_test.go

@@ -16,6 +16,7 @@ package transport
 
 
 import (
 import (
 	"crypto/tls"
 	"crypto/tls"
+	"net"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
 	"testing"
 	"testing"
@@ -25,7 +26,12 @@ import (
 // that accepts connections.
 // that accepts connections.
 // TODO: verify the keepalive option is set correctly
 // TODO: verify the keepalive option is set correctly
 func TestNewKeepAliveListener(t *testing.T) {
 func TestNewKeepAliveListener(t *testing.T) {
-	ln, err := NewKeepAliveListener("127.0.0.1:0", "http", TLSInfo{})
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("unexpected listen error: %v", err)
+	}
+
+	ln, err = NewKeepAliveListener(ln, "http", TLSInfo{})
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 	}
 	}
@@ -38,6 +44,7 @@ func TestNewKeepAliveListener(t *testing.T) {
 	conn.Close()
 	conn.Close()
 	ln.Close()
 	ln.Close()
 
 
+	ln, err = net.Listen("tcp", "127.0.0.1:0")
 	// tls
 	// tls
 	tmp, err := createTempFile([]byte("XXX"))
 	tmp, err := createTempFile([]byte("XXX"))
 	if err != nil {
 	if err != nil {
@@ -46,7 +53,7 @@ func TestNewKeepAliveListener(t *testing.T) {
 	defer os.Remove(tmp)
 	defer os.Remove(tmp)
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
-	tlsln, err := NewKeepAliveListener("127.0.0.1:0", "https", tlsInfo)
+	tlsln, err := NewKeepAliveListener(ln, "https", tlsInfo)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 	}
 	}
@@ -64,7 +71,12 @@ func TestNewKeepAliveListener(t *testing.T) {
 }
 }
 
 
 func TestNewKeepAliveListenerTLSEmptyInfo(t *testing.T) {
 func TestNewKeepAliveListenerTLSEmptyInfo(t *testing.T) {
-	_, err := NewListener("127.0.0.1:0", "https", TLSInfo{})
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("unexpected listen error: %v", err)
+	}
+
+	_, err = NewKeepAliveListener(ln, "https", TLSInfo{})
 	if err == nil {
 	if err == nil {
 		t.Errorf("err = nil, want not presented error")
 		t.Errorf("err = nil, want not presented error")
 	}
 	}