Browse Source

transport: use reverse lookup to match wildcard DNS SAN

Fixes #8268
Anthony Romano 8 years ago
parent
commit
b1aa962233
1 changed files with 53 additions and 10 deletions
  1. 53 10
      pkg/transport/listener_tls.go

+ 53 - 10
pkg/transport/listener_tls.go

@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"io/ioutil"
 	"net"
+	"strings"
 	"sync"
 )
 
@@ -206,20 +207,62 @@ func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string
 		}
 	}
 	if len(cert.DNSNames) > 0 {
-		for _, dns := range cert.DNSNames {
-			addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
-			if lerr != nil {
-				continue
+		ok, err := isHostInDNS(ctx, h, cert.DNSNames)
+		if ok {
+			return nil
+		}
+		errStr := ""
+		if err != nil {
+			errStr = " (" + err.Error() + ")"
+		}
+		return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames)
+	}
+	return nil
+}
+
+func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) {
+	// reverse lookup
+	wildcards, names := []string{}, []string{}
+	for _, dns := range dnsNames {
+		if strings.HasPrefix(dns, "*.") {
+			wildcards = append(wildcards, dns[1:])
+		} else {
+			names = append(names, dns)
+		}
+	}
+	lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host)
+	for _, name := range lnames {
+		// strip trailing '.' from PTR record
+		if name[len(name)-1] == '.' {
+			name = name[:len(name)-1]
+		}
+		for _, wc := range wildcards {
+			if strings.HasSuffix(name, wc) {
+				return true, nil
 			}
-			for _, addr := range addrs {
-				if addr == h {
-					return nil
-				}
+		}
+		for _, n := range names {
+			if n == name {
+				return true, nil
 			}
 		}
-		return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames)
 	}
-	return nil
+	err = lerr
+
+	// forward lookup
+	for _, dns := range names {
+		addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
+		if lerr != nil {
+			err = lerr
+			continue
+		}
+		for _, addr := range addrs {
+			if addr == host {
+				return true, nil
+			}
+		}
+	}
+	return false, err
 }
 
 func (l *tlsListener) Close() error {