Browse Source

transport: wrap timeout listener with tls listener

Otherwise the listener will return timeoutConn's, causing a type
assertion to tls.Conn in net.http to fail so http.Request.TLS is never set.
Anthony Romano 9 years ago
parent
commit
99e0655c2f
2 changed files with 25 additions and 18 deletions
  1. 18 15
      pkg/transport/listener.go
  2. 7 3
      pkg/transport/timeout_listener.go

+ 18 - 15
pkg/transport/listener.go

@@ -34,27 +34,30 @@ import (
 	"github.com/coreos/etcd/pkg/tlsutil"
 )
 
-func NewListener(addr string, scheme string, tlscfg *tls.Config) (l net.Listener, err error) {
+func NewListener(addr, scheme string, tlscfg *tls.Config) (l net.Listener, err error) {
+	if l, err = newListener(addr, scheme); err != nil {
+		return nil, err
+	}
+	return wrapTLS(addr, scheme, tlscfg, l)
+}
+
+func newListener(addr string, scheme string) (net.Listener, error) {
 	if scheme == "unix" || scheme == "unixs" {
 		// unix sockets via unix://laddr
-		l, err = NewUnixListener(addr)
-	} else {
-		l, err = net.Listen("tcp", addr)
+		return NewUnixListener(addr)
 	}
+	return net.Listen("tcp", addr)
+}
 
-	if err != nil {
-		return nil, err
+func wrapTLS(addr, scheme string, tlscfg *tls.Config, l net.Listener) (net.Listener, error) {
+	if scheme != "https" && scheme != "unixs" {
+		return l, nil
 	}
-
-	if scheme == "https" || scheme == "unixs" {
-		if tlscfg == nil {
-			return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
-		}
-
-		l = tls.NewListener(l, tlscfg)
+	if tlscfg == nil {
+		l.Close()
+		return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
 	}
-
-	return l, nil
+	return tls.NewListener(l, tlscfg), nil
 }
 
 type TLSInfo struct {

+ 7 - 3
pkg/transport/timeout_listener.go

@@ -24,15 +24,19 @@ import (
 // If read/write on the accepted connection blocks longer than its time limit,
 // it will return timeout error.
 func NewTimeoutListener(addr string, scheme string, tlscfg *tls.Config, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
-	ln, err := NewListener(addr, scheme, tlscfg)
+	ln, err := newListener(addr, scheme)
 	if err != nil {
 		return nil, err
 	}
-	return &rwTimeoutListener{
+	ln = &rwTimeoutListener{
 		Listener:   ln,
 		rdtimeoutd: rdtimeoutd,
 		wtimeoutd:  wtimeoutd,
-	}, nil
+	}
+	if ln, err = wrapTLS(addr, scheme, tlscfg, ln); err != nil {
+		return nil, err
+	}
+	return ln, nil
 }
 
 type rwTimeoutListener struct {