فهرست منبع

Improve client host header handling.

- Set request host header to substring of the URL. Do not add default
  port to string.
- Do not include port when verifying TLS host name.
Gary Burd 11 سال پیش
والد
کامیت
db7a2a1679
2فایلهای تغییر یافته به همراه102 افزوده شده و 31 حذف شده
  1. 39 31
      client.go
  2. 63 0
      client_test.go

+ 39 - 31
client.go

@@ -96,7 +96,9 @@ type Dialer struct {
 
 var errMalformedURL = errors.New("malformed ws or wss URL")
 
-func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
+// parseURL parses the URL. The url.Parse function is not used here because
+// url.Parse mangles the path.
+func parseURL(s string) (*url.URL, error) {
 	// From the RFC:
 	//
 	// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
@@ -106,33 +108,41 @@ func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
 	// not provide a way for applications to work around percent deocding in
 	// the net/url parser.
 
+	var u url.URL
 	switch {
-	case strings.HasPrefix(u, "ws://"):
-		u = u[len("ws://"):]
-	case strings.HasPrefix(u, "wss://"):
-		u = u[len("wss://"):]
-		useTLS = true
+	case strings.HasPrefix(s, "ws://"):
+		u.Scheme = "ws"
+		s = s[len("ws://"):]
+	case strings.HasPrefix(s, "wss://"):
+		u.Scheme = "wss"
+		s = s[len("wss://"):]
 	default:
-		return false, "", "", "", errMalformedURL
+		return nil, errMalformedURL
 	}
 
-	hostPort := u
-	opaque = "/"
-	if i := strings.Index(u, "/"); i >= 0 {
-		hostPort = u[:i]
-		opaque = u[i:]
+	u.Host = s
+	u.Opaque = "/"
+	if i := strings.Index(s, "/"); i >= 0 {
+		u.Host = s[:i]
+		u.Opaque = s[i:]
 	}
 
-	host = hostPort
-	port = ":80"
-	if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") {
-		host = hostPort[:i]
-		port = hostPort[i:]
-	} else if useTLS {
-		port = ":443"
-	}
+	return &u, nil
+}
 
-	return useTLS, host, port, opaque, nil
+func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
+	hostPort = u.Host
+	hostNoPort = u.Host
+	if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
+		hostNoPort = hostNoPort[:i]
+	} else {
+		if u.Scheme == "wss" {
+			hostPort += ":443"
+		} else {
+			hostPort += ":80"
+		}
+	}
+	return hostPort, hostNoPort
 }
 
 // DefaultDialer is a dialer with all fields set to the default zero values.
@@ -147,12 +157,13 @@ var DefaultDialer *Dialer
 // non-nil *http.Response so that callers can handle redirects, authentication,
 // etc.
 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
-
-	useTLS, host, port, opaque, err := parseURL(urlStr)
+	u, err := parseURL(urlStr)
 	if err != nil {
 		return nil, nil, err
 	}
 
+	hostPort, hostNoPort := hostPortNoPort(u)
+
 	if d == nil {
 		d = &Dialer{}
 	}
@@ -168,7 +179,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		netDial = netDialer.Dial
 	}
 
-	netConn, err := netDial("tcp", host+port)
+	netConn, err := netDial("tcp", hostPort)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -183,14 +194,14 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		return nil, nil, err
 	}
 
-	if useTLS {
+	if u.Scheme == "wss" {
 		cfg := d.TLSClientConfig
 		if cfg == nil {
-			cfg = &tls.Config{ServerName: host}
+			cfg = &tls.Config{ServerName: hostNoPort}
 		} else if cfg.ServerName == "" {
 			shallowCopy := *cfg
 			cfg = &shallowCopy
-			cfg.ServerName = host
+			cfg.ServerName = hostNoPort
 		}
 		tlsConn := tls.Client(netConn, cfg)
 		netConn = tlsConn
@@ -223,10 +234,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		requestHeader = h
 	}
 
-	conn, resp, err := NewClient(
-		netConn,
-		&url.URL{Host: host + port, Opaque: opaque},
-		requestHeader, readBufferSize, writeBufferSize)
+	conn, resp, err := NewClient(netConn, u, requestHeader, readBufferSize, writeBufferSize)
 	if err != nil {
 		return nil, resp, err
 	}

+ 63 - 0
client_test.go

@@ -0,0 +1,63 @@
+// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package websocket
+
+import (
+	"net/url"
+	"reflect"
+	"testing"
+)
+
+var parseURLTests = []struct {
+	s string
+	u *url.URL
+}{
+	{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
+	{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
+	{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}},
+	{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
+	{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
+	{"ss://example.com/a/b", nil},
+}
+
+func TestParseURL(t *testing.T) {
+	for _, tt := range parseURLTests {
+		u, err := parseURL(tt.s)
+		if tt.u != nil && err != nil {
+			t.Errorf("parseURL(%q) returned error %v", tt.s, err)
+			continue
+		}
+		if tt.u == nil && err == nil {
+			t.Errorf("parseURL(%q) did not return error", tt.s)
+			continue
+		}
+		if !reflect.DeepEqual(u, tt.u) {
+			t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u)
+			continue
+		}
+	}
+}
+
+var hostPortNoPortTests = []struct {
+	u                    *url.URL
+	hostPort, hostNoPort string
+}{
+	{&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
+	{&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
+	{&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
+	{&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
+}
+
+func TestHostPortNoPort(t *testing.T) {
+	for _, tt := range hostPortNoPortTests {
+		hostPort, hostNoPort := hostPortNoPort(tt.u)
+		if hostPort != tt.hostPort {
+			t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
+		}
+		if hostNoPort != tt.hostNoPort {
+			t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
+		}
+	}
+}