Browse Source

Add context in the Dialer

SALLEYRON Julien 7 years ago
parent
commit
ceae45234a
4 changed files with 230 additions and 16 deletions
  1. 70 16
      client.go
  2. 129 0
      client_server_test.go
  3. 19 0
      trace.go
  4. 12 0
      trace_17.go

+ 70 - 16
client.go

@@ -6,12 +6,14 @@ package websocket
 
 import (
 	"bytes"
+	"context"
 	"crypto/tls"
 	"errors"
 	"io"
 	"io/ioutil"
 	"net"
 	"net/http"
+	"net/http/httptrace"
 	"net/url"
 	"strings"
 	"time"
@@ -51,6 +53,10 @@ type Dialer struct {
 	// NetDial is nil, net.Dial is used.
 	NetDial func(network, addr string) (net.Conn, error)
 
+	// NetDialContext specifies the dial function for creating TCP connections. If
+	// NetDialContext is nil, net.DialContext is used.
+	NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
 	// Proxy specifies a function to return a proxy for a given
 	// Request. If the function returns a non-nil error, the
 	// request is aborted with the provided error.
@@ -95,6 +101,11 @@ type Dialer struct {
 	Jar http.CookieJar
 }
 
+// Dial creates a new client connection by calling DialContext with a background context.
+func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
+	return d.DialContext(urlStr, requestHeader, context.Background())
+}
+
 var errMalformedURL = errors.New("malformed ws or wss URL")
 
 func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
@@ -124,17 +135,18 @@ var DefaultDialer = &Dialer{
 // nilDialer is dialer to use when receiver is nil.
 var nilDialer Dialer = *DefaultDialer
 
-// Dial creates a new client connection. Use requestHeader to specify the
+// DialContext creates a new client connection. Use requestHeader to specify the
 // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
 // Use the response.Header to get the selected subprotocol
 // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
 //
+// The context will be used in the request and in the Dialer
+//
 // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
 // non-nil *http.Response so that callers can handle redirects, authentication,
 // etcetera. The response body may not contain the entire response and does not
 // need to be closed by the application.
-func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
-
+func (d *Dialer) DialContext(urlStr string, requestHeader http.Header, ctx context.Context) (*Conn, *http.Response, error) {
 	if d == nil {
 		d = &nilDialer
 	}
@@ -172,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		Header:     make(http.Header),
 		Host:       u.Host,
 	}
+	req = req.WithContext(ctx)
 
 	// Set the cookies present in the cookie jar of the dialer
 	if d.Jar != nil {
@@ -215,20 +228,30 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
 	}
 
-	var deadline time.Time
 	if d.HandshakeTimeout != 0 {
-		deadline = time.Now().Add(d.HandshakeTimeout)
+		var cancel func()
+		ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
+		defer cancel()
 	}
 
 	// Get network dial function.
-	netDial := d.NetDial
-	if netDial == nil {
-		netDialer := &net.Dialer{Deadline: deadline}
-		netDial = netDialer.Dial
+	var netDial func(network, add string) (net.Conn, error)
+
+	if d.NetDialContext != nil {
+		netDial = func(network, addr string) (net.Conn, error) {
+			return d.NetDialContext(ctx, network, addr)
+		}
+	} else if d.NetDial != nil {
+		netDial = d.NetDial
+	} else {
+		netDialer := &net.Dialer{}
+		netDial = func(network, addr string) (net.Conn, error) {
+			return netDialer.DialContext(ctx, network, addr)
+		}
 	}
 
 	// If needed, wrap the dial function to set the connection deadline.
-	if !deadline.Equal(time.Time{}) {
+	if deadline, ok := ctx.Deadline(); ok {
 		forwardDial := netDial
 		netDial = func(network, addr string) (net.Conn, error) {
 			c, err := forwardDial(network, addr)
@@ -260,7 +283,17 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 	}
 
 	hostPort, hostNoPort := hostPortNoPort(u)
+	trace := httptrace.ContextClientTrace(ctx)
+	if trace != nil && trace.GetConn != nil {
+		trace.GetConn(hostPort)
+	}
+
 	netConn, err := netDial("tcp", hostPort)
+	if trace != nil && trace.GotConn != nil {
+		trace.GotConn(httptrace.GotConnInfo{
+			Conn: netConn,
+		})
+	}
 	if err != nil {
 		return nil, nil, err
 	}
@@ -278,13 +311,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		}
 		tlsConn := tls.Client(netConn, cfg)
 		netConn = tlsConn
-		if err := tlsConn.Handshake(); err != nil {
-			return nil, nil, err
+
+		var err error
+		if trace != nil {
+			err = doHandshakeWithTrace(trace, tlsConn, cfg)
+		} else {
+			err = doHandshake(tlsConn, cfg)
 		}
-		if !cfg.InsecureSkipVerify {
-			if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
-				return nil, nil, err
-			}
+
+		if err != nil {
+			return nil, nil, err
 		}
 	}
 
@@ -294,6 +330,12 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		return nil, nil, err
 	}
 
+	if trace != nil && trace.GotFirstResponseByte != nil {
+		if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
+			trace.GotFirstResponseByte()
+		}
+	}
+
 	resp, err := http.ReadResponse(conn.br, req)
 	if err != nil {
 		return nil, nil, err
@@ -339,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 	netConn = nil // to avoid close in defer.
 	return conn, resp, nil
 }
+
+func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
+	if err := tlsConn.Handshake(); err != nil {
+		return err
+	}
+	if !cfg.InsecureSkipVerify {
+		if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+			return err
+		}
+	}
+	return nil
+}

+ 129 - 0
client_server_test.go

@@ -6,6 +6,7 @@ package websocket
 
 import (
 	"bytes"
+	"context"
 	"crypto/tls"
 	"crypto/x509"
 	"encoding/base64"
@@ -16,6 +17,7 @@ import (
 	"net/http"
 	"net/http/cookiejar"
 	"net/http/httptest"
+	"net/http/httptrace"
 	"net/url"
 	"reflect"
 	"strings"
@@ -40,6 +42,12 @@ var cstDialer = Dialer{
 	HandshakeTimeout: 30 * time.Second,
 }
 
+var cstDialerWithoutHandshakeTimeout = Dialer{
+	Subprotocols:    []string{"p1", "p2"},
+	ReadBufferSize:  1024,
+	WriteBufferSize: 1024,
+}
+
 type cstHandler struct{ *testing.T }
 
 type cstServer struct {
@@ -403,6 +411,26 @@ func TestHandshakeTimeout(t *testing.T) {
 	ws.Close()
 }
 
+func TestHandshakeTimeoutInContext(t *testing.T) {
+	s := newServer(t)
+	defer s.Close()
+
+	d := cstDialerWithoutHandshakeTimeout
+	d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
+		netDialer := &net.Dialer{}
+		c, err := netDialer.DialContext(ctx, n, a)
+		return &requireDeadlineNetConn{c: c, t: t}, err
+	}
+
+	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
+	defer cancel()
+	ws, _, err := d.DialContext(s.URL, nil, ctx)
+	if err != nil {
+		t.Fatal("Dial:", err)
+	}
+	ws.Close()
+}
+
 func TestDialBadScheme(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()
@@ -659,3 +687,104 @@ func TestSocksProxyDial(t *testing.T) {
 	defer ws.Close()
 	sendRecv(t, ws)
 }
+
+func TestTracingDialWithContext(t *testing.T) {
+
+	var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
+	trace := &httptrace.ClientTrace{
+		WroteHeaders: func() {
+			headersWrote = true
+		},
+		WroteRequest: func(httptrace.WroteRequestInfo) {
+			requestWrote = true
+		},
+		GetConn: func(hostPort string) {
+			getConn = true
+		},
+		GotConn: func(info httptrace.GotConnInfo) {
+			gotConn = true
+		},
+		ConnectDone: func(network, addr string, err error) {
+			connectDone = true
+		},
+		GotFirstResponseByte: func() {
+			gotFirstResponseByte = true
+		},
+	}
+	ctx := httptrace.WithClientTrace(context.Background(), trace)
+
+	s := newTLSServer(t)
+	defer s.Close()
+
+	certs := x509.NewCertPool()
+	for _, c := range s.TLS.Certificates {
+		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+		if err != nil {
+			t.Fatalf("error parsing server's root cert: %v", err)
+		}
+		for _, root := range roots {
+			certs.AddCert(root)
+		}
+	}
+
+	d := cstDialer
+	d.TLSClientConfig = &tls.Config{RootCAs: certs}
+
+	ws, _, err := d.DialContext(s.URL, nil, ctx)
+	if err != nil {
+		t.Fatalf("Dial: %v", err)
+	}
+
+	if !headersWrote {
+		t.Fatal("Headers was not written")
+	}
+	if !requestWrote {
+		t.Fatal("Request was not written")
+	}
+	if !getConn {
+		t.Fatal("getConn was not called")
+	}
+	if !gotConn {
+		t.Fatal("gotConn was not called")
+	}
+	if !connectDone {
+		t.Fatal("connectDone was not called")
+	}
+	if !gotFirstResponseByte {
+		t.Fatal("GotFirstResponseByte was not called")
+	}
+
+	defer ws.Close()
+	sendRecv(t, ws)
+}
+
+func TestEmptyTracingDialWithContext(t *testing.T) {
+
+	trace := &httptrace.ClientTrace{}
+	ctx := httptrace.WithClientTrace(context.Background(), trace)
+
+	s := newTLSServer(t)
+	defer s.Close()
+
+	certs := x509.NewCertPool()
+	for _, c := range s.TLS.Certificates {
+		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+		if err != nil {
+			t.Fatalf("error parsing server's root cert: %v", err)
+		}
+		for _, root := range roots {
+			certs.AddCert(root)
+		}
+	}
+
+	d := cstDialer
+	d.TLSClientConfig = &tls.Config{RootCAs: certs}
+
+	ws, _, err := d.DialContext(s.URL, nil, ctx)
+	if err != nil {
+		t.Fatalf("Dial: %v", err)
+	}
+
+	defer ws.Close()
+	sendRecv(t, ws)
+}

+ 19 - 0
trace.go

@@ -0,0 +1,19 @@
+// +build go1.8
+
+package websocket
+
+import (
+	"crypto/tls"
+	"net/http/httptrace"
+)
+
+func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
+	if trace.TLSHandshakeStart != nil {
+		trace.TLSHandshakeStart()
+	}
+	err := doHandshake(tlsConn, cfg)
+	if trace.TLSHandshakeDone != nil {
+		trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
+	}
+	return err
+}

+ 12 - 0
trace_17.go

@@ -0,0 +1,12 @@
+// +build !go1.8
+
+package websocket
+
+import (
+	"crypto/tls"
+	"net/http/httptrace"
+)
+
+func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
+	return doHandshake(tlsConn, cfg)
+}