|
|
@@ -16,8 +16,10 @@ package transport
|
|
|
|
|
|
import (
|
|
|
"crypto/tls"
|
|
|
+ "crypto/x509"
|
|
|
"errors"
|
|
|
"io/ioutil"
|
|
|
+ "net"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
"testing"
|
|
|
@@ -26,12 +28,16 @@ import (
|
|
|
"go.uber.org/zap"
|
|
|
)
|
|
|
|
|
|
-func createSelfCert() (*TLSInfo, func(), error) {
|
|
|
+func createSelfCert(hosts ...string) (*TLSInfo, func(), error) {
|
|
|
+ return createSelfCertEx("127.0.0.1")
|
|
|
+}
|
|
|
+
|
|
|
+func createSelfCertEx(host string, additionalUsages ...x509.ExtKeyUsage) (*TLSInfo, func(), error) {
|
|
|
d, terr := ioutil.TempDir("", "etcd-test-tls-")
|
|
|
if terr != nil {
|
|
|
return nil, nil, terr
|
|
|
}
|
|
|
- info, err := SelfCert(zap.NewExample(), d, []string{"127.0.0.1"})
|
|
|
+ info, err := SelfCert(zap.NewExample(), d, []string{host + ":0"}, additionalUsages...)
|
|
|
if err != nil {
|
|
|
return nil, nil, err
|
|
|
}
|
|
|
@@ -76,6 +82,105 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// TestNewListenerTLSInfoSkipClientVerify tests that if client IP address mismatches
|
|
|
+// with specified address in its certificate the connection is still accepted
|
|
|
+// if the flag SkipClientVerify is set (i.e. checkSAN() is disabled for the client side)
|
|
|
+func TestNewListenerTLSInfoSkipClientVerify(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ skipClientVerify bool
|
|
|
+ goodClientHost bool
|
|
|
+ acceptExpected bool
|
|
|
+ }{
|
|
|
+ {false, true, true},
|
|
|
+ {false, false, false},
|
|
|
+ {true, true, true},
|
|
|
+ {true, false, true},
|
|
|
+ }
|
|
|
+ for _, test := range tests {
|
|
|
+ testNewListenerTLSInfoClientCheck(t, test.skipClientVerify, test.goodClientHost, test.acceptExpected)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func testNewListenerTLSInfoClientCheck(t *testing.T, skipClientVerify, goodClientHost, acceptExpected bool) {
|
|
|
+ tlsInfo, del, err := createSelfCert()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("unable to create cert: %v", err)
|
|
|
+ }
|
|
|
+ defer del()
|
|
|
+
|
|
|
+ host := "127.0.0.222"
|
|
|
+ if goodClientHost {
|
|
|
+ host = "127.0.0.1"
|
|
|
+ }
|
|
|
+ clientTLSInfo, del2, err := createSelfCertEx(host, x509.ExtKeyUsageClientAuth)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("unable to create cert: %v", err)
|
|
|
+ }
|
|
|
+ defer del2()
|
|
|
+
|
|
|
+ tlsInfo.SkipClientVerify = skipClientVerify
|
|
|
+ tlsInfo.TrustedCAFile = clientTLSInfo.CertFile
|
|
|
+
|
|
|
+ rootCAs := x509.NewCertPool()
|
|
|
+ loaded, err := ioutil.ReadFile(tlsInfo.CertFile)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("unexpected missing certfile: %v", err)
|
|
|
+ }
|
|
|
+ rootCAs.AppendCertsFromPEM(loaded)
|
|
|
+
|
|
|
+ clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("unable to create peer cert: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ tlsConfig := &tls.Config{}
|
|
|
+ tlsConfig.InsecureSkipVerify = false
|
|
|
+ tlsConfig.Certificates = []tls.Certificate{clientCert}
|
|
|
+ tlsConfig.RootCAs = rootCAs
|
|
|
+
|
|
|
+ ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("unexpected NewListener error: %v", err)
|
|
|
+ }
|
|
|
+ defer ln.Close()
|
|
|
+
|
|
|
+ tr := &http.Transport{TLSClientConfig: tlsConfig}
|
|
|
+ cli := &http.Client{Transport: tr}
|
|
|
+ chClientErr := make(chan error)
|
|
|
+ go func() {
|
|
|
+ _, err := cli.Get("https://" + ln.Addr().String())
|
|
|
+ chClientErr <- err
|
|
|
+ }()
|
|
|
+
|
|
|
+ chAcceptErr := make(chan error)
|
|
|
+ chAcceptConn := make(chan net.Conn)
|
|
|
+ go func() {
|
|
|
+ conn, err := ln.Accept()
|
|
|
+ if err != nil {
|
|
|
+ chAcceptErr <- err
|
|
|
+ } else {
|
|
|
+ chAcceptConn <- conn
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-chClientErr:
|
|
|
+ if acceptExpected {
|
|
|
+ t.Errorf("accepted for good client address: skipClientVerify=%t, goodClientHost=%t", skipClientVerify, goodClientHost)
|
|
|
+ }
|
|
|
+ case acceptErr := <-chAcceptErr:
|
|
|
+ t.Fatalf("unexpected Accept error: %v", acceptErr)
|
|
|
+ case conn := <-chAcceptConn:
|
|
|
+ defer conn.Close()
|
|
|
+ if _, ok := conn.(*tls.Conn); !ok {
|
|
|
+ t.Errorf("failed to accept *tls.Conn")
|
|
|
+ }
|
|
|
+ if !acceptExpected {
|
|
|
+ t.Errorf("accepted for bad client address: skipClientVerify=%t, goodClientHost=%t", skipClientVerify, goodClientHost)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TestNewListenerTLSEmptyInfo(t *testing.T) {
|
|
|
_, err := NewListener("127.0.0.1:0", "https", nil)
|
|
|
if err == nil {
|