瀏覽代碼

Refactor client handshake

- To take advantage of the Host header cleanup in the net/http
  Request.Write method, use a net/http Request to write the handshake to
  the wire.
- Move code from the deprecated NewClientConn function to Dialer.Dial.
  This change makes it easier to add proxy support to Dialer.Dial. Add
  comment noting that NewClientConn is deprecated.
- Update the code so that parseURL can be replaced with net/url Parse.
  We need to wait until we can require 1.5 before making the swap.
Gary Burd 10 年之前
父節點
當前提交
5ed2f4547d
共有 2 個文件被更改,包括 94 次插入93 次删除
  1. 91 86
      client.go
  2. 3 7
      client_server_test.go

+ 91 - 86
client.go

@@ -30,50 +30,17 @@ var ErrBadHandshake = errors.New("websocket: bad handshake")
 // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
 // non-nil *http.Response so that callers can handle redirects, authentication,
 // etc.
+//
+// Deprecated: Use Dialer instead.
 func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
-	challengeKey, err := generateChallengeKey()
-	if err != nil {
-		return nil, nil, err
+	d := Dialer{
+		ReadBufferSize:  readBufSize,
+		WriteBufferSize: writeBufSize,
+		NetDial: func(net, addr string) (net.Conn, error) {
+			return netConn, nil
+		},
 	}
-	acceptKey := computeAcceptKey(challengeKey)
-
-	c = newConn(netConn, false, readBufSize, writeBufSize)
-	p := c.writeBuf[:0]
-	p = append(p, "GET "...)
-	p = append(p, u.RequestURI()...)
-	p = append(p, " HTTP/1.1\r\nHost: "...)
-	p = append(p, u.Host...)
-	// "Upgrade" is capitalized for servers that do not use case insensitive
-	// comparisons on header tokens.
-	p = append(p, "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...)
-	p = append(p, challengeKey...)
-	p = append(p, "\r\n"...)
-	for k, vs := range requestHeader {
-		for _, v := range vs {
-			p = append(p, k...)
-			p = append(p, ": "...)
-			p = append(p, v...)
-			p = append(p, "\r\n"...)
-		}
-	}
-	p = append(p, "\r\n"...)
-
-	if _, err := netConn.Write(p); err != nil {
-		return nil, nil, err
-	}
-
-	resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
-	if err != nil {
-		return nil, nil, err
-	}
-	if resp.StatusCode != 101 ||
-		!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
-		!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
-		resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
-		return nil, resp, ErrBadHandshake
-	}
-	c.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
-	return c, resp, nil
+	return d.Dial(u.String(), requestHeader)
 }
 
 // A Dialer contains options for connecting to WebSocket server.
@@ -99,17 +66,15 @@ type Dialer struct {
 
 var errMalformedURL = errors.New("malformed ws or wss URL")
 
-// parseURL parses the URL. The url.Parse function is not used here because
-// url.Parse mangles the path.
+// parseURL parses the URL.
+//
+// This function is a replacement for the standard library url.Parse function.
+// In Go 1.4 and earlier, url.Parse loses information from the path.
 func parseURL(s string) (*url.URL, error) {
 	// From the RFC:
 	//
 	// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
 	// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
-	//
-	// We don't use the net/url parser here because the dialer interface does
-	// not provide a way for applications to work around percent deocding in
-	// the net/url parser.
 
 	var u url.URL
 	switch {
@@ -131,7 +96,8 @@ func parseURL(s string) (*url.URL, error) {
 	}
 
 	if strings.Contains(u.Host, "@") {
-		// WebSocket URIs do not contain user information.
+		// Don't bother parsing user information because user information is
+		// not allowed in websocket URIs.
 		return nil, errMalformedURL
 	}
 
@@ -166,16 +132,67 @@ var DefaultDialer = &Dialer{}
 // etcetera. The response body may not contain the entire response and does not
 // need to be closed by the application.
 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
+
+	if d == nil {
+		d = &Dialer{}
+	}
+
+	challengeKey, err := generateChallengeKey()
+	if err != nil {
+		return nil, nil, err
+	}
+
 	u, err := parseURL(urlStr)
 	if err != nil {
 		return nil, nil, err
 	}
 
-	hostPort, hostNoPort := hostPortNoPort(u)
+	switch u.Scheme {
+	case "ws":
+		u.Scheme = "http"
+	case "wss":
+		u.Scheme = "https"
+	default:
+		return nil, nil, errMalformedURL
+	}
 
-	if d == nil {
-		d = &Dialer{}
+	if u.User != nil {
+		// User name and password are not allowed in websocket URIs.
+		return nil, nil, errMalformedURL
+	}
+
+	req := &http.Request{
+		Method:     "GET",
+		URL:        u,
+		Proto:      "HTTP/1.1",
+		ProtoMajor: 1,
+		ProtoMinor: 1,
+		Header:     make(http.Header),
+		Host:       u.Host,
+	}
+
+	// Set the request headers using the capitalization for names and values in
+	// RFC examples. Although the capitalization shouldn't matter, there are
+	// servers that depend on it. The Header.Set method is not used because the
+	// method canonicalizes the header names.
+	req.Header["Upgrade"] = []string{"websocket"}
+	req.Header["Connection"] = []string{"Upgrade"}
+	req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
+	req.Header["Sec-WebSocket-Version"] = []string{"13"}
+	if len(d.Subprotocols) > 0 {
+		req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
 	}
+	for k, vs := range requestHeader {
+		if k == "Host" {
+			if len(vs) > 0 {
+				req.Host = vs[0]
+			}
+		} else {
+			req.Header[k] = vs
+		}
+	}
+
+	hostPort, hostNoPort := hostPortNoPort(u)
 
 	var deadline time.Time
 	if d.HandshakeTimeout != 0 {
@@ -203,7 +220,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		return nil, nil, err
 	}
 
-	if u.Scheme == "wss" {
+	if u.Scheme == "https" {
 		cfg := d.TLSClientConfig
 		if cfg == nil {
 			cfg = &tls.Config{ServerName: hostNoPort}
@@ -224,45 +241,33 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		}
 	}
 
-	if len(d.Subprotocols) > 0 {
-		h := http.Header{}
-		for k, v := range requestHeader {
-			h[k] = v
-		}
-		h.Set("Sec-Websocket-Protocol", strings.Join(d.Subprotocols, ", "))
-		requestHeader = h
-	}
-
-	if len(requestHeader["Host"]) > 0 {
-		// This can be used to supply a Host: header which is different from
-		// the dial address.
-		u.Host = requestHeader.Get("Host")
+	conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
 
-		// Drop "Host" header
-		h := http.Header{}
-		for k, v := range requestHeader {
-			if k == "Host" {
-				continue
-			}
-			h[k] = v
-		}
-		requestHeader = h
+	if err := req.Write(netConn); err != nil {
+		return nil, nil, err
 	}
 
-	conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
-
+	resp, err := http.ReadResponse(conn.br, req)
 	if err != nil {
-		if err == ErrBadHandshake {
-			// Before closing the network connection on return from this
-			// function, slurp up some of the response to aid application
-			// debugging.
-			buf := make([]byte, 1024)
-			n, _ := io.ReadFull(resp.Body, buf)
-			resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
-		}
-		return nil, resp, err
+		return nil, nil, err
+	}
+	if resp.StatusCode != 101 ||
+		!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
+		!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
+		resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
+		// Before closing the network connection on return from this
+		// function, slurp up some of the response to aid application
+		// debugging.
+		buf := make([]byte, 1024)
+		n, _ := io.ReadFull(resp.Body, buf)
+		resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
+		return nil, resp, ErrBadHandshake
+	} else {
+		resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
 	}
 
+	conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
+
 	netConn.SetDeadline(time.Time{})
 	netConn = nil // to avoid close in defer.
 	return conn, resp, nil

+ 3 - 7
client_server_test.go

@@ -289,8 +289,8 @@ func TestRespOnBadHandshake(t *testing.T) {
 	}
 }
 
-// If the Host header is specified in `Dial()`, the server must receive it as
-// the `Host:` header.
+// TestHostHeader confirms that the host header provided in the call to Dial is
+// sent to the server.
 func TestHostHeader(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()
@@ -305,16 +305,12 @@ func TestHostHeader(t *testing.T) {
 			origHandler.ServeHTTP(w, r)
 		})
 
-	ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
+	ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
 	if err != nil {
 		t.Fatalf("Dial: %v", err)
 	}
 	defer ws.Close()
 
-	if resp.StatusCode != http.StatusSwitchingProtocols {
-		t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode)
-	}
-
 	if gotHost := <-specifiedHost; gotHost != "testhost" {
 		t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
 	}