Просмотр исходного кода

transport: resolve DNSNames when SAN checking

The current transport client TLS checking will pass an IP address into
VerifyHostnames if there is DNSNames SAN. However, the go runtime will
not resolve the DNS names to match the client IP. Intead, resolve the
names when checking.
Anthony Romano 9 лет назад
Родитель
Сommit
05582ad5b2
1 измененных файлов с 46 добавлено и 13 удалено
  1. 46 13
      pkg/transport/listener_tls.go

+ 46 - 13
pkg/transport/listener_tls.go

@@ -15,7 +15,9 @@
 package transport
 
 import (
+	"context"
 	"crypto/tls"
+	"crypto/x509"
 	"fmt"
 	"net"
 	"sync"
@@ -40,11 +42,16 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
 	if err != nil {
 		return nil, err
 	}
+
+	hf := tlsinfo.HandshakeFailure
+	if hf == nil {
+		hf = func(*tls.Conn, error) {}
+	}
 	tlsl := &tlsListener{
 		Listener:         tls.NewListener(l, tlscfg),
 		connc:            make(chan net.Conn),
 		donec:            make(chan struct{}),
-		handshakeFailure: tlsinfo.HandshakeFailure,
+		handshakeFailure: hf,
 	}
 	go tlsl.acceptLoop()
 	return tlsl, nil
@@ -66,9 +73,9 @@ func (l *tlsListener) acceptLoop() {
 	var pendingMu sync.Mutex
 
 	pending := make(map[net.Conn]struct{})
-	stopc := make(chan struct{})
+	ctx, cancel := context.WithCancel(context.Background())
 	defer func() {
-		close(stopc)
+		cancel()
 		pendingMu.Lock()
 		for c := range pending {
 			c.Close()
@@ -104,32 +111,58 @@ func (l *tlsListener) acceptLoop() {
 			delete(pending, conn)
 			pendingMu.Unlock()
 			if herr != nil {
-				if l.handshakeFailure != nil {
-					l.handshakeFailure(tlsConn, herr)
-				}
+				l.handshakeFailure(tlsConn, herr)
 				return
 			}
 
 			st := tlsConn.ConnectionState()
 			if len(st.PeerCertificates) > 0 {
 				cert := st.PeerCertificates[0]
-				if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 {
-					addr := tlsConn.RemoteAddr().String()
-					h, _, herr := net.SplitHostPort(addr)
-					if herr != nil || cert.VerifyHostname(h) != nil {
-						return
-					}
+				addr := tlsConn.RemoteAddr().String()
+				if cerr := checkCert(ctx, cert, addr); cerr != nil {
+					l.handshakeFailure(tlsConn, cerr)
+					return
 				}
 			}
 			select {
 			case l.connc <- tlsConn:
 				conn = nil
-			case <-stopc:
+			case <-ctx.Done():
 			}
 		}()
 	}
 }
 
+func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
+	h, _, herr := net.SplitHostPort(remoteAddr)
+	if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
+		return nil
+	}
+	if herr != nil {
+		return herr
+	}
+	if len(cert.IPAddresses) > 0 {
+		if cerr := cert.VerifyHostname(h); cerr != nil && len(cert.DNSNames) == 0 {
+			return cerr
+		}
+	}
+	if len(cert.DNSNames) > 0 {
+		for _, dns := range cert.DNSNames {
+			addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
+			if lerr != nil {
+				continue
+			}
+			for _, addr := range addrs {
+				if addr == h {
+					return nil
+				}
+			}
+		}
+		return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames)
+	}
+	return nil
+}
+
 func (l *tlsListener) Close() error {
 	err := l.Listener.Close()
 	<-l.donec