Gary Burd 12 лет назад
Родитель
Сommit
87accaef66
2 измененных файлов с 294 добавлено и 17 удалено
  1. 152 1
      client.go
  2. 142 16
      client_server_test.go

+ 152 - 1
client.go

@@ -5,11 +5,13 @@
 package websocket
 
 import (
+	"crypto/tls"
 	"errors"
 	"net"
 	"net/http"
 	"net/url"
 	"strings"
+	"time"
 )
 
 // ErrBadHandshake is returned when the server response to opening handshake is
@@ -18,7 +20,7 @@ var ErrBadHandshake = errors.New("websocket: bad handshake")
 
 // NewClient creates a new client connection using the given net connection.
 // The URL u specifies the host and request URI. Use requestHeader to specify
-// the origin (Origin), subprotocols (Set-WebSocket-Protocol) and cookies
+// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
 // (Cookie). Use the response.Header to get the selected subprotocol
 // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
 //
@@ -67,3 +69,152 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
 	}
 	return c, resp, nil
 }
+
+type Dialer struct {
+	// NetDial specifies the dial function for creating TCP connections. If
+	// NetDial is nil, net.Dial is used.
+	NetDial func(network, addr string) (net.Conn, error)
+
+	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
+	// If nil, the default configuration is used.
+	TLSClientConfig *tls.Config
+
+	// HandshakeTimeout specifies the duration for the handshake to complete.
+	HandshakeTimeout time.Duration
+
+	// Input and output buffer sizes. If the buffer size is zero, then a
+	// default value of 4096 is used.
+	ReadBufferSize, WriteBufferSize int
+}
+
+var errMalformedURL = errors.New("malformed ws or wss URL")
+
+func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
+	// From the RFC:
+	//
+	// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
+	// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
+	//
+	// We don't use the net/url parser here because the dialer interface does
+	// not provide a way for applications to work around percent deocding in
+	// the net/url parser.
+
+	switch {
+	case strings.HasPrefix(u, "ws://"):
+		u = u[len("ws://"):]
+	case strings.HasPrefix(u, "wss://"):
+		u = u[len("wss://"):]
+		useTLS = true
+	default:
+		return false, "", "", "", errMalformedURL
+	}
+
+	hostPort := u
+	opaque = "/"
+	if i := strings.Index(u, "/"); i >= 0 {
+		hostPort = u[:i]
+		opaque = u[i:]
+	}
+
+	host = hostPort
+	port = ":80"
+	if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") {
+		host = hostPort[:i]
+		port = hostPort[i:]
+	} else if useTLS {
+		port = ":443"
+	}
+
+	return useTLS, host, port, opaque, nil
+}
+
+var DefaultDialer *Dialer
+
+// Dial creates a new client connection. Use requestHeader to specify the
+// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
+// Use the response.Header to get the selected subprotocol
+// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
+//
+// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
+// non-nil *http.Response so that callers can handle redirects, authentication,
+// etc.
+func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
+
+	useTLS, host, port, opaque, err := parseURL(urlStr)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	if d == nil {
+		d = &Dialer{}
+	}
+
+	var deadline time.Time
+	if d.HandshakeTimeout != 0 {
+		deadline = time.Now().Add(d.HandshakeTimeout)
+	}
+
+	netDial := d.NetDial
+	if netDial == nil {
+		netDialer := &net.Dialer{Deadline: deadline}
+		netDial = netDialer.Dial
+	}
+
+	netConn, err := netDial("tcp", host+port)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	defer func() {
+		if netConn != nil {
+			netConn.Close()
+		}
+	}()
+
+	if err := netConn.SetDeadline(deadline); err != nil {
+		return nil, nil, err
+	}
+
+	if useTLS {
+		cfg := d.TLSClientConfig
+		if cfg == nil {
+			cfg = &tls.Config{ServerName: host}
+		} else if cfg.ServerName == "" {
+			shallowCopy := *cfg
+			cfg = &shallowCopy
+			cfg.ServerName = host
+		}
+		tlsConn := tls.Client(netConn, cfg)
+		netConn = tlsConn
+		if err := tlsConn.Handshake(); err != nil {
+			return nil, nil, err
+		}
+		if !cfg.InsecureSkipVerify {
+			if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+				return nil, nil, err
+			}
+		}
+	}
+
+	readBufferSize := d.ReadBufferSize
+	if readBufferSize == 0 {
+		readBufferSize = 4096
+	}
+
+	writeBufferSize := d.WriteBufferSize
+	if writeBufferSize == 0 {
+		writeBufferSize = 4096
+	}
+
+	conn, resp, err := NewClient(
+		netConn,
+		&url.URL{Host: host + port, Opaque: opaque},
+		requestHeader, readBufferSize, writeBufferSize)
+	if err != nil {
+		return nil, resp, err
+	}
+
+	netConn.SetDeadline(time.Time{})
+	netConn = nil // to avoid close in defer.
+	return conn, resp, nil
+}

+ 142 - 16
client_server_test.go

@@ -2,9 +2,11 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package websocket_test
+package websocket
 
 import (
+	"crypto/tls"
+	"crypto/x509"
 	"io"
 	"io/ioutil"
 	"net"
@@ -13,15 +15,13 @@ import (
 	"net/url"
 	"testing"
 	"time"
-
-	"github.com/gorilla/websocket"
 )
 
-type wsHandler struct {
+type handshakeHandler struct {
 	*testing.T
 }
 
-func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+func (t handshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	if r.Method != "GET" {
 		http.Error(w, "Method not allowed", 405)
 		t.Logf("bad method: %s", r.Method)
@@ -32,8 +32,8 @@ func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		t.Logf("bad origin: %s", r.Header.Get("Origin"))
 		return
 	}
-	ws, err := websocket.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}, 1024, 1024)
-	if _, ok := err.(websocket.HandshakeError); ok {
+	ws, err := Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}, 1024, 1024)
+	if _, ok := err.(HandshakeError); ok {
 		t.Logf("bad handshake: %v", err)
 		http.Error(w, "Not a websocket handshake", 400)
 		return
@@ -50,9 +50,6 @@ func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			}
 			return
 		}
-		if op == websocket.PongMessage {
-			continue
-		}
 		w, err := ws.NextWriter(op)
 		if err != nil {
 			t.Logf("NextWriter: %v", err)
@@ -69,15 +66,15 @@ func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-func TestClientServer(t *testing.T) {
-	s := httptest.NewServer(wsHandler{t})
+func TestHandshake(t *testing.T) {
+	s := httptest.NewServer(handshakeHandler{t})
 	defer s.Close()
 	u, _ := url.Parse(s.URL)
 	c, err := net.Dial("tcp", u.Host)
 	if err != nil {
 		t.Fatalf("Dial: %v", err)
 	}
-	ws, resp, err := websocket.NewClient(c, u, http.Header{"Origin": {s.URL}}, 1024, 1024)
+	ws, resp, err := NewClient(c, u, http.Header{"Origin": {s.URL}}, 1024, 1024)
 	if err != nil {
 		t.Fatalf("NewClient: %v", err)
 	}
@@ -93,7 +90,7 @@ func TestClientServer(t *testing.T) {
 		t.Error("Set-Cookie not received from the server.")
 	}
 
-	w, _ := ws.NextWriter(websocket.TextMessage)
+	w, _ := ws.NextWriter(TextMessage)
 	io.WriteString(w, "HELLO")
 	w.Close()
 	ws.SetReadDeadline(time.Now().Add(1 * time.Second))
@@ -101,8 +98,8 @@ func TestClientServer(t *testing.T) {
 	if err != nil {
 		t.Fatalf("NextReader: %v", err)
 	}
-	if op != websocket.TextMessage {
-		t.Fatalf("op=%d, want %d", op, websocket.TextMessage)
+	if op != TextMessage {
+		t.Fatalf("op=%d, want %d", op, TextMessage)
 	}
 	b, err := ioutil.ReadAll(r)
 	if err != nil {
@@ -112,3 +109,132 @@ func TestClientServer(t *testing.T) {
 		t.Fatalf("message=%s, want %s", b, "HELLO")
 	}
 }
+
+type dialHandler struct {
+	*testing.T
+}
+
+func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ws, err := Upgrade(w, r, nil, 1024, 1024)
+	if _, ok := err.(HandshakeError); ok {
+		t.Logf("bad handshake: %v", err)
+		http.Error(w, "Not a websocket handshake", 400)
+		return
+	} else if err != nil {
+		t.Logf("upgrade error: %v", err)
+		return
+	}
+	defer ws.Close()
+	for {
+		mt, p, err := ws.ReadMessage()
+		if err != nil {
+			if err != io.EOF {
+				t.Logf("ReadMessage: %v", err)
+			}
+			return
+		}
+		if err := ws.WriteMessage(mt, p); err != nil {
+			t.Logf("WriteMessage: %v", err)
+			return
+		}
+	}
+}
+
+func sendRecv(t *testing.T, ws *Conn) {
+	const message = "Hello World!"
+	if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
+		t.Fatalf("WriteMessage: %v", err)
+	}
+	_, p, err := ws.ReadMessage()
+	if err != nil {
+		t.Fatalf("ReadMessage: %v", err)
+	}
+	if string(p) != message {
+		t.Fatalf("message=%s, want %s", p, message)
+	}
+}
+
+func httpToWs(u string) string {
+	return "ws" + u[len("http"):]
+}
+
+func TestDial(t *testing.T) {
+	s := httptest.NewServer(dialHandler{t})
+	defer s.Close()
+	ws, _, err := DefaultDialer.Dial(httpToWs(s.URL), nil)
+	if err != nil {
+		t.Fatalf("Dial() returned error %v", err)
+	}
+	defer ws.Close()
+	sendRecv(t, ws)
+}
+
+func TestDialTLS(t *testing.T) {
+	s := httptest.NewTLSServer(dialHandler{t})
+	defer s.Close()
+
+	certs := x509.NewCertPool()
+	for _, c := range s.TLS.Certificates {
+		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+		if err != nil {
+			t.Fatalf("error parsing server's root cert: %v", err)
+		}
+		for _, root := range roots {
+			certs.AddCert(root)
+		}
+	}
+
+	u, _ := url.Parse(s.URL)
+	d := &Dialer{
+		NetDial:         func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) },
+		TLSClientConfig: &tls.Config{RootCAs: certs},
+	}
+	ws, _, err := d.Dial("wss://example.com/", nil)
+	if err != nil {
+		t.Fatalf("Dial() returned error %v", err)
+	}
+	defer ws.Close()
+	sendRecv(t, ws)
+}
+
+func TestDialTLSBadCert(t *testing.T) {
+	s := httptest.NewTLSServer(dialHandler{t})
+	defer s.Close()
+	_, _, err := DefaultDialer.Dial(httpToWs(s.URL), nil)
+	if err == nil {
+		t.Fatalf("Dial() did not return error")
+	}
+}
+
+func TestDialTLSNoVerify(t *testing.T) {
+	s := httptest.NewTLSServer(dialHandler{t})
+	defer s.Close()
+	d := &Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
+	ws, _, err := d.Dial(httpToWs(s.URL), nil)
+	if err != nil {
+		t.Fatalf("Dial() returned error %v", err)
+	}
+	defer ws.Close()
+	sendRecv(t, ws)
+}
+
+func TestDialTimeout(t *testing.T) {
+	s := httptest.NewServer(dialHandler{t})
+	defer s.Close()
+	d := &Dialer{
+		HandshakeTimeout: -1,
+	}
+	_, _, err := d.Dial(httpToWs(s.URL), nil)
+	if err == nil {
+		t.Fatalf("Dial() did not return error")
+	}
+}
+
+func TestDialBadScheme(t *testing.T) {
+	s := httptest.NewServer(dialHandler{t})
+	defer s.Close()
+	_, _, err := DefaultDialer.Dial(s.URL, nil)
+	if err == nil {
+		t.Fatalf("Dial() did not return error")
+	}
+}