Browse Source

Add comprehensive host test (#429)

Add table driven test for handling of host in request URL, request
header and TLS server name. In addition to testing various uses of host
names, this test also confirms that host names are handled the same as
the net/http client.

The new table driven test replaces TestDialTLS, TestDialTLSNoverify,
TestDialTLSBadCert and TestHostHeader.

Eliminate duplicated code for constructing root CA.
Steven Scott 7 years ago
parent
commit
cdd40f587d
1 changed files with 194 additions and 79 deletions
  1. 194 79
      client_server_test.go

+ 194 - 79
client_server_test.go

@@ -11,8 +11,10 @@ import (
 	"crypto/x509"
 	"encoding/base64"
 	"encoding/binary"
+	"fmt"
 	"io"
 	"io/ioutil"
+	"log"
 	"net"
 	"net/http"
 	"net/http/cookiejar"
@@ -42,17 +44,12 @@ var cstDialer = Dialer{
 	HandshakeTimeout: 30 * time.Second,
 }
 
-var cstDialerWithoutHandshakeTimeout = Dialer{
-	Subprotocols:    []string{"p1", "p2"},
-	ReadBufferSize:  1024,
-	WriteBufferSize: 1024,
-}
-
 type cstHandler struct{ *testing.T }
 
 type cstServer struct {
 	*httptest.Server
 	URL string
+	t   *testing.T
 }
 
 const (
@@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
 	sendRecv(t, ws)
 }
 
-func TestDialTLS(t *testing.T) {
-	s := newTLSServer(t)
-	defer s.Close()
-
+func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
 	certs := x509.NewCertPool()
 	for _, c := range s.TLS.Certificates {
 		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
@@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
 			certs.AddCert(root)
 		}
 	}
-
-	d := cstDialer
-	d.TLSClientConfig = &tls.Config{RootCAs: certs}
-	ws, _, err := d.Dial(s.URL, nil)
-	if err != nil {
-		t.Fatalf("Dial: %v", err)
-	}
-	defer ws.Close()
-	sendRecv(t, ws)
-}
-
-func xTestDialTLSBadCert(t *testing.T) {
-	// This test is deactivated because of noisy logging from the net/http package.
-	s := newTLSServer(t)
-	defer s.Close()
-
-	ws, _, err := cstDialer.Dial(s.URL, nil)
-	if err == nil {
-		ws.Close()
-		t.Fatalf("Dial: nil")
-	}
+	return certs
 }
 
-func TestDialTLSNoVerify(t *testing.T) {
+func TestDialTLS(t *testing.T) {
 	s := newTLSServer(t)
 	defer s.Close()
 
 	d := cstDialer
-	d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
+	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
 	ws, _, err := d.Dial(s.URL, nil)
 	if err != nil {
 		t.Fatalf("Dial: %v", err)
@@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()
 
-	d := cstDialerWithoutHandshakeTimeout
+	d := cstDialer
+	d.HandshakeTimeout = 0
 	d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
 		netDialer := &net.Dialer{}
 		c, err := netDialer.DialContext(ctx, n, a)
@@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
 	}
 }
 
-// TestHostHeader confirms that the host header provided in the call to Dial is
-// sent to the server.
-func TestHostHeader(t *testing.T) {
-	s := newServer(t)
-	defer s.Close()
+type testLogWriter struct {
+	t *testing.T
+}
 
-	specifiedHost := make(chan string, 1)
-	origHandler := s.Server.Config.Handler
+func (w testLogWriter) Write(p []byte) (int, error) {
+	w.t.Logf("%s", p)
+	return len(p), nil
+}
 
-	// Capture the request Host header.
-	s.Server.Config.Handler = http.HandlerFunc(
-		func(w http.ResponseWriter, r *http.Request) {
-			specifiedHost <- r.Host
-			origHandler.ServeHTTP(w, r)
-		})
+// TestHost tests handling of host names and confirms that it matches net/http.
+func TestHost(t *testing.T) {
 
-	ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
-	if err != nil {
-		t.Fatalf("Dial: %v", err)
-	}
-	defer ws.Close()
+	upgrader := Upgrader{}
+	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if IsWebSocketUpgrade(r) {
+			c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
+			if err != nil {
+				t.Fatal(err)
+			}
+			c.Close()
+		} else {
+			w.Header().Set("X-Test-Host", r.Host)
+		}
+	})
+
+	server := httptest.NewServer(handler)
+	defer server.Close()
+
+	tlsServer := httptest.NewTLSServer(handler)
+	defer tlsServer.Close()
+
+	addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
+	wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
+	httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
+
+	// Avoid log noise from net/http server by logging to testing.T
+	server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
+	tlsServer.Config.ErrorLog = server.Config.ErrorLog
+
+	cas := rootCAs(t, tlsServer)
+
+	tests := []struct {
+		fail               bool             // true if dial / get should fail
+		server             *httptest.Server // server to use
+		url                string           // host for request URI
+		header             string           // optional request host header
+		tls                string           // optiona host for tls ServerName
+		wantAddr           string           // expected host for dial
+		wantHeader         string           // expected request header on server
+		insecureSkipVerify bool
+	}{
+		{
+			server:     server,
+			url:        addrs[server],
+			wantAddr:   addrs[server],
+			wantHeader: addrs[server],
+		},
+		{
+			server:     tlsServer,
+			url:        addrs[tlsServer],
+			wantAddr:   addrs[tlsServer],
+			wantHeader: addrs[tlsServer],
+		},
+
+		{
+			server:     server,
+			url:        addrs[server],
+			header:     "badhost.com",
+			wantAddr:   addrs[server],
+			wantHeader: "badhost.com",
+		},
+		{
+			server:     tlsServer,
+			url:        addrs[tlsServer],
+			header:     "badhost.com",
+			wantAddr:   addrs[tlsServer],
+			wantHeader: "badhost.com",
+		},
+
+		{
+			server:     server,
+			url:        "example.com",
+			header:     "badhost.com",
+			wantAddr:   "example.com:80",
+			wantHeader: "badhost.com",
+		},
+		{
+			server:     tlsServer,
+			url:        "example.com",
+			header:     "badhost.com",
+			wantAddr:   "example.com:443",
+			wantHeader: "badhost.com",
+		},
 
-	if gotHost := <-specifiedHost; gotHost != "testhost" {
-		t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
+		{
+			server:     server,
+			url:        "badhost.com",
+			header:     "example.com",
+			wantAddr:   "badhost.com:80",
+			wantHeader: "example.com",
+		},
+		{
+			fail:     true,
+			server:   tlsServer,
+			url:      "badhost.com",
+			header:   "example.com",
+			wantAddr: "badhost.com:443",
+		},
+		{
+			server:             tlsServer,
+			url:                "badhost.com",
+			insecureSkipVerify: true,
+			wantAddr:           "badhost.com:443",
+			wantHeader:         "badhost.com",
+		},
+		{
+			server:     tlsServer,
+			url:        "badhost.com",
+			tls:        "example.com",
+			wantAddr:   "badhost.com:443",
+			wantHeader: "badhost.com",
+		},
 	}
 
-	sendRecv(t, ws)
+	for i, tt := range tests {
+
+		tls := &tls.Config{
+			RootCAs:            cas,
+			ServerName:         tt.tls,
+			InsecureSkipVerify: tt.insecureSkipVerify,
+		}
+
+		var gotAddr string
+		dialer := Dialer{
+			NetDial: func(network, addr string) (net.Conn, error) {
+				gotAddr = addr
+				return net.Dial(network, addrs[tt.server])
+			},
+			TLSClientConfig: tls,
+		}
+
+		// Test websocket dial
+
+		h := http.Header{}
+		if tt.header != "" {
+			h.Set("Host", tt.header)
+		}
+		c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
+		if err == nil {
+			c.Close()
+		}
+
+		check := func(protos map[*httptest.Server]string) {
+			name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
+			if gotAddr != tt.wantAddr {
+				t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
+			}
+			switch {
+			case tt.fail && err == nil:
+				t.Errorf("%s: unexpected success", name)
+			case !tt.fail && err != nil:
+				t.Errorf("%s: unexpected error %v", name, err)
+			case !tt.fail && err == nil:
+				if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
+					t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
+				}
+			}
+		}
+
+		check(wsProtos)
+
+		// Confirm that net/http has same result
+
+		transport := &http.Transport{
+			Dial:            dialer.NetDial,
+			TLSClientConfig: dialer.TLSClientConfig,
+		}
+		req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
+		if tt.header != "" {
+			req.Host = tt.header
+		}
+		client := &http.Client{Transport: transport}
+		resp, err = client.Do(req)
+		if err == nil {
+			resp.Body.Close()
+		}
+		transport.CloseIdleConnections()
+		check(httpProtos)
+	}
 }
 
 func TestDialCompression(t *testing.T) {
@@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
 	s := newTLSServer(t)
 	defer s.Close()
 
-	certs := x509.NewCertPool()
-	for _, c := range s.TLS.Certificates {
-		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
-		if err != nil {
-			t.Fatalf("error parsing server's root cert: %v", err)
-		}
-		for _, root := range roots {
-			certs.AddCert(root)
-		}
-	}
-
 	d := cstDialer
-	d.TLSClientConfig = &tls.Config{RootCAs: certs}
+	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
 
 	ws, _, err := d.DialContext(ctx, s.URL, nil)
 	if err != nil {
@@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
 	s := newTLSServer(t)
 	defer s.Close()
 
-	certs := x509.NewCertPool()
-	for _, c := range s.TLS.Certificates {
-		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
-		if err != nil {
-			t.Fatalf("error parsing server's root cert: %v", err)
-		}
-		for _, root := range roots {
-			certs.AddCert(root)
-		}
-	}
-
 	d := cstDialer
-	d.TLSClientConfig = &tls.Config{RootCAs: certs}
+	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
 
 	ws, _, err := d.DialContext(ctx, s.URL, nil)
 	if err != nil {