|
@@ -22,6 +22,7 @@ import (
|
|
|
"crypto/x509"
|
|
"crypto/x509"
|
|
|
"crypto/x509/pkix"
|
|
"crypto/x509/pkix"
|
|
|
"encoding/pem"
|
|
"encoding/pem"
|
|
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"math/big"
|
|
"math/big"
|
|
|
"net"
|
|
"net"
|
|
@@ -76,6 +77,9 @@ type TLSInfo struct {
|
|
|
// parseFunc exists to simplify testing. Typically, parseFunc
|
|
// parseFunc exists to simplify testing. Typically, parseFunc
|
|
|
// should be left nil. In that case, tls.X509KeyPair will be used.
|
|
// should be left nil. In that case, tls.X509KeyPair will be used.
|
|
|
parseFunc func([]byte, []byte) (tls.Certificate, error)
|
|
parseFunc func([]byte, []byte) (tls.Certificate, error)
|
|
|
|
|
+
|
|
|
|
|
+ // AllowedCN is a CN which must be provided by a client.
|
|
|
|
|
+ AllowedCN string
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (info TLSInfo) String() string {
|
|
func (info TLSInfo) String() string {
|
|
@@ -174,6 +178,20 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
|
|
|
MinVersion: tls.VersionTLS12,
|
|
MinVersion: tls.VersionTLS12,
|
|
|
ServerName: info.ServerName,
|
|
ServerName: info.ServerName,
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ if info.AllowedCN != "" {
|
|
|
|
|
+ cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
|
|
|
+ for _, chains := range verifiedChains {
|
|
|
|
|
+ if len(chains) != 0 {
|
|
|
|
|
+ if info.AllowedCN == chains[0].Subject.CommonName {
|
|
|
|
|
+ return nil
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return errors.New("CommonName authentication failed")
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// this only reloads certs when there's a client request
|
|
// this only reloads certs when there's a client request
|
|
|
// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
|
|
// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
|
|
|
cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|