Selaa lähdekoodia

Return response body on bad handshake.

The Dialer.Dial method returns an *http.Response on a bad handshake.
This CL updates the Dial method to include up to 1024 bytes of the
response body in the returned *http.Response. Applications may find the
response body helpful when debugging bad handshakes.

Fixes issue #62.
Gary Burd 10 vuotta sitten
vanhempi
commit
b2fa8f6d58
2 muutettua tiedostoa jossa 62 lisäystä ja 8 poistoa
  1. 14 1
      client.go
  2. 48 7
      client_server_test.go

+ 14 - 1
client.go

@@ -5,8 +5,11 @@
 package websocket
 
 import (
+	"bytes"
 	"crypto/tls"
 	"errors"
+	"io"
+	"io/ioutil"
 	"net"
 	"net/http"
 	"net/url"
@@ -155,7 +158,8 @@ var DefaultDialer *Dialer
 //
 // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
 // non-nil *http.Response so that callers can handle redirects, authentication,
-// etc.
+// etcetera. The response body may not contain the entire response and does not
+// need to be closed by the application.
 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
 	u, err := parseURL(urlStr)
 	if err != nil {
@@ -225,7 +229,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 	}
 
 	conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
+
 	if err != nil {
+		if err == ErrBadHandshake {
+			// Before closing the network connection on return from this
+			// function, slurp up some of the response to aid application
+			// debugging.
+			buf := make([]byte, 1024)
+			n, _ := io.ReadFull(resp.Body, buf)
+			resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
+		}
 		return nil, resp, err
 	}
 

+ 48 - 7
client_server_test.go

@@ -8,11 +8,13 @@ import (
 	"crypto/tls"
 	"crypto/x509"
 	"io"
+	"io/ioutil"
 	"net"
 	"net/http"
 	"net/http/httptest"
 	"net/url"
 	"reflect"
+	"strings"
 	"testing"
 	"time"
 )
@@ -34,22 +36,22 @@ var cstDialer = Dialer{
 
 type cstHandler struct{ *testing.T }
 
-type Server struct {
+type cstServer struct {
 	*httptest.Server
 	URL string
 }
 
-func newServer(t *testing.T) *Server {
-	var s Server
+func newServer(t *testing.T) *cstServer {
+	var s cstServer
 	s.Server = httptest.NewServer(cstHandler{t})
-	s.URL = "ws" + s.Server.URL[len("http"):]
+	s.URL = makeWsProto(s.Server.URL)
 	return &s
 }
 
-func newTLSServer(t *testing.T) *Server {
-	var s Server
+func newTLSServer(t *testing.T) *cstServer {
+	var s cstServer
 	s.Server = httptest.NewTLSServer(cstHandler{t})
-	s.URL = "ws" + s.Server.URL[len("http"):]
+	s.URL = makeWsProto(s.Server.URL)
 	return &s
 }
 
@@ -97,6 +99,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
+func makeWsProto(s string) string {
+	return "ws" + strings.TrimPrefix(s, "http")
+}
+
 func sendRecv(t *testing.T, ws *Conn) {
 	const message = "Hello World!"
 	if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
@@ -157,6 +163,7 @@ func TestDialTLS(t *testing.T) {
 }
 
 func xTestDialTLSBadCert(t *testing.T) {
+	// This test is deactivated because of noisy logging from the net/http package.
 	s := newTLSServer(t)
 	defer s.Close()
 
@@ -247,3 +254,37 @@ func TestHandshake(t *testing.T) {
 	}
 	sendRecv(t, ws)
 }
+
+func TestRespOnBadHandshake(t *testing.T) {
+	const expectedStatus = http.StatusGone
+	const expectedBody = "This is the response body."
+
+	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(expectedStatus)
+		io.WriteString(w, expectedBody)
+	}))
+	defer s.Close()
+
+	ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
+	if err == nil {
+		ws.Close()
+		t.Fatalf("Dial: nil")
+	}
+
+	if resp == nil {
+		t.Fatalf("resp=nil, err=%v", err)
+	}
+
+	if resp.StatusCode != expectedStatus {
+		t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
+	}
+
+	p, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatalf("ReadFull(resp.Body) returned error %v", err)
+	}
+
+	if string(p) != expectedBody {
+		t.Errorf("resp.Body=%s, want %s", p, expectedBody)
+	}
+}