|
|
@@ -13,6 +13,7 @@ import (
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
"net/url"
|
|
|
+ "reflect"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"testing"
|
|
|
@@ -450,3 +451,39 @@ func TestClose(t *testing.T) {
|
|
|
t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+var originTests = []struct {
|
|
|
+ req *http.Request
|
|
|
+ origin *url.URL
|
|
|
+}{
|
|
|
+ {
|
|
|
+ req: &http.Request{
|
|
|
+ Header: http.Header{
|
|
|
+ "Origin": []string{"http://www.example.com"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ origin: &url.URL{
|
|
|
+ Scheme: "http",
|
|
|
+ Host: "www.example.com",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ req: &http.Request{},
|
|
|
+ },
|
|
|
+}
|
|
|
+
|
|
|
+func TestOrigin(t *testing.T) {
|
|
|
+ conf := newConfig(t, "/echo")
|
|
|
+ conf.Version = ProtocolVersionHybi13
|
|
|
+ for i, tt := range originTests {
|
|
|
+ origin, err := Origin(conf, tt.req)
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(origin, tt.origin) {
|
|
|
+ t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|