|
|
@@ -11,8 +11,10 @@ import (
|
|
|
"crypto/x509"
|
|
|
"encoding/base64"
|
|
|
"encoding/binary"
|
|
|
+ "fmt"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
+ "log"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/http/cookiejar"
|
|
|
@@ -42,17 +44,12 @@ var cstDialer = Dialer{
|
|
|
HandshakeTimeout: 30 * time.Second,
|
|
|
}
|
|
|
|
|
|
-var cstDialerWithoutHandshakeTimeout = Dialer{
|
|
|
- Subprotocols: []string{"p1", "p2"},
|
|
|
- ReadBufferSize: 1024,
|
|
|
- WriteBufferSize: 1024,
|
|
|
-}
|
|
|
-
|
|
|
type cstHandler struct{ *testing.T }
|
|
|
|
|
|
type cstServer struct {
|
|
|
*httptest.Server
|
|
|
URL string
|
|
|
+ t *testing.T
|
|
|
}
|
|
|
|
|
|
const (
|
|
|
@@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
|
|
|
sendRecv(t, ws)
|
|
|
}
|
|
|
|
|
|
-func TestDialTLS(t *testing.T) {
|
|
|
- s := newTLSServer(t)
|
|
|
- defer s.Close()
|
|
|
-
|
|
|
+func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
|
|
|
certs := x509.NewCertPool()
|
|
|
for _, c := range s.TLS.Certificates {
|
|
|
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
|
|
@@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
|
|
|
certs.AddCert(root)
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- d := cstDialer
|
|
|
- d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
|
|
- ws, _, err := d.Dial(s.URL, nil)
|
|
|
- if err != nil {
|
|
|
- t.Fatalf("Dial: %v", err)
|
|
|
- }
|
|
|
- defer ws.Close()
|
|
|
- sendRecv(t, ws)
|
|
|
-}
|
|
|
-
|
|
|
-func xTestDialTLSBadCert(t *testing.T) {
|
|
|
- // This test is deactivated because of noisy logging from the net/http package.
|
|
|
- s := newTLSServer(t)
|
|
|
- defer s.Close()
|
|
|
-
|
|
|
- ws, _, err := cstDialer.Dial(s.URL, nil)
|
|
|
- if err == nil {
|
|
|
- ws.Close()
|
|
|
- t.Fatalf("Dial: nil")
|
|
|
- }
|
|
|
+ return certs
|
|
|
}
|
|
|
|
|
|
-func TestDialTLSNoVerify(t *testing.T) {
|
|
|
+func TestDialTLS(t *testing.T) {
|
|
|
s := newTLSServer(t)
|
|
|
defer s.Close()
|
|
|
|
|
|
d := cstDialer
|
|
|
- d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
|
|
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
|
|
ws, _, err := d.Dial(s.URL, nil)
|
|
|
if err != nil {
|
|
|
t.Fatalf("Dial: %v", err)
|
|
|
@@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
|
|
|
s := newServer(t)
|
|
|
defer s.Close()
|
|
|
|
|
|
- d := cstDialerWithoutHandshakeTimeout
|
|
|
+ d := cstDialer
|
|
|
+ d.HandshakeTimeout = 0
|
|
|
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
|
|
|
netDialer := &net.Dialer{}
|
|
|
c, err := netDialer.DialContext(ctx, n, a)
|
|
|
@@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// TestHostHeader confirms that the host header provided in the call to Dial is
|
|
|
-// sent to the server.
|
|
|
-func TestHostHeader(t *testing.T) {
|
|
|
- s := newServer(t)
|
|
|
- defer s.Close()
|
|
|
+type testLogWriter struct {
|
|
|
+ t *testing.T
|
|
|
+}
|
|
|
|
|
|
- specifiedHost := make(chan string, 1)
|
|
|
- origHandler := s.Server.Config.Handler
|
|
|
+func (w testLogWriter) Write(p []byte) (int, error) {
|
|
|
+ w.t.Logf("%s", p)
|
|
|
+ return len(p), nil
|
|
|
+}
|
|
|
|
|
|
- // Capture the request Host header.
|
|
|
- s.Server.Config.Handler = http.HandlerFunc(
|
|
|
- func(w http.ResponseWriter, r *http.Request) {
|
|
|
- specifiedHost <- r.Host
|
|
|
- origHandler.ServeHTTP(w, r)
|
|
|
- })
|
|
|
+// TestHost tests handling of host names and confirms that it matches net/http.
|
|
|
+func TestHost(t *testing.T) {
|
|
|
|
|
|
- ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
|
|
|
- if err != nil {
|
|
|
- t.Fatalf("Dial: %v", err)
|
|
|
- }
|
|
|
- defer ws.Close()
|
|
|
+ upgrader := Upgrader{}
|
|
|
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ if IsWebSocketUpgrade(r) {
|
|
|
+ c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ c.Close()
|
|
|
+ } else {
|
|
|
+ w.Header().Set("X-Test-Host", r.Host)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ server := httptest.NewServer(handler)
|
|
|
+ defer server.Close()
|
|
|
+
|
|
|
+ tlsServer := httptest.NewTLSServer(handler)
|
|
|
+ defer tlsServer.Close()
|
|
|
+
|
|
|
+ addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
|
|
|
+ wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
|
|
|
+ httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
|
|
|
+
|
|
|
+ // Avoid log noise from net/http server by logging to testing.T
|
|
|
+ server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
|
|
|
+ tlsServer.Config.ErrorLog = server.Config.ErrorLog
|
|
|
+
|
|
|
+ cas := rootCAs(t, tlsServer)
|
|
|
+
|
|
|
+ tests := []struct {
|
|
|
+ fail bool // true if dial / get should fail
|
|
|
+ server *httptest.Server // server to use
|
|
|
+ url string // host for request URI
|
|
|
+ header string // optional request host header
|
|
|
+ tls string // optiona host for tls ServerName
|
|
|
+ wantAddr string // expected host for dial
|
|
|
+ wantHeader string // expected request header on server
|
|
|
+ insecureSkipVerify bool
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ server: server,
|
|
|
+ url: addrs[server],
|
|
|
+ wantAddr: addrs[server],
|
|
|
+ wantHeader: addrs[server],
|
|
|
+ },
|
|
|
+ {
|
|
|
+ server: tlsServer,
|
|
|
+ url: addrs[tlsServer],
|
|
|
+ wantAddr: addrs[tlsServer],
|
|
|
+ wantHeader: addrs[tlsServer],
|
|
|
+ },
|
|
|
+
|
|
|
+ {
|
|
|
+ server: server,
|
|
|
+ url: addrs[server],
|
|
|
+ header: "badhost.com",
|
|
|
+ wantAddr: addrs[server],
|
|
|
+ wantHeader: "badhost.com",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ server: tlsServer,
|
|
|
+ url: addrs[tlsServer],
|
|
|
+ header: "badhost.com",
|
|
|
+ wantAddr: addrs[tlsServer],
|
|
|
+ wantHeader: "badhost.com",
|
|
|
+ },
|
|
|
+
|
|
|
+ {
|
|
|
+ server: server,
|
|
|
+ url: "example.com",
|
|
|
+ header: "badhost.com",
|
|
|
+ wantAddr: "example.com:80",
|
|
|
+ wantHeader: "badhost.com",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ server: tlsServer,
|
|
|
+ url: "example.com",
|
|
|
+ header: "badhost.com",
|
|
|
+ wantAddr: "example.com:443",
|
|
|
+ wantHeader: "badhost.com",
|
|
|
+ },
|
|
|
|
|
|
- if gotHost := <-specifiedHost; gotHost != "testhost" {
|
|
|
- t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
|
|
|
+ {
|
|
|
+ server: server,
|
|
|
+ url: "badhost.com",
|
|
|
+ header: "example.com",
|
|
|
+ wantAddr: "badhost.com:80",
|
|
|
+ wantHeader: "example.com",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ fail: true,
|
|
|
+ server: tlsServer,
|
|
|
+ url: "badhost.com",
|
|
|
+ header: "example.com",
|
|
|
+ wantAddr: "badhost.com:443",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ server: tlsServer,
|
|
|
+ url: "badhost.com",
|
|
|
+ insecureSkipVerify: true,
|
|
|
+ wantAddr: "badhost.com:443",
|
|
|
+ wantHeader: "badhost.com",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ server: tlsServer,
|
|
|
+ url: "badhost.com",
|
|
|
+ tls: "example.com",
|
|
|
+ wantAddr: "badhost.com:443",
|
|
|
+ wantHeader: "badhost.com",
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
- sendRecv(t, ws)
|
|
|
+ for i, tt := range tests {
|
|
|
+
|
|
|
+ tls := &tls.Config{
|
|
|
+ RootCAs: cas,
|
|
|
+ ServerName: tt.tls,
|
|
|
+ InsecureSkipVerify: tt.insecureSkipVerify,
|
|
|
+ }
|
|
|
+
|
|
|
+ var gotAddr string
|
|
|
+ dialer := Dialer{
|
|
|
+ NetDial: func(network, addr string) (net.Conn, error) {
|
|
|
+ gotAddr = addr
|
|
|
+ return net.Dial(network, addrs[tt.server])
|
|
|
+ },
|
|
|
+ TLSClientConfig: tls,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test websocket dial
|
|
|
+
|
|
|
+ h := http.Header{}
|
|
|
+ if tt.header != "" {
|
|
|
+ h.Set("Host", tt.header)
|
|
|
+ }
|
|
|
+ c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
|
|
|
+ if err == nil {
|
|
|
+ c.Close()
|
|
|
+ }
|
|
|
+
|
|
|
+ check := func(protos map[*httptest.Server]string) {
|
|
|
+ name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
|
|
|
+ if gotAddr != tt.wantAddr {
|
|
|
+ t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
|
|
|
+ }
|
|
|
+ switch {
|
|
|
+ case tt.fail && err == nil:
|
|
|
+ t.Errorf("%s: unexpected success", name)
|
|
|
+ case !tt.fail && err != nil:
|
|
|
+ t.Errorf("%s: unexpected error %v", name, err)
|
|
|
+ case !tt.fail && err == nil:
|
|
|
+ if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
|
|
|
+ t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ check(wsProtos)
|
|
|
+
|
|
|
+ // Confirm that net/http has same result
|
|
|
+
|
|
|
+ transport := &http.Transport{
|
|
|
+ Dial: dialer.NetDial,
|
|
|
+ TLSClientConfig: dialer.TLSClientConfig,
|
|
|
+ }
|
|
|
+ req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
|
|
|
+ if tt.header != "" {
|
|
|
+ req.Host = tt.header
|
|
|
+ }
|
|
|
+ client := &http.Client{Transport: transport}
|
|
|
+ resp, err = client.Do(req)
|
|
|
+ if err == nil {
|
|
|
+ resp.Body.Close()
|
|
|
+ }
|
|
|
+ transport.CloseIdleConnections()
|
|
|
+ check(httpProtos)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func TestDialCompression(t *testing.T) {
|
|
|
@@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
|
|
|
s := newTLSServer(t)
|
|
|
defer s.Close()
|
|
|
|
|
|
- certs := x509.NewCertPool()
|
|
|
- for _, c := range s.TLS.Certificates {
|
|
|
- roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
|
|
- if err != nil {
|
|
|
- t.Fatalf("error parsing server's root cert: %v", err)
|
|
|
- }
|
|
|
- for _, root := range roots {
|
|
|
- certs.AddCert(root)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
d := cstDialer
|
|
|
- d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
|
|
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
|
|
|
|
|
ws, _, err := d.DialContext(ctx, s.URL, nil)
|
|
|
if err != nil {
|
|
|
@@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
|
|
|
s := newTLSServer(t)
|
|
|
defer s.Close()
|
|
|
|
|
|
- certs := x509.NewCertPool()
|
|
|
- for _, c := range s.TLS.Certificates {
|
|
|
- roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
|
|
- if err != nil {
|
|
|
- t.Fatalf("error parsing server's root cert: %v", err)
|
|
|
- }
|
|
|
- for _, root := range roots {
|
|
|
- certs.AddCert(root)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
d := cstDialer
|
|
|
- d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
|
|
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
|
|
|
|
|
ws, _, err := d.DialContext(ctx, s.URL, nil)
|
|
|
if err != nil {
|