Browse Source

etcdmain: tls listener MUST be at the outer layer of all listeners

go HTTP library uses type assertion to determine if a connection
is a TLS connection. If we wrapper TLS Listener with any customized
Listener that can create customized Conn, HTTPs will be broken.

This commit fixes the issue.
Xiang Li 10 years ago
parent
commit
1f97f2dc36

+ 9 - 1
etcdmain/etcd.go

@@ -236,10 +236,11 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 			plog.Warningf("The scheme of client url %s is http while client key/cert files are presented. Ignored client key/cert files.", u.String())
 		}
 		var l net.Listener
-		l, err = transport.NewKeepAliveListener(u.Host, u.Scheme, cfg.clientTLSInfo)
+		l, err = net.Listen("tcp", u.Host)
 		if err != nil {
 			return nil, err
 		}
+
 		if fdLimit, err := runtimeutil.FDLimit(); err == nil {
 			if fdLimit <= reservedInternalFDNum {
 				plog.Fatalf("file descriptor limit[%d] of etcd process is too low, and should be set higher than %d to ensure internal usage", fdLimit, reservedInternalFDNum)
@@ -247,6 +248,13 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
 			l = netutil.LimitListener(l, int(fdLimit-reservedInternalFDNum))
 		}
 
+		// Do not wrap around this listener if TLS Info is set.
+		// HTTPS server expects TLS Conn created by TLSListener.
+		l, err = transport.NewKeepAliveListener(l, u.Scheme, cfg.clientTLSInfo)
+		if err != nil {
+			return nil, err
+		}
+
 		urlStr := u.String()
 		plog.Info("listening for client requests on ", urlStr)
 		defer func() {

+ 3 - 0
etcdmain/http.go

@@ -26,6 +26,9 @@ import (
 // creating a new service goroutine for each. The service goroutines
 // read requests and then call handler to reply to them.
 func serveHTTP(l net.Listener, handler http.Handler, readTimeout time.Duration) error {
+	// TODO: assert net.Listener type? Arbitrary listener might break HTTPS server which
+	// expect a TLS Conn type.
+
 	logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
 	// TODO: add debug flag; enable logging when debug flag is set
 	srv := &http.Server{

+ 4 - 7
pkg/transport/keepalive_listener.go

@@ -22,16 +22,13 @@ import (
 )
 
 // NewKeepAliveListener returns a listener that listens on the given address.
+// Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil.
+// Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake.
 // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
-func NewKeepAliveListener(addr string, scheme string, info TLSInfo) (net.Listener, error) {
-	l, err := net.Listen("tcp", addr)
-	if err != nil {
-		return nil, err
-	}
-
+func NewKeepAliveListener(l net.Listener, scheme string, info TLSInfo) (net.Listener, error) {
 	if scheme == "https" {
 		if info.Empty() {
-			return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
+			return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented")
 		}
 		cfg, err := info.ServerConfig()
 		if err != nil {

+ 15 - 3
pkg/transport/keepalive_listener_test.go

@@ -16,6 +16,7 @@ package transport
 
 import (
 	"crypto/tls"
+	"net"
 	"net/http"
 	"os"
 	"testing"
@@ -25,7 +26,12 @@ import (
 // that accepts connections.
 // TODO: verify the keepalive option is set correctly
 func TestNewKeepAliveListener(t *testing.T) {
-	ln, err := NewKeepAliveListener("127.0.0.1:0", "http", TLSInfo{})
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("unexpected listen error: %v", err)
+	}
+
+	ln, err = NewKeepAliveListener(ln, "http", TLSInfo{})
 	if err != nil {
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 	}
@@ -38,6 +44,7 @@ func TestNewKeepAliveListener(t *testing.T) {
 	conn.Close()
 	ln.Close()
 
+	ln, err = net.Listen("tcp", "127.0.0.1:0")
 	// tls
 	tmp, err := createTempFile([]byte("XXX"))
 	if err != nil {
@@ -46,7 +53,7 @@ func TestNewKeepAliveListener(t *testing.T) {
 	defer os.Remove(tmp)
 	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
 	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
-	tlsln, err := NewKeepAliveListener("127.0.0.1:0", "https", tlsInfo)
+	tlsln, err := NewKeepAliveListener(ln, "https", tlsInfo)
 	if err != nil {
 		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
 	}
@@ -64,7 +71,12 @@ func TestNewKeepAliveListener(t *testing.T) {
 }
 
 func TestNewKeepAliveListenerTLSEmptyInfo(t *testing.T) {
-	_, err := NewListener("127.0.0.1:0", "https", TLSInfo{})
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("unexpected listen error: %v", err)
+	}
+
+	_, err = NewKeepAliveListener(ln, "https", TLSInfo{})
 	if err == nil {
 		t.Errorf("err = nil, want not presented error")
 	}