فهرست منبع

websocket: use net.Dialer to open tcp connection

This change adds a Dialer field to websocket.Config struct. If this
value is set the Dialer will be used. If it's nil, DialConfig will
create an empty Dialer that will maintain the original behavior.

Because before Go 1.3 there was no crypto/tls.DialWithDialer function,
the Dialer will be ignored when opening TLS connections in these
versions.

Fixes golang/go#9198.

Change-Id: If8b5c3c47019a3d367c436e3e60eb54bf0276184
Reviewed-on: https://go-review.googlesource.com/12191
Reviewed-by: Russ Cox <rsc@golang.org>
Cezar Sa Espinola 10 سال پیش
والد
کامیت
819f4c5391
6فایلهای تغییر یافته به همراه127 افزوده شده و 11 حذف شده
  1. 4 11
      websocket/client.go
  2. 26 0
      websocket/dial.go
  3. 29 0
      websocket/dialdialer.go
  4. 45 0
      websocket/dialdialer_test.go
  5. 3 0
      websocket/websocket.go
  6. 20 0
      websocket/websocket_test.go

+ 4 - 11
websocket/client.go

@@ -6,7 +6,6 @@ package websocket
 
 import (
 	"bufio"
-	"crypto/tls"
 	"io"
 	"net"
 	"net/http"
@@ -87,20 +86,14 @@ func DialConfig(config *Config) (ws *Conn, err error) {
 	if config.Origin == nil {
 		return nil, &DialError{config, ErrBadWebSocketOrigin}
 	}
-	switch config.Location.Scheme {
-	case "ws":
-		client, err = net.Dial("tcp", parseAuthority(config.Location))
-
-	case "wss":
-		client, err = tls.Dial("tcp", parseAuthority(config.Location), config.TlsConfig)
-
-	default:
-		err = ErrBadScheme
+	dialer := config.Dialer
+	if dialer == nil {
+		dialer = &net.Dialer{}
 	}
+	client, err = dialWithDialer(dialer, config)
 	if err != nil {
 		goto Error
 	}
-
 	ws, err = NewClient(config, client)
 	if err != nil {
 		client.Close()

+ 26 - 0
websocket/dial.go

@@ -0,0 +1,26 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.3
+
+package websocket
+
+import (
+	"crypto/tls"
+	"net"
+)
+
+func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
+	switch config.Location.Scheme {
+	case "ws":
+		conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
+
+	case "wss":
+		conn, err = tls.Dial("tcp", parseAuthority(config.Location), config.TlsConfig)
+
+	default:
+		err = ErrBadScheme
+	}
+	return
+}

+ 29 - 0
websocket/dialdialer.go

@@ -0,0 +1,29 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.3
+
+// We only compile this with Go 1.3+ because previously tls.DialWithDialer
+// wasn't available. The dial.go file is used for previous Go versions.
+
+package websocket
+
+import (
+	"crypto/tls"
+	"net"
+)
+
+func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
+	switch config.Location.Scheme {
+	case "ws":
+		conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
+
+	case "wss":
+		conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
+
+	default:
+		err = ErrBadScheme
+	}
+	return
+}

+ 45 - 0
websocket/dialdialer_test.go

@@ -0,0 +1,45 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.3
+
+package websocket
+
+import (
+	"crypto/tls"
+	"fmt"
+	"log"
+	"net"
+	"net/http/httptest"
+	"testing"
+	"time"
+)
+
+// This test depend on Go 1.3+ because in earlier versions the Dialer won't be
+// used in TLS connections and a timeout won't be triggered.
+func TestDialConfigTLSWithDialer(t *testing.T) {
+	tlsServer := httptest.NewTLSServer(nil)
+	tlsServerAddr := tlsServer.Listener.Addr().String()
+	log.Print("Test TLS WebSocket server listening on ", tlsServerAddr)
+	defer tlsServer.Close()
+	config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost")
+	config.Dialer = &net.Dialer{
+		Deadline: time.Now().Add(-time.Minute),
+	}
+	config.TlsConfig = &tls.Config{
+		InsecureSkipVerify: true,
+	}
+	_, err := DialConfig(config)
+	dialerr, ok := err.(*DialError)
+	if !ok {
+		t.Fatalf("DialError expected, got %#v", err)
+	}
+	neterr, ok := dialerr.Err.(*net.OpError)
+	if !ok {
+		t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
+	}
+	if !neterr.Timeout() {
+		t.Fatalf("expected timeout error, got %#v", neterr)
+	}
+}

+ 3 - 0
websocket/websocket.go

@@ -86,6 +86,9 @@ type Config struct {
 	// Additional header fields to be sent in WebSocket opening handshake.
 	Header http.Header
 
+	// Dialer used when opening websocket connections.
+	Dialer *net.Dialer
+
 	handshakeData map[string]string
 }
 

+ 20 - 0
websocket/websocket_test.go

@@ -357,6 +357,26 @@ func TestDialConfigBadVersion(t *testing.T) {
 	}
 }
 
+func TestDialConfigWithDialer(t *testing.T) {
+	once.Do(startServer)
+	config := newConfig(t, "/echo")
+	config.Dialer = &net.Dialer{
+		Deadline: time.Now().Add(-time.Minute),
+	}
+	_, err := DialConfig(config)
+	dialerr, ok := err.(*DialError)
+	if !ok {
+		t.Fatalf("DialError expected, got %#v", err)
+	}
+	neterr, ok := dialerr.Err.(*net.OpError)
+	if !ok {
+		t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
+	}
+	if !neterr.Timeout() {
+		t.Fatalf("expected timeout error, got %#v", neterr)
+	}
+}
+
 func TestSmallBuffer(t *testing.T) {
 	// http://code.google.com/p/go/issues/detail?id=1145
 	// Read should be able to handle reading a fragment of a frame.