|
|
@@ -96,7 +96,9 @@ type Dialer struct {
|
|
|
|
|
|
var errMalformedURL = errors.New("malformed ws or wss URL")
|
|
|
|
|
|
-func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
|
|
|
+// parseURL parses the URL. The url.Parse function is not used here because
|
|
|
+// url.Parse mangles the path.
|
|
|
+func parseURL(s string) (*url.URL, error) {
|
|
|
// From the RFC:
|
|
|
//
|
|
|
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
|
|
@@ -106,33 +108,41 @@ func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
|
|
|
// not provide a way for applications to work around percent deocding in
|
|
|
// the net/url parser.
|
|
|
|
|
|
+ var u url.URL
|
|
|
switch {
|
|
|
- case strings.HasPrefix(u, "ws://"):
|
|
|
- u = u[len("ws://"):]
|
|
|
- case strings.HasPrefix(u, "wss://"):
|
|
|
- u = u[len("wss://"):]
|
|
|
- useTLS = true
|
|
|
+ case strings.HasPrefix(s, "ws://"):
|
|
|
+ u.Scheme = "ws"
|
|
|
+ s = s[len("ws://"):]
|
|
|
+ case strings.HasPrefix(s, "wss://"):
|
|
|
+ u.Scheme = "wss"
|
|
|
+ s = s[len("wss://"):]
|
|
|
default:
|
|
|
- return false, "", "", "", errMalformedURL
|
|
|
+ return nil, errMalformedURL
|
|
|
}
|
|
|
|
|
|
- hostPort := u
|
|
|
- opaque = "/"
|
|
|
- if i := strings.Index(u, "/"); i >= 0 {
|
|
|
- hostPort = u[:i]
|
|
|
- opaque = u[i:]
|
|
|
+ u.Host = s
|
|
|
+ u.Opaque = "/"
|
|
|
+ if i := strings.Index(s, "/"); i >= 0 {
|
|
|
+ u.Host = s[:i]
|
|
|
+ u.Opaque = s[i:]
|
|
|
}
|
|
|
|
|
|
- host = hostPort
|
|
|
- port = ":80"
|
|
|
- if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") {
|
|
|
- host = hostPort[:i]
|
|
|
- port = hostPort[i:]
|
|
|
- } else if useTLS {
|
|
|
- port = ":443"
|
|
|
- }
|
|
|
+ return &u, nil
|
|
|
+}
|
|
|
|
|
|
- return useTLS, host, port, opaque, nil
|
|
|
+func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
|
|
|
+ hostPort = u.Host
|
|
|
+ hostNoPort = u.Host
|
|
|
+ if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
|
|
|
+ hostNoPort = hostNoPort[:i]
|
|
|
+ } else {
|
|
|
+ if u.Scheme == "wss" {
|
|
|
+ hostPort += ":443"
|
|
|
+ } else {
|
|
|
+ hostPort += ":80"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return hostPort, hostNoPort
|
|
|
}
|
|
|
|
|
|
// DefaultDialer is a dialer with all fields set to the default zero values.
|
|
|
@@ -147,12 +157,13 @@ var DefaultDialer *Dialer
|
|
|
// non-nil *http.Response so that callers can handle redirects, authentication,
|
|
|
// etc.
|
|
|
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
|
|
-
|
|
|
- useTLS, host, port, opaque, err := parseURL(urlStr)
|
|
|
+ u, err := parseURL(urlStr)
|
|
|
if err != nil {
|
|
|
return nil, nil, err
|
|
|
}
|
|
|
|
|
|
+ hostPort, hostNoPort := hostPortNoPort(u)
|
|
|
+
|
|
|
if d == nil {
|
|
|
d = &Dialer{}
|
|
|
}
|
|
|
@@ -168,7 +179,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|
|
netDial = netDialer.Dial
|
|
|
}
|
|
|
|
|
|
- netConn, err := netDial("tcp", host+port)
|
|
|
+ netConn, err := netDial("tcp", hostPort)
|
|
|
if err != nil {
|
|
|
return nil, nil, err
|
|
|
}
|
|
|
@@ -183,14 +194,14 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|
|
return nil, nil, err
|
|
|
}
|
|
|
|
|
|
- if useTLS {
|
|
|
+ if u.Scheme == "wss" {
|
|
|
cfg := d.TLSClientConfig
|
|
|
if cfg == nil {
|
|
|
- cfg = &tls.Config{ServerName: host}
|
|
|
+ cfg = &tls.Config{ServerName: hostNoPort}
|
|
|
} else if cfg.ServerName == "" {
|
|
|
shallowCopy := *cfg
|
|
|
cfg = &shallowCopy
|
|
|
- cfg.ServerName = host
|
|
|
+ cfg.ServerName = hostNoPort
|
|
|
}
|
|
|
tlsConn := tls.Client(netConn, cfg)
|
|
|
netConn = tlsConn
|
|
|
@@ -223,10 +234,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|
|
requestHeader = h
|
|
|
}
|
|
|
|
|
|
- conn, resp, err := NewClient(
|
|
|
- netConn,
|
|
|
- &url.URL{Host: host + port, Opaque: opaque},
|
|
|
- requestHeader, readBufferSize, writeBufferSize)
|
|
|
+ conn, resp, err := NewClient(netConn, u, requestHeader, readBufferSize, writeBufferSize)
|
|
|
if err != nil {
|
|
|
return nil, resp, err
|
|
|
}
|