Browse Source

pkg/transport: Added test for SkipClientVerify flag.

Martin Weindel 6 years ago
parent
commit
2f476f2b5a
2 changed files with 109 additions and 4 deletions
  1. 2 2
      pkg/transport/listener.go
  2. 107 2
      pkg/transport/listener_test.go

+ 2 - 2
pkg/transport/listener.go

@@ -113,7 +113,7 @@ func (info TLSInfo) Empty() bool {
 	return info.CertFile == "" && info.KeyFile == ""
 }
 
-func SelfCert(lg *zap.Logger, dirpath string, hosts []string) (info TLSInfo, err error) {
+func SelfCert(lg *zap.Logger, dirpath string, hosts []string, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) {
 	if err = os.MkdirAll(dirpath, 0700); err != nil {
 		return
 	}
@@ -149,7 +149,7 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string) (info TLSInfo, err
 		NotAfter:     time.Now().Add(365 * (24 * time.Hour)),
 
 		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
-		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+		ExtKeyUsage:           append([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, additionalUsages...),
 		BasicConstraintsValid: true,
 	}
 

+ 107 - 2
pkg/transport/listener_test.go

@@ -16,8 +16,10 @@ package transport
 
 import (
 	"crypto/tls"
+	"crypto/x509"
 	"errors"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"os"
 	"testing"
@@ -26,12 +28,16 @@ import (
 	"go.uber.org/zap"
 )
 
-func createSelfCert() (*TLSInfo, func(), error) {
+func createSelfCert(hosts ...string) (*TLSInfo, func(), error) {
+	return createSelfCertEx("127.0.0.1")
+}
+
+func createSelfCertEx(host string, additionalUsages ...x509.ExtKeyUsage) (*TLSInfo, func(), error) {
 	d, terr := ioutil.TempDir("", "etcd-test-tls-")
 	if terr != nil {
 		return nil, nil, terr
 	}
-	info, err := SelfCert(zap.NewExample(), d, []string{"127.0.0.1"})
+	info, err := SelfCert(zap.NewExample(), d, []string{host + ":0"}, additionalUsages...)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -76,6 +82,105 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
 	}
 }
 
+// TestNewListenerTLSInfoSkipClientVerify tests that if client IP address mismatches
+// with specified address in its certificate the connection is still accepted
+// if the flag SkipClientVerify is set (i.e. checkSAN() is disabled for the client side)
+func TestNewListenerTLSInfoSkipClientVerify(t *testing.T) {
+	tests := []struct {
+		skipClientVerify bool
+		goodClientHost   bool
+		acceptExpected   bool
+	}{
+		{false, true, true},
+		{false, false, false},
+		{true, true, true},
+		{true, false, true},
+	}
+	for _, test := range tests {
+		testNewListenerTLSInfoClientCheck(t, test.skipClientVerify, test.goodClientHost, test.acceptExpected)
+	}
+}
+
+func testNewListenerTLSInfoClientCheck(t *testing.T, skipClientVerify, goodClientHost, acceptExpected bool) {
+	tlsInfo, del, err := createSelfCert()
+	if err != nil {
+		t.Fatalf("unable to create cert: %v", err)
+	}
+	defer del()
+
+	host := "127.0.0.222"
+	if goodClientHost {
+		host = "127.0.0.1"
+	}
+	clientTLSInfo, del2, err := createSelfCertEx(host, x509.ExtKeyUsageClientAuth)
+	if err != nil {
+		t.Fatalf("unable to create cert: %v", err)
+	}
+	defer del2()
+
+	tlsInfo.SkipClientVerify = skipClientVerify
+	tlsInfo.TrustedCAFile = clientTLSInfo.CertFile
+
+	rootCAs := x509.NewCertPool()
+	loaded, err := ioutil.ReadFile(tlsInfo.CertFile)
+	if err != nil {
+		t.Fatalf("unexpected missing certfile: %v", err)
+	}
+	rootCAs.AppendCertsFromPEM(loaded)
+
+	clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile)
+	if err != nil {
+		t.Fatalf("unable to create peer cert: %v", err)
+	}
+
+	tlsConfig := &tls.Config{}
+	tlsConfig.InsecureSkipVerify = false
+	tlsConfig.Certificates = []tls.Certificate{clientCert}
+	tlsConfig.RootCAs = rootCAs
+
+	ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
+	if err != nil {
+		t.Fatalf("unexpected NewListener error: %v", err)
+	}
+	defer ln.Close()
+
+	tr := &http.Transport{TLSClientConfig: tlsConfig}
+	cli := &http.Client{Transport: tr}
+	chClientErr := make(chan error)
+	go func() {
+		_, err := cli.Get("https://" + ln.Addr().String())
+		chClientErr <- err
+	}()
+
+	chAcceptErr := make(chan error)
+	chAcceptConn := make(chan net.Conn)
+	go func() {
+		conn, err := ln.Accept()
+		if err != nil {
+			chAcceptErr <- err
+		} else {
+			chAcceptConn <- conn
+		}
+	}()
+
+	select {
+	case <-chClientErr:
+		if acceptExpected {
+			t.Errorf("accepted for good client address: skipClientVerify=%t, goodClientHost=%t", skipClientVerify, goodClientHost)
+		}
+	case acceptErr := <-chAcceptErr:
+		t.Fatalf("unexpected Accept error: %v", acceptErr)
+	case conn := <-chAcceptConn:
+		defer conn.Close()
+		if _, ok := conn.(*tls.Conn); !ok {
+			t.Errorf("failed to accept *tls.Conn")
+		}
+		if !acceptExpected {
+			t.Errorf("accepted for bad client address: skipClientVerify=%t, goodClientHost=%t", skipClientVerify, goodClientHost)
+		}
+	}
+}
+
 func TestNewListenerTLSEmptyInfo(t *testing.T) {
 	_, err := NewListener("127.0.0.1:0", "https", nil)
 	if err == nil {