Browse Source

transport: resolve DNSNames when SAN checking

The current transport client TLS checking will pass an IP address into
VerifyHostnames if there is DNSNames SAN. However, the go runtime will
not resolve the DNS names to match the client IP. Intead, resolve the
names when checking.
Anthony Romano 8 years ago
parent
commit
05582ad5b2
1 changed files with 46 additions and 13 deletions
  1. 46 13
      pkg/transport/listener_tls.go

+ 46 - 13
pkg/transport/listener_tls.go

@@ -15,7 +15,9 @@
 package transport
 
 import (
+	"context"
 	"crypto/tls"
+	"crypto/x509"
 	"fmt"
 	"net"
 	"sync"
@@ -40,11 +42,16 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
 	if err != nil {
 		return nil, err
 	}
+
+	hf := tlsinfo.HandshakeFailure
+	if hf == nil {
+		hf = func(*tls.Conn, error) {}
+	}
 	tlsl := &tlsListener{
 		Listener:         tls.NewListener(l, tlscfg),
 		connc:            make(chan net.Conn),
 		donec:            make(chan struct{}),
-		handshakeFailure: tlsinfo.HandshakeFailure,
+		handshakeFailure: hf,
 	}
 	go tlsl.acceptLoop()
 	return tlsl, nil
@@ -66,9 +73,9 @@ func (l *tlsListener) acceptLoop() {
 	var pendingMu sync.Mutex
 
 	pending := make(map[net.Conn]struct{})
-	stopc := make(chan struct{})
+	ctx, cancel := context.WithCancel(context.Background())
 	defer func() {
-		close(stopc)
+		cancel()
 		pendingMu.Lock()
 		for c := range pending {
 			c.Close()
@@ -104,32 +111,58 @@ func (l *tlsListener) acceptLoop() {
 			delete(pending, conn)
 			pendingMu.Unlock()
 			if herr != nil {
-				if l.handshakeFailure != nil {
-					l.handshakeFailure(tlsConn, herr)
-				}
+				l.handshakeFailure(tlsConn, herr)
 				return
 			}
 
 			st := tlsConn.ConnectionState()
 			if len(st.PeerCertificates) > 0 {
 				cert := st.PeerCertificates[0]
-				if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 {
-					addr := tlsConn.RemoteAddr().String()
-					h, _, herr := net.SplitHostPort(addr)
-					if herr != nil || cert.VerifyHostname(h) != nil {
-						return
-					}
+				addr := tlsConn.RemoteAddr().String()
+				if cerr := checkCert(ctx, cert, addr); cerr != nil {
+					l.handshakeFailure(tlsConn, cerr)
+					return
 				}
 			}
 			select {
 			case l.connc <- tlsConn:
 				conn = nil
-			case <-stopc:
+			case <-ctx.Done():
 			}
 		}()
 	}
 }
 
+func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
+	h, _, herr := net.SplitHostPort(remoteAddr)
+	if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
+		return nil
+	}
+	if herr != nil {
+		return herr
+	}
+	if len(cert.IPAddresses) > 0 {
+		if cerr := cert.VerifyHostname(h); cerr != nil && len(cert.DNSNames) == 0 {
+			return cerr
+		}
+	}
+	if len(cert.DNSNames) > 0 {
+		for _, dns := range cert.DNSNames {
+			addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
+			if lerr != nil {
+				continue
+			}
+			for _, addr := range addrs {
+				if addr == h {
+					return nil
+				}
+			}
+		}
+		return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames)
+	}
+	return nil
+}
+
 func (l *tlsListener) Close() error {
 	err := l.Listener.Close()
 	<-l.donec