|
@@ -8,11 +8,13 @@ import (
|
|
|
"crypto/tls"
|
|
"crypto/tls"
|
|
|
"crypto/x509"
|
|
"crypto/x509"
|
|
|
"io"
|
|
"io"
|
|
|
|
|
+ "io/ioutil"
|
|
|
"net"
|
|
"net"
|
|
|
"net/http"
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
"net/http/httptest"
|
|
|
"net/url"
|
|
"net/url"
|
|
|
"reflect"
|
|
"reflect"
|
|
|
|
|
+ "strings"
|
|
|
"testing"
|
|
"testing"
|
|
|
"time"
|
|
"time"
|
|
|
)
|
|
)
|
|
@@ -34,22 +36,22 @@ var cstDialer = Dialer{
|
|
|
|
|
|
|
|
type cstHandler struct{ *testing.T }
|
|
type cstHandler struct{ *testing.T }
|
|
|
|
|
|
|
|
-type Server struct {
|
|
|
|
|
|
|
+type cstServer struct {
|
|
|
*httptest.Server
|
|
*httptest.Server
|
|
|
URL string
|
|
URL string
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func newServer(t *testing.T) *Server {
|
|
|
|
|
- var s Server
|
|
|
|
|
|
|
+func newServer(t *testing.T) *cstServer {
|
|
|
|
|
+ var s cstServer
|
|
|
s.Server = httptest.NewServer(cstHandler{t})
|
|
s.Server = httptest.NewServer(cstHandler{t})
|
|
|
- s.URL = "ws" + s.Server.URL[len("http"):]
|
|
|
|
|
|
|
+ s.URL = makeWsProto(s.Server.URL)
|
|
|
return &s
|
|
return &s
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func newTLSServer(t *testing.T) *Server {
|
|
|
|
|
- var s Server
|
|
|
|
|
|
|
+func newTLSServer(t *testing.T) *cstServer {
|
|
|
|
|
+ var s cstServer
|
|
|
s.Server = httptest.NewTLSServer(cstHandler{t})
|
|
s.Server = httptest.NewTLSServer(cstHandler{t})
|
|
|
- s.URL = "ws" + s.Server.URL[len("http"):]
|
|
|
|
|
|
|
+ s.URL = makeWsProto(s.Server.URL)
|
|
|
return &s
|
|
return &s
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -97,6 +99,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func makeWsProto(s string) string {
|
|
|
|
|
+ return "ws" + strings.TrimPrefix(s, "http")
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func sendRecv(t *testing.T, ws *Conn) {
|
|
func sendRecv(t *testing.T, ws *Conn) {
|
|
|
const message = "Hello World!"
|
|
const message = "Hello World!"
|
|
|
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
|
|
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
|
|
@@ -157,6 +163,7 @@ func TestDialTLS(t *testing.T) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func xTestDialTLSBadCert(t *testing.T) {
|
|
func xTestDialTLSBadCert(t *testing.T) {
|
|
|
|
|
+ // This test is deactivated because of noisy logging from the net/http package.
|
|
|
s := newTLSServer(t)
|
|
s := newTLSServer(t)
|
|
|
defer s.Close()
|
|
defer s.Close()
|
|
|
|
|
|
|
@@ -247,3 +254,37 @@ func TestHandshake(t *testing.T) {
|
|
|
}
|
|
}
|
|
|
sendRecv(t, ws)
|
|
sendRecv(t, ws)
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+func TestRespOnBadHandshake(t *testing.T) {
|
|
|
|
|
+ const expectedStatus = http.StatusGone
|
|
|
|
|
+ const expectedBody = "This is the response body."
|
|
|
|
|
+
|
|
|
|
|
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
+ w.WriteHeader(expectedStatus)
|
|
|
|
|
+ io.WriteString(w, expectedBody)
|
|
|
|
|
+ }))
|
|
|
|
|
+ defer s.Close()
|
|
|
|
|
+
|
|
|
|
|
+ ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
|
|
|
|
|
+ if err == nil {
|
|
|
|
|
+ ws.Close()
|
|
|
|
|
+ t.Fatalf("Dial: nil")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if resp == nil {
|
|
|
|
|
+ t.Fatalf("resp=nil, err=%v", err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if resp.StatusCode != expectedStatus {
|
|
|
|
|
+ t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ p, err := ioutil.ReadAll(resp.Body)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if string(p) != expectedBody {
|
|
|
|
|
+ t.Errorf("resp.Body=%s, want %s", p, expectedBody)
|
|
|
|
|
+ }
|
|
|
|
|
+}
|