Browse Source

pkg/transport: generate TLS client config w/ only CAFile

Brian Waldon 11 years ago
parent
commit
902f06c5c4
2 changed files with 32 additions and 44 deletions
  1. 17 17
      pkg/transport/listener.go
  2. 15 27
      pkg/transport/listener_test.go

+ 17 - 17
pkg/transport/listener.go

@@ -46,6 +46,11 @@ func NewListener(addr string, info TLSInfo) (net.Listener, error) {
 }
 
 func NewTransport(info TLSInfo) (*http.Transport, error) {
+	cfg, err := info.ClientConfig()
+	if err != nil {
+		return nil, err
+	}
+
 	t := &http.Transport{
 		// timeouts taken from http.DefaultTransport
 		Dial: (&net.Dialer{
@@ -53,14 +58,7 @@ func NewTransport(info TLSInfo) (*http.Transport, error) {
 			KeepAlive: 30 * time.Second,
 		}).Dial,
 		TLSHandshakeTimeout: 10 * time.Second,
-	}
-
-	if !info.Empty() {
-		tlsCfg, err := info.ClientConfig()
-		if err != nil {
-			return nil, err
-		}
-		t.TLSClientConfig = tlsCfg
+		TLSClientConfig:     cfg,
 	}
 
 	return t, nil
@@ -134,22 +132,24 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) {
 }
 
 // ClientConfig generates a tls.Config object for use by an HTTP client
-func (info TLSInfo) ClientConfig() (*tls.Config, error) {
-	cfg, err := info.baseConfig()
-	if err != nil {
-		return nil, err
+func (info TLSInfo) ClientConfig() (cfg *tls.Config, err error) {
+	if !info.Empty() {
+		cfg, err = info.baseConfig()
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		cfg = &tls.Config{}
 	}
 
 	if info.CAFile != "" {
-		cp, err := newCertPool(info.CAFile)
+		cfg.RootCAs, err = newCertPool(info.CAFile)
 		if err != nil {
-			return nil, err
+			return
 		}
-
-		cfg.RootCAs = cp
 	}
 
-	return cfg, nil
+	return
 }
 
 // newCertPool creates x509 certPool with provided CA file

+ 15 - 27
pkg/transport/listener_test.go

@@ -51,41 +51,31 @@ func TestNewTransportTLSInfo(t *testing.T) {
 	}
 	defer os.Remove(tmp)
 
-	tests := []struct {
-		info                TLSInfo
-		wantTLSClientConfig bool
-	}{
-		{
-			info:                TLSInfo{},
-			wantTLSClientConfig: false,
+	tests := []TLSInfo{
+		TLSInfo{},
+		TLSInfo{
+			CertFile: tmp,
+			KeyFile:  tmp,
 		},
-		{
-			info: TLSInfo{
-				CertFile: tmp,
-				KeyFile:  tmp,
-			},
-			wantTLSClientConfig: true,
+		TLSInfo{
+			CertFile: tmp,
+			KeyFile:  tmp,
+			CAFile:   tmp,
 		},
-		{
-			info: TLSInfo{
-				CertFile: tmp,
-				KeyFile:  tmp,
-				CAFile:   tmp,
-			},
-			wantTLSClientConfig: true,
+		TLSInfo{
+			CAFile: tmp,
 		},
 	}
 
 	for i, tt := range tests {
-		tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
-		trans, err := NewTransport(tt.info)
+		tt.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
+		trans, err := NewTransport(tt)
 		if err != nil {
 			t.Fatalf("Received unexpected error from NewTransport: %v", err)
 		}
 
-		gotTLSClientConfig := trans.TLSClientConfig != nil
-		if tt.wantTLSClientConfig != gotTLSClientConfig {
-			t.Fatalf("#%d: wantTLSClientConfig=%t but gotTLSClientConfig=%t", i, tt.wantTLSClientConfig, gotTLSClientConfig)
+		if trans.TLSClientConfig == nil {
+			t.Fatalf("#%d: want non-nil TLSClientConfig", i)
 		}
 	}
 }
@@ -121,8 +111,6 @@ func TestTLSInfoMissingFields(t *testing.T) {
 	defer os.Remove(tmp)
 
 	tests := []TLSInfo{
-		TLSInfo{},
-		TLSInfo{CAFile: tmp},
 		TLSInfo{CertFile: tmp},
 		TLSInfo{KeyFile: tmp},
 		TLSInfo{CertFile: tmp, CAFile: tmp},