|
@@ -19,21 +19,32 @@ import (
|
|
|
"crypto/tls"
|
|
"crypto/tls"
|
|
|
"crypto/x509"
|
|
"crypto/x509"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
|
|
+ "io/ioutil"
|
|
|
"net"
|
|
"net"
|
|
|
"sync"
|
|
"sync"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
// tlsListener overrides a TLS listener so it will reject client
|
|
// tlsListener overrides a TLS listener so it will reject client
|
|
|
-// certificates with insufficient SAN credentials.
|
|
|
|
|
|
|
+// certificates with insufficient SAN credentials or CRL revoked
|
|
|
|
|
+// certificates.
|
|
|
type tlsListener struct {
|
|
type tlsListener struct {
|
|
|
net.Listener
|
|
net.Listener
|
|
|
connc chan net.Conn
|
|
connc chan net.Conn
|
|
|
donec chan struct{}
|
|
donec chan struct{}
|
|
|
err error
|
|
err error
|
|
|
handshakeFailure func(*tls.Conn, error)
|
|
handshakeFailure func(*tls.Conn, error)
|
|
|
|
|
+ check tlsCheckFunc
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
|
|
|
|
|
|
+type tlsCheckFunc func(context.Context, *tls.Conn) error
|
|
|
|
|
+
|
|
|
|
|
+// NewTLSListener handshakes TLS connections and performs optional CRL checking.
|
|
|
|
|
+func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
|
|
|
|
+ check := func(context.Context, *tls.Conn) error { return nil }
|
|
|
|
|
+ return newTLSListener(l, tlsinfo, check)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) {
|
|
|
if tlsinfo == nil || tlsinfo.Empty() {
|
|
if tlsinfo == nil || tlsinfo.Empty() {
|
|
|
l.Close()
|
|
l.Close()
|
|
|
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
|
|
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
|
|
@@ -47,11 +58,27 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
|
|
if hf == nil {
|
|
if hf == nil {
|
|
|
hf = func(*tls.Conn, error) {}
|
|
hf = func(*tls.Conn, error) {}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ if len(tlsinfo.CRLFile) > 0 {
|
|
|
|
|
+ prevCheck := check
|
|
|
|
|
+ check = func(ctx context.Context, tlsConn *tls.Conn) error {
|
|
|
|
|
+ if err := prevCheck(ctx, tlsConn); err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ st := tlsConn.ConnectionState()
|
|
|
|
|
+ if certs := st.PeerCertificates; len(certs) > 0 {
|
|
|
|
|
+ return checkCRL(tlsinfo.CRLFile, certs)
|
|
|
|
|
+ }
|
|
|
|
|
+ return nil
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
tlsl := &tlsListener{
|
|
tlsl := &tlsListener{
|
|
|
Listener: tls.NewListener(l, tlscfg),
|
|
Listener: tls.NewListener(l, tlscfg),
|
|
|
connc: make(chan net.Conn),
|
|
connc: make(chan net.Conn),
|
|
|
donec: make(chan struct{}),
|
|
donec: make(chan struct{}),
|
|
|
handshakeFailure: hf,
|
|
handshakeFailure: hf,
|
|
|
|
|
+ check: check,
|
|
|
}
|
|
}
|
|
|
go tlsl.acceptLoop()
|
|
go tlsl.acceptLoop()
|
|
|
return tlsl, nil
|
|
return tlsl, nil
|
|
@@ -66,6 +93,15 @@ func (l *tlsListener) Accept() (net.Conn, error) {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func checkSAN(ctx context.Context, tlsConn *tls.Conn) error {
|
|
|
|
|
+ st := tlsConn.ConnectionState()
|
|
|
|
|
+ if certs := st.PeerCertificates; len(certs) > 0 {
|
|
|
|
|
+ addr := tlsConn.RemoteAddr().String()
|
|
|
|
|
+ return checkCertSAN(ctx, certs[0], addr)
|
|
|
|
|
+ }
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// acceptLoop launches each TLS handshake in a separate goroutine
|
|
// acceptLoop launches each TLS handshake in a separate goroutine
|
|
|
// to prevent a hanging TLS connection from blocking other connections.
|
|
// to prevent a hanging TLS connection from blocking other connections.
|
|
|
func (l *tlsListener) acceptLoop() {
|
|
func (l *tlsListener) acceptLoop() {
|
|
@@ -110,20 +146,16 @@ func (l *tlsListener) acceptLoop() {
|
|
|
pendingMu.Lock()
|
|
pendingMu.Lock()
|
|
|
delete(pending, conn)
|
|
delete(pending, conn)
|
|
|
pendingMu.Unlock()
|
|
pendingMu.Unlock()
|
|
|
|
|
+
|
|
|
if herr != nil {
|
|
if herr != nil {
|
|
|
l.handshakeFailure(tlsConn, herr)
|
|
l.handshakeFailure(tlsConn, herr)
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- st := tlsConn.ConnectionState()
|
|
|
|
|
- if len(st.PeerCertificates) > 0 {
|
|
|
|
|
- cert := st.PeerCertificates[0]
|
|
|
|
|
- addr := tlsConn.RemoteAddr().String()
|
|
|
|
|
- if cerr := checkCert(ctx, cert, addr); cerr != nil {
|
|
|
|
|
- l.handshakeFailure(tlsConn, cerr)
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if err := l.check(ctx, tlsConn); err != nil {
|
|
|
|
|
+ l.handshakeFailure(tlsConn, err)
|
|
|
|
|
+ return
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
select {
|
|
select {
|
|
|
case l.connc <- tlsConn:
|
|
case l.connc <- tlsConn:
|
|
|
conn = nil
|
|
conn = nil
|
|
@@ -133,11 +165,34 @@ func (l *tlsListener) acceptLoop() {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
|
|
|
|
|
- h, _, herr := net.SplitHostPort(remoteAddr)
|
|
|
|
|
|
|
+func checkCRL(crlPath string, cert []*x509.Certificate) error {
|
|
|
|
|
+ // TODO: cache
|
|
|
|
|
+ crlBytes, err := ioutil.ReadFile(crlPath)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ certList, err := x509.ParseCRL(crlBytes)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ revokedSerials := make(map[string]struct{})
|
|
|
|
|
+ for _, rc := range certList.TBSCertList.RevokedCertificates {
|
|
|
|
|
+ revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{}
|
|
|
|
|
+ }
|
|
|
|
|
+ for _, c := range cert {
|
|
|
|
|
+ serial := string(c.SerialNumber.Bytes())
|
|
|
|
|
+ if _, ok := revokedSerials[serial]; ok {
|
|
|
|
|
+ return fmt.Errorf("transport: certificate serial %x revoked", serial)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
|
|
|
if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
|
|
if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
+ h, _, herr := net.SplitHostPort(remoteAddr)
|
|
|
if herr != nil {
|
|
if herr != nil {
|
|
|
return herr
|
|
return herr
|
|
|
}
|
|
}
|