Browse Source

transport: CRL checking

Anthony Romano 8 years ago
parent
commit
322976bedc
2 changed files with 71 additions and 15 deletions
  1. 3 2
      pkg/transport/listener.go
  2. 68 13
      pkg/transport/listener_tls.go

+ 3 - 2
pkg/transport/listener.go

@@ -52,7 +52,7 @@ func wrapTLS(addr, scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listene
 	if scheme != "https" && scheme != "unixs" {
 	if scheme != "https" && scheme != "unixs" {
 		return l, nil
 		return l, nil
 	}
 	}
-	return newTLSListener(l, tlsinfo)
+	return newTLSListener(l, tlsinfo, checkSAN)
 }
 }
 
 
 type TLSInfo struct {
 type TLSInfo struct {
@@ -61,6 +61,7 @@ type TLSInfo struct {
 	CAFile         string
 	CAFile         string
 	TrustedCAFile  string
 	TrustedCAFile  string
 	ClientCertAuth bool
 	ClientCertAuth bool
+	CRLFile        string
 
 
 	// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
 	// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
 	ServerName string
 	ServerName string
@@ -77,7 +78,7 @@ type TLSInfo struct {
 }
 }
 
 
 func (info TLSInfo) String() string {
 func (info TLSInfo) String() string {
-	return fmt.Sprintf("cert = %s, key = %s, ca = %s, trusted-ca = %s, client-cert-auth = %v", info.CertFile, info.KeyFile, info.CAFile, info.TrustedCAFile, info.ClientCertAuth)
+	return fmt.Sprintf("cert = %s, key = %s, ca = %s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.CAFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile)
 }
 }
 
 
 func (info TLSInfo) Empty() bool {
 func (info TLSInfo) Empty() bool {

+ 68 - 13
pkg/transport/listener_tls.go

@@ -19,21 +19,32 @@ import (
 	"crypto/tls"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509"
 	"fmt"
 	"fmt"
+	"io/ioutil"
 	"net"
 	"net"
 	"sync"
 	"sync"
 )
 )
 
 
 // tlsListener overrides a TLS listener so it will reject client
 // tlsListener overrides a TLS listener so it will reject client
-// certificates with insufficient SAN credentials.
+// certificates with insufficient SAN credentials or CRL revoked
+// certificates.
 type tlsListener struct {
 type tlsListener struct {
 	net.Listener
 	net.Listener
 	connc            chan net.Conn
 	connc            chan net.Conn
 	donec            chan struct{}
 	donec            chan struct{}
 	err              error
 	err              error
 	handshakeFailure func(*tls.Conn, error)
 	handshakeFailure func(*tls.Conn, error)
+	check            tlsCheckFunc
 }
 }
 
 
-func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
+type tlsCheckFunc func(context.Context, *tls.Conn) error
+
+// NewTLSListener handshakes TLS connections and performs optional CRL checking.
+func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
+	check := func(context.Context, *tls.Conn) error { return nil }
+	return newTLSListener(l, tlsinfo, check)
+}
+
+func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) {
 	if tlsinfo == nil || tlsinfo.Empty() {
 	if tlsinfo == nil || tlsinfo.Empty() {
 		l.Close()
 		l.Close()
 		return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
 		return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
@@ -47,11 +58,27 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
 	if hf == nil {
 	if hf == nil {
 		hf = func(*tls.Conn, error) {}
 		hf = func(*tls.Conn, error) {}
 	}
 	}
+
+	if len(tlsinfo.CRLFile) > 0 {
+		prevCheck := check
+		check = func(ctx context.Context, tlsConn *tls.Conn) error {
+			if err := prevCheck(ctx, tlsConn); err != nil {
+				return err
+			}
+			st := tlsConn.ConnectionState()
+			if certs := st.PeerCertificates; len(certs) > 0 {
+				return checkCRL(tlsinfo.CRLFile, certs)
+			}
+			return nil
+		}
+	}
+
 	tlsl := &tlsListener{
 	tlsl := &tlsListener{
 		Listener:         tls.NewListener(l, tlscfg),
 		Listener:         tls.NewListener(l, tlscfg),
 		connc:            make(chan net.Conn),
 		connc:            make(chan net.Conn),
 		donec:            make(chan struct{}),
 		donec:            make(chan struct{}),
 		handshakeFailure: hf,
 		handshakeFailure: hf,
+		check:            check,
 	}
 	}
 	go tlsl.acceptLoop()
 	go tlsl.acceptLoop()
 	return tlsl, nil
 	return tlsl, nil
@@ -66,6 +93,15 @@ func (l *tlsListener) Accept() (net.Conn, error) {
 	}
 	}
 }
 }
 
 
+func checkSAN(ctx context.Context, tlsConn *tls.Conn) error {
+	st := tlsConn.ConnectionState()
+	if certs := st.PeerCertificates; len(certs) > 0 {
+		addr := tlsConn.RemoteAddr().String()
+		return checkCertSAN(ctx, certs[0], addr)
+	}
+	return nil
+}
+
 // acceptLoop launches each TLS handshake in a separate goroutine
 // acceptLoop launches each TLS handshake in a separate goroutine
 // to prevent a hanging TLS connection from blocking other connections.
 // to prevent a hanging TLS connection from blocking other connections.
 func (l *tlsListener) acceptLoop() {
 func (l *tlsListener) acceptLoop() {
@@ -110,20 +146,16 @@ func (l *tlsListener) acceptLoop() {
 			pendingMu.Lock()
 			pendingMu.Lock()
 			delete(pending, conn)
 			delete(pending, conn)
 			pendingMu.Unlock()
 			pendingMu.Unlock()
+
 			if herr != nil {
 			if herr != nil {
 				l.handshakeFailure(tlsConn, herr)
 				l.handshakeFailure(tlsConn, herr)
 				return
 				return
 			}
 			}
-
-			st := tlsConn.ConnectionState()
-			if len(st.PeerCertificates) > 0 {
-				cert := st.PeerCertificates[0]
-				addr := tlsConn.RemoteAddr().String()
-				if cerr := checkCert(ctx, cert, addr); cerr != nil {
-					l.handshakeFailure(tlsConn, cerr)
-					return
-				}
+			if err := l.check(ctx, tlsConn); err != nil {
+				l.handshakeFailure(tlsConn, err)
+				return
 			}
 			}
+
 			select {
 			select {
 			case l.connc <- tlsConn:
 			case l.connc <- tlsConn:
 				conn = nil
 				conn = nil
@@ -133,11 +165,34 @@ func (l *tlsListener) acceptLoop() {
 	}
 	}
 }
 }
 
 
-func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
-	h, _, herr := net.SplitHostPort(remoteAddr)
+func checkCRL(crlPath string, cert []*x509.Certificate) error {
+	// TODO: cache
+	crlBytes, err := ioutil.ReadFile(crlPath)
+	if err != nil {
+		return err
+	}
+	certList, err := x509.ParseCRL(crlBytes)
+	if err != nil {
+		return err
+	}
+	revokedSerials := make(map[string]struct{})
+	for _, rc := range certList.TBSCertList.RevokedCertificates {
+		revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{}
+	}
+	for _, c := range cert {
+		serial := string(c.SerialNumber.Bytes())
+		if _, ok := revokedSerials[serial]; ok {
+			return fmt.Errorf("transport: certificate serial %x revoked", serial)
+		}
+	}
+	return nil
+}
+
+func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
 	if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
 	if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
 		return nil
 		return nil
 	}
 	}
+	h, _, herr := net.SplitHostPort(remoteAddr)
 	if herr != nil {
 	if herr != nil {
 		return herr
 		return herr
 	}
 	}