瀏覽代碼

websocket: use net.Dialer to open tcp connection

This change adds a Dialer field to websocket.Config struct. If this
value is set the Dialer will be used. If it's nil, DialConfig will
create an empty Dialer that will maintain the original behavior.

Because before Go 1.3 there was no crypto/tls.DialWithDialer function,
the Dialer will be ignored when opening TLS connections in these
versions.

Fixes golang/go#9198.

Change-Id: If8b5c3c47019a3d367c436e3e60eb54bf0276184
Reviewed-on: https://go-review.googlesource.com/12191
Reviewed-by: Russ Cox <rsc@golang.org>
Cezar Sa Espinola 10 年之前
父節點
當前提交
819f4c5391
共有 6 個文件被更改,包括 127 次插入11 次删除
  1. 4 11
      websocket/client.go
  2. 26 0
      websocket/dial.go
  3. 29 0
      websocket/dialdialer.go
  4. 45 0
      websocket/dialdialer_test.go
  5. 3 0
      websocket/websocket.go
  6. 20 0
      websocket/websocket_test.go

+ 4 - 11
websocket/client.go

@@ -6,7 +6,6 @@ package websocket
 
 import (
 	"bufio"
-	"crypto/tls"
 	"io"
 	"net"
 	"net/http"
@@ -87,20 +86,14 @@ func DialConfig(config *Config) (ws *Conn, err error) {
 	if config.Origin == nil {
 		return nil, &DialError{config, ErrBadWebSocketOrigin}
 	}
-	switch config.Location.Scheme {
-	case "ws":
-		client, err = net.Dial("tcp", parseAuthority(config.Location))
-
-	case "wss":
-		client, err = tls.Dial("tcp", parseAuthority(config.Location), config.TlsConfig)
-
-	default:
-		err = ErrBadScheme
+	dialer := config.Dialer
+	if dialer == nil {
+		dialer = &net.Dialer{}
 	}
+	client, err = dialWithDialer(dialer, config)
 	if err != nil {
 		goto Error
 	}
-
 	ws, err = NewClient(config, client)
 	if err != nil {
 		client.Close()

+ 26 - 0
websocket/dial.go

@@ -0,0 +1,26 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.3
+
+package websocket
+
+import (
+	"crypto/tls"
+	"net"
+)
+
+func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
+	switch config.Location.Scheme {
+	case "ws":
+		conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
+
+	case "wss":
+		conn, err = tls.Dial("tcp", parseAuthority(config.Location), config.TlsConfig)
+
+	default:
+		err = ErrBadScheme
+	}
+	return
+}

+ 29 - 0
websocket/dialdialer.go

@@ -0,0 +1,29 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.3
+
+// We only compile this with Go 1.3+ because previously tls.DialWithDialer
+// wasn't available. The dial.go file is used for previous Go versions.
+
+package websocket
+
+import (
+	"crypto/tls"
+	"net"
+)
+
+func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
+	switch config.Location.Scheme {
+	case "ws":
+		conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
+
+	case "wss":
+		conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
+
+	default:
+		err = ErrBadScheme
+	}
+	return
+}

+ 45 - 0
websocket/dialdialer_test.go

@@ -0,0 +1,45 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.3
+
+package websocket
+
+import (
+	"crypto/tls"
+	"fmt"
+	"log"
+	"net"
+	"net/http/httptest"
+	"testing"
+	"time"
+)
+
+// This test depend on Go 1.3+ because in earlier versions the Dialer won't be
+// used in TLS connections and a timeout won't be triggered.
+func TestDialConfigTLSWithDialer(t *testing.T) {
+	tlsServer := httptest.NewTLSServer(nil)
+	tlsServerAddr := tlsServer.Listener.Addr().String()
+	log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
+	defer tlsServer.Close()
+	config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
+	config.Dialer = &net.Dialer{
+		Deadline: time.Now().Add(-time.Minute),
+	}
+	config.TlsConfig = &tls.Config{
+		InsecureSkipVerify: true,
+	}
+	_, err := DialConfig(config)
+	dialerr, ok := err.(*DialError)
+	if !ok {
+		t.Fatalf("DialError expected, got %#v", err)
+	}
+	neterr, ok := dialerr.Err.(*net.OpError)
+	if !ok {
+		t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
+	}
+	if !neterr.Timeout() {
+		t.Fatalf("expected timeout error, got %#v", neterr)
+	}
+}

+ 3 - 0
websocket/websocket.go

@@ -86,6 +86,9 @@ type Config struct {
 	// Additional header fields to be sent in WebSocket opening handshake.
 	Header http.Header
 
+	// Dialer used when opening websocket connections.
+	Dialer *net.Dialer
+
 	handshakeData map[string]string
 }
 

+ 20 - 0
websocket/websocket_test.go

@@ -357,6 +357,26 @@ func TestDialConfigBadVersion(t *testing.T) {
 	}
 }
 
+func TestDialConfigWithDialer(t *testing.T) {
+	once.Do(startServer)
+	config := newConfig(t, "/echo")
+	config.Dialer = &net.Dialer{
+		Deadline: time.Now().Add(-time.Minute),
+	}
+	_, err := DialConfig(config)
+	dialerr, ok := err.(*DialError)
+	if !ok {
+		t.Fatalf("DialError expected, got %#v", err)
+	}
+	neterr, ok := dialerr.Err.(*net.OpError)
+	if !ok {
+		t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
+	}
+	if !neterr.Timeout() {
+		t.Fatalf("expected timeout error, got %#v", neterr)
+	}
+}
+
 func TestSmallBuffer(t *testing.T) {
 	// http://code.google.com/p/go/issues/detail?id=1145
 	// Read should be able to handle reading a fragment of a frame.