Bläddra i källkod

Merge pull request #90 from wolfeidau/add_proxy_support

Proxy support for websocket clients.
Gary Burd 10 år sedan
förälder
incheckning
527637c0f3
2 ändrade filer med 98 tillägg och 5 borttagningar
  1. 59 5
      client.go
  2. 39 0
      client_server_test.go

+ 59 - 5
client.go

@@ -5,6 +5,7 @@
 package websocket
 
 import (
+	"bufio"
 	"bytes"
 	"crypto/tls"
 	"errors"
@@ -49,6 +50,12 @@ type Dialer struct {
 	// NetDial is nil, net.Dial is used.
 	NetDial func(network, addr string) (net.Conn, error)
 
+	// Proxy specifies a function to return a proxy for a given
+	// Request. If the function returns a non-nil error, the
+	// request is aborted with the provided error.
+	// If Proxy is nil or returns a nil *URL, no proxy is used.
+	Proxy func(*http.Request) (*url.URL, error)
+
 	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
 	// If nil, the default configuration is used.
 	TLSClientConfig *tls.Config
@@ -110,9 +117,12 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
 	if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
 		hostNoPort = hostNoPort[:i]
 	} else {
-		if u.Scheme == "wss" {
+		switch u.Scheme {
+		case "wss":
 			hostPort += ":443"
-		} else {
+		case "https":
+			hostPort += ":443"
+		default:
 			hostPort += ":80"
 		}
 	}
@@ -120,7 +130,9 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
 }
 
 // DefaultDialer is a dialer with all fields set to the default zero values.
-var DefaultDialer = &Dialer{}
+var DefaultDialer = &Dialer{
+	Proxy: http.ProxyFromEnvironment,
+}
 
 // Dial creates a new client connection. Use requestHeader to specify the
 // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
@@ -134,7 +146,9 @@ var DefaultDialer = &Dialer{}
 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
 
 	if d == nil {
-		d = &Dialer{}
+		d = &Dialer{
+			Proxy: http.ProxyFromEnvironment,
+		}
 	}
 
 	challengeKey, err := generateChallengeKey()
@@ -194,6 +208,22 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 
 	hostPort, hostNoPort := hostPortNoPort(u)
 
+	var proxyURL *url.URL
+	// Check wether the proxy method has been configured
+	if d.Proxy != nil {
+		proxyURL, err = d.Proxy(req)
+	}
+	if err != nil {
+		return nil, nil, err
+	}
+
+	var targetHostPort string
+	if proxyURL != nil {
+		targetHostPort, _ = hostPortNoPort(proxyURL)
+	} else {
+		targetHostPort = hostPort
+	}
+
 	var deadline time.Time
 	if d.HandshakeTimeout != 0 {
 		deadline = time.Now().Add(d.HandshakeTimeout)
@@ -205,7 +235,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		netDial = netDialer.Dial
 	}
 
-	netConn, err := netDial("tcp", hostPort)
+	netConn, err := netDial("tcp", targetHostPort)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -220,6 +250,30 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		return nil, nil, err
 	}
 
+	if proxyURL != nil {
+		connectReq := &http.Request{
+			Method: "CONNECT",
+			URL:    &url.URL{Opaque: hostPort},
+			Host:   hostPort,
+			Header: make(http.Header),
+		}
+
+		connectReq.Write(netConn)
+
+		// Read response.
+		// Okay to use and discard buffered reader here, because
+		// TLS server will not speak until spoken to.
+		br := bufio.NewReader(netConn)
+		resp, err := http.ReadResponse(br, connectReq)
+		if err != nil {
+			return nil, nil, err
+		}
+		if resp.StatusCode != 200 {
+			f := strings.SplitN(resp.Status, " ", 2)
+			return nil, nil, errors.New(f[1])
+		}
+	}
+
 	if u.Scheme == "https" {
 		cfg := d.TLSClientConfig
 		if cfg == nil {

+ 39 - 0
client_server_test.go

@@ -123,6 +123,45 @@ func sendRecv(t *testing.T, ws *Conn) {
 	}
 }
 
+func TestProxyDial(t *testing.T) {
+
+	s := newServer(t)
+	defer s.Close()
+
+	surl, _ := url.Parse(s.URL)
+
+	cstDialer.Proxy = http.ProxyURL(surl)
+
+	connect := false
+	origHandler := s.Server.Config.Handler
+
+	// Capture the request Host header.
+	s.Server.Config.Handler = http.HandlerFunc(
+		func(w http.ResponseWriter, r *http.Request) {
+			if r.Method == "CONNECT" {
+				connect = true
+				w.WriteHeader(200)
+				return
+			}
+
+			if !connect {
+				t.Log("connect not recieved")
+				http.Error(w, "connect not recieved", 405)
+				return
+			}
+			origHandler.ServeHTTP(w, r)
+		})
+
+	ws, _, err := cstDialer.Dial(s.URL, nil)
+	if err != nil {
+		t.Fatalf("Dial: %v", err)
+	}
+	defer ws.Close()
+	sendRecv(t, ws)
+
+	cstDialer.Proxy = http.ProxyFromEnvironment
+}
+
 func TestDial(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()