Browse Source

pkg/transport: fix tlskeepalive

Xiang Li 11 years ago
parent
commit
4960324876
2 changed files with 75 additions and 5 deletions
  1. 48 5
      pkg/transport/keepalive_listener.go
  2. 27 0
      pkg/transport/keepalive_listener_test.go

+ 48 - 5
pkg/transport/keepalive_listener.go

@@ -15,6 +15,7 @@
 package transport
 
 import (
+	"crypto/tls"
 	"net"
 	"time"
 )
@@ -22,18 +23,26 @@ import (
 // NewKeepAliveListener returns a listener that listens on the given address.
 // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
 func NewKeepAliveListener(addr string, scheme string, info TLSInfo) (net.Listener, error) {
-	ln, err := NewListener(addr, scheme, info)
+	l, err := net.Listen("tcp", addr)
 	if err != nil {
 		return nil, err
 	}
+
+	if !info.Empty() && scheme == "https" {
+		cfg, err := info.ServerConfig()
+		if err != nil {
+			return nil, err
+		}
+
+		return newTLSKeepaliveListener(l, cfg), nil
+	}
+
 	return &keepaliveListener{
-		Listener: ln,
+		Listener: l,
 	}, nil
 }
 
-type keepaliveListener struct {
-	net.Listener
-}
+type keepaliveListener struct{ net.Listener }
 
 func (kln *keepaliveListener) Accept() (net.Conn, error) {
 	c, err := kln.Listener.Accept()
@@ -48,3 +57,37 @@ func (kln *keepaliveListener) Accept() (net.Conn, error) {
 	tcpc.SetKeepAlivePeriod(30 * time.Second)
 	return tcpc, nil
 }
+
+// A tlsKeepaliveListener implements a network listener (net.Listener) for TLS connections.
+type tlsKeepaliveListener struct {
+	net.Listener
+	config *tls.Config
+}
+
+// Accept waits for and returns the next incoming TLS connection.
+// The returned connection c is a *tls.Conn.
+func (l *tlsKeepaliveListener) Accept() (c net.Conn, err error) {
+	c, err = l.Listener.Accept()
+	if err != nil {
+		return
+	}
+	tcpc := c.(*net.TCPConn)
+	// detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl
+	// default on linux:  30 + 8 * 30
+	// default on osx:    30 + 8 * 75
+	tcpc.SetKeepAlive(true)
+	tcpc.SetKeepAlivePeriod(30 * time.Second)
+	c = tls.Server(c, l.config)
+	return
+}
+
+// NewListener creates a Listener which accepts connections from an inner
+// Listener and wraps each connection with Server.
+// The configuration config must be non-nil and must have
+// at least one certificate.
+func newTLSKeepaliveListener(inner net.Listener, config *tls.Config) net.Listener {
+	l := &tlsKeepaliveListener{}
+	l.Listener = inner
+	l.config = config
+	return l
+}

+ 27 - 0
pkg/transport/keepalive_listener_test.go

@@ -15,7 +15,9 @@
 package transport
 
 import (
+	"crypto/tls"
 	"net/http"
+	"os"
 	"testing"
 )
 
@@ -34,4 +36,29 @@ func TestNewKeepAliveListener(t *testing.T) {
 		t.Fatalf("unexpected Accept error: %v", err)
 	}
 	conn.Close()
+	ln.Close()
+
+	// tls
+	tmp, err := createTempFile([]byte("XXX"))
+	if err != nil {
+		t.Fatalf("unable to create tmpfile: %v", err)
+	}
+	defer os.Remove(tmp)
+	tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
+	tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
+	tlsln, err := NewKeepAliveListener("127.0.0.1:0", "https", tlsInfo)
+	if err != nil {
+		t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
+	}
+
+	go http.Get("https://" + tlsln.Addr().String())
+	conn, err = tlsln.Accept()
+	if err != nil {
+		t.Fatalf("unexpected Accept error: %v", err)
+	}
+	if _, ok := conn.(*tls.Conn); !ok {
+		t.Errorf("failed to accept *tls.Conn")
+	}
+	conn.Close()
+	tlsln.Close()
 }