Bladeren bron

etcdserver: add unit test.

Andy Liu 6 jaren geleden
bovenliggende
commit
d851911f86
2 gewijzigde bestanden met toevoegingen van 109 en 5 verwijderingen
  1. 2 2
      pkg/transport/listener.go
  2. 107 3
      pkg/transport/listener_test.go

+ 2 - 2
pkg/transport/listener.go

@@ -100,7 +100,7 @@ func (info TLSInfo) Empty() bool {
 	return info.CertFile == "" && info.KeyFile == ""
 }
 
-func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) {
+func SelfCert(dirpath string, hosts []string, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) {
 	if err = os.MkdirAll(dirpath, 0700); err != nil {
 		return
 	}
@@ -129,7 +129,7 @@ func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) {
 		NotAfter:     time.Now().Add(365 * (24 * time.Hour)),
 
 		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
-		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+		ExtKeyUsage:           append([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, additionalUsages...),
 		BasicConstraintsValid: true,
 	}
 

+ 107 - 3
pkg/transport/listener_test.go

@@ -16,20 +16,26 @@ package transport
 
 import (
 	"crypto/tls"
+	"crypto/x509"
 	"errors"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"os"
 	"testing"
 	"time"
 )
 
-func createSelfCert() (*TLSInfo, func(), error) {
+func createSelfCert(hosts ...string) (*TLSInfo, func(), error) {
+	return createSelfCertEx("127.0.0.1")
+}
+
+func createSelfCertEx(host string, additionalUsages ...x509.ExtKeyUsage) (*TLSInfo, func(), error) {
 	d, terr := ioutil.TempDir("", "etcd-test-tls-")
 	if terr != nil {
 		return nil, nil, terr
 	}
-	info, err := SelfCert(d, []string{"127.0.0.1"})
+	info, err := SelfCert(d, []string{host + ":0"}, additionalUsages...)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -70,10 +76,108 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
 	}
 	defer conn.Close()
 	if _, ok := conn.(*tls.Conn); !ok {
-		t.Errorf("failed to accept *tls.Conn")
+		t.Error("failed to accept *tls.Conn")
 	}
 }
 
+// TestNewListenerTLSInfoSkipClientSANVerify tests that if client IP address mismatches
+// with specified address in its certificate the connection is still accepted
+// if the flag SkipClientSANVerify is set (i.e. checkSAN() is disabled for the client side)
+func TestNewListenerTLSInfoSkipClientSANVerify(t *testing.T) {
+	tests := []struct {
+		skipClientSANVerify bool
+		goodClientHost      bool
+		acceptExpected      bool
+	}{
+		{false, true, true},
+		{false, false, false},
+		{true, true, true},
+		{true, false, true},
+	}
+	for _, test := range tests {
+		testNewListenerTLSInfoClientCheck(t, test.skipClientSANVerify, test.goodClientHost, test.acceptExpected)
+	}
+}
+
+func testNewListenerTLSInfoClientCheck(t *testing.T, skipClientSANVerify, goodClientHost, acceptExpected bool) {
+	tlsInfo, del, err := createSelfCert()
+	if err != nil {
+		t.Fatalf("unable to create cert: %v", err)
+	}
+	defer del()
+
+	host := "127.0.0.222"
+	if goodClientHost {
+		host = "127.0.0.1"
+	}
+	clientTLSInfo, del2, err := createSelfCertEx(host, x509.ExtKeyUsageClientAuth)
+	if err != nil {
+		t.Fatalf("unable to create cert: %v", err)
+	}
+	defer del2()
+
+	tlsInfo.SkipClientSANVerify = skipClientSANVerify
+	tlsInfo.CAFile = clientTLSInfo.CertFile
+
+	rootCAs := x509.NewCertPool()
+	loaded, err := ioutil.ReadFile(tlsInfo.CertFile)
+	if err != nil {
+		t.Fatalf("unexpected missing certfile: %v", err)
+	}
+	rootCAs.AppendCertsFromPEM(loaded)
+
+	clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile)
+	if err != nil {
+		t.Fatalf("unable to create peer cert: %v", err)
+	}
+
+	tlsConfig := &tls.Config{}
+	tlsConfig.InsecureSkipVerify = false
+	tlsConfig.Certificates = []tls.Certificate{clientCert}
+	tlsConfig.RootCAs = rootCAs
+
+	ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
+	if err != nil {
+		t.Fatalf("unexpected NewListener error: %v", err)
+	}
+	defer ln.Close()
+
+	tr := &http.Transport{TLSClientConfig: tlsConfig}
+	cli := &http.Client{Transport: tr}
+	chClientErr := make(chan error)
+	go func() {
+		_, err := cli.Get("https://" + ln.Addr().String())
+		chClientErr <- err
+	}()
+
+	chAcceptErr := make(chan error)
+	chAcceptConn := make(chan net.Conn)
+	go func() {
+		conn, err := ln.Accept()
+		if err != nil {
+			chAcceptErr <- err
+		} else {
+			chAcceptConn <- conn
+		}
+	}()
+
+	select {
+	case <-chClientErr:
+		if acceptExpected {
+			t.Errorf("accepted for good client address: skipClientSANVerify=%v, goodClientHost=%v", skipClientSANVerify, goodClientHost)
+		}
+	case acceptErr := <-chAcceptErr:
+		t.Fatalf("unexpected Accept error: %v", acceptErr)
+	case conn := <-chAcceptConn:
+		defer conn.Close()
+		if _, ok := conn.(*tls.Conn); !ok {
+			t.Errorf("failed to accept *tls.Conn")
+		}
+		if !acceptExpected {
+			t.Errorf("accepted for bad client address: skipClientSANVerify=%v, goodClientHost=%v", skipClientSANVerify, goodClientHost)
+		}
+	}
+}
 func TestNewListenerTLSEmptyInfo(t *testing.T) {
 	_, err := NewListener("127.0.0.1:0", "https", nil)
 	if err == nil {