Browse Source

Merge pull request #7767 from heyitsanthony/transport-resolve-dnsnames

transport: resolve DNSNames when SAN checking
Anthony Romano 8 years ago
parent
commit
8fa4b8da6e
1 changed files with 46 additions and 13 deletions
  1. 46 13
      pkg/transport/listener_tls.go

+ 46 - 13
pkg/transport/listener_tls.go

@@ -15,7 +15,9 @@
 package transport
 package transport
 
 
 import (
 import (
+	"context"
 	"crypto/tls"
 	"crypto/tls"
+	"crypto/x509"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"sync"
 	"sync"
@@ -40,11 +42,16 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+
+	hf := tlsinfo.HandshakeFailure
+	if hf == nil {
+		hf = func(*tls.Conn, error) {}
+	}
 	tlsl := &tlsListener{
 	tlsl := &tlsListener{
 		Listener:         tls.NewListener(l, tlscfg),
 		Listener:         tls.NewListener(l, tlscfg),
 		connc:            make(chan net.Conn),
 		connc:            make(chan net.Conn),
 		donec:            make(chan struct{}),
 		donec:            make(chan struct{}),
-		handshakeFailure: tlsinfo.HandshakeFailure,
+		handshakeFailure: hf,
 	}
 	}
 	go tlsl.acceptLoop()
 	go tlsl.acceptLoop()
 	return tlsl, nil
 	return tlsl, nil
@@ -66,9 +73,9 @@ func (l *tlsListener) acceptLoop() {
 	var pendingMu sync.Mutex
 	var pendingMu sync.Mutex
 
 
 	pending := make(map[net.Conn]struct{})
 	pending := make(map[net.Conn]struct{})
-	stopc := make(chan struct{})
+	ctx, cancel := context.WithCancel(context.Background())
 	defer func() {
 	defer func() {
-		close(stopc)
+		cancel()
 		pendingMu.Lock()
 		pendingMu.Lock()
 		for c := range pending {
 		for c := range pending {
 			c.Close()
 			c.Close()
@@ -104,32 +111,58 @@ func (l *tlsListener) acceptLoop() {
 			delete(pending, conn)
 			delete(pending, conn)
 			pendingMu.Unlock()
 			pendingMu.Unlock()
 			if herr != nil {
 			if herr != nil {
-				if l.handshakeFailure != nil {
-					l.handshakeFailure(tlsConn, herr)
-				}
+				l.handshakeFailure(tlsConn, herr)
 				return
 				return
 			}
 			}
 
 
 			st := tlsConn.ConnectionState()
 			st := tlsConn.ConnectionState()
 			if len(st.PeerCertificates) > 0 {
 			if len(st.PeerCertificates) > 0 {
 				cert := st.PeerCertificates[0]
 				cert := st.PeerCertificates[0]
-				if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 {
-					addr := tlsConn.RemoteAddr().String()
-					h, _, herr := net.SplitHostPort(addr)
-					if herr != nil || cert.VerifyHostname(h) != nil {
-						return
-					}
+				addr := tlsConn.RemoteAddr().String()
+				if cerr := checkCert(ctx, cert, addr); cerr != nil {
+					l.handshakeFailure(tlsConn, cerr)
+					return
 				}
 				}
 			}
 			}
 			select {
 			select {
 			case l.connc <- tlsConn:
 			case l.connc <- tlsConn:
 				conn = nil
 				conn = nil
-			case <-stopc:
+			case <-ctx.Done():
 			}
 			}
 		}()
 		}()
 	}
 	}
 }
 }
 
 
+func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
+	h, _, herr := net.SplitHostPort(remoteAddr)
+	if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
+		return nil
+	}
+	if herr != nil {
+		return herr
+	}
+	if len(cert.IPAddresses) > 0 {
+		if cerr := cert.VerifyHostname(h); cerr != nil && len(cert.DNSNames) == 0 {
+			return cerr
+		}
+	}
+	if len(cert.DNSNames) > 0 {
+		for _, dns := range cert.DNSNames {
+			addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
+			if lerr != nil {
+				continue
+			}
+			for _, addr := range addrs {
+				if addr == h {
+					return nil
+				}
+			}
+		}
+		return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames)
+	}
+	return nil
+}
+
 func (l *tlsListener) Close() error {
 func (l *tlsListener) Close() error {
 	err := l.Listener.Close()
 	err := l.Listener.Close()
 	<-l.donec
 	<-l.donec