|
@@ -15,7 +15,9 @@
|
|
|
package transport
|
|
package transport
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
|
+ "context"
|
|
|
"crypto/tls"
|
|
"crypto/tls"
|
|
|
|
|
+ "crypto/x509"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"net"
|
|
"net"
|
|
|
"sync"
|
|
"sync"
|
|
@@ -40,11 +42,16 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ hf := tlsinfo.HandshakeFailure
|
|
|
|
|
+ if hf == nil {
|
|
|
|
|
+ hf = func(*tls.Conn, error) {}
|
|
|
|
|
+ }
|
|
|
tlsl := &tlsListener{
|
|
tlsl := &tlsListener{
|
|
|
Listener: tls.NewListener(l, tlscfg),
|
|
Listener: tls.NewListener(l, tlscfg),
|
|
|
connc: make(chan net.Conn),
|
|
connc: make(chan net.Conn),
|
|
|
donec: make(chan struct{}),
|
|
donec: make(chan struct{}),
|
|
|
- handshakeFailure: tlsinfo.HandshakeFailure,
|
|
|
|
|
|
|
+ handshakeFailure: hf,
|
|
|
}
|
|
}
|
|
|
go tlsl.acceptLoop()
|
|
go tlsl.acceptLoop()
|
|
|
return tlsl, nil
|
|
return tlsl, nil
|
|
@@ -66,9 +73,9 @@ func (l *tlsListener) acceptLoop() {
|
|
|
var pendingMu sync.Mutex
|
|
var pendingMu sync.Mutex
|
|
|
|
|
|
|
|
pending := make(map[net.Conn]struct{})
|
|
pending := make(map[net.Conn]struct{})
|
|
|
- stopc := make(chan struct{})
|
|
|
|
|
|
|
+ ctx, cancel := context.WithCancel(context.Background())
|
|
|
defer func() {
|
|
defer func() {
|
|
|
- close(stopc)
|
|
|
|
|
|
|
+ cancel()
|
|
|
pendingMu.Lock()
|
|
pendingMu.Lock()
|
|
|
for c := range pending {
|
|
for c := range pending {
|
|
|
c.Close()
|
|
c.Close()
|
|
@@ -104,32 +111,58 @@ func (l *tlsListener) acceptLoop() {
|
|
|
delete(pending, conn)
|
|
delete(pending, conn)
|
|
|
pendingMu.Unlock()
|
|
pendingMu.Unlock()
|
|
|
if herr != nil {
|
|
if herr != nil {
|
|
|
- if l.handshakeFailure != nil {
|
|
|
|
|
- l.handshakeFailure(tlsConn, herr)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ l.handshakeFailure(tlsConn, herr)
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
st := tlsConn.ConnectionState()
|
|
st := tlsConn.ConnectionState()
|
|
|
if len(st.PeerCertificates) > 0 {
|
|
if len(st.PeerCertificates) > 0 {
|
|
|
cert := st.PeerCertificates[0]
|
|
cert := st.PeerCertificates[0]
|
|
|
- if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 {
|
|
|
|
|
- addr := tlsConn.RemoteAddr().String()
|
|
|
|
|
- h, _, herr := net.SplitHostPort(addr)
|
|
|
|
|
- if herr != nil || cert.VerifyHostname(h) != nil {
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ addr := tlsConn.RemoteAddr().String()
|
|
|
|
|
+ if cerr := checkCert(ctx, cert, addr); cerr != nil {
|
|
|
|
|
+ l.handshakeFailure(tlsConn, cerr)
|
|
|
|
|
+ return
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
select {
|
|
select {
|
|
|
case l.connc <- tlsConn:
|
|
case l.connc <- tlsConn:
|
|
|
conn = nil
|
|
conn = nil
|
|
|
- case <-stopc:
|
|
|
|
|
|
|
+ case <-ctx.Done():
|
|
|
}
|
|
}
|
|
|
}()
|
|
}()
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
|
|
|
|
|
+ h, _, herr := net.SplitHostPort(remoteAddr)
|
|
|
|
|
+ if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
|
|
|
|
|
+ return nil
|
|
|
|
|
+ }
|
|
|
|
|
+ if herr != nil {
|
|
|
|
|
+ return herr
|
|
|
|
|
+ }
|
|
|
|
|
+ if len(cert.IPAddresses) > 0 {
|
|
|
|
|
+ if cerr := cert.VerifyHostname(h); cerr != nil && len(cert.DNSNames) == 0 {
|
|
|
|
|
+ return cerr
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if len(cert.DNSNames) > 0 {
|
|
|
|
|
+ for _, dns := range cert.DNSNames {
|
|
|
|
|
+ addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
|
|
|
|
|
+ if lerr != nil {
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ for _, addr := range addrs {
|
|
|
|
|
+ if addr == h {
|
|
|
|
|
+ return nil
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames)
|
|
|
|
|
+ }
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func (l *tlsListener) Close() error {
|
|
func (l *tlsListener) Close() error {
|
|
|
err := l.Listener.Close()
|
|
err := l.Listener.Close()
|
|
|
<-l.donec
|
|
<-l.donec
|