Forráskód Böngészése

Relax default origin test.

Update the default origin test to treat no origin specified as OK. If
the client can create a request without the origin set, then the client
can also create a request with an arbitrary origin.
Gary Burd 11 éve
szülő
commit
0f32413e5e
2 módosított fájl, 16 hozzáadás és 8 törlés
  1. 9 1
      client_server_test.go
  2. 7 7
      server.go

+ 9 - 1
client_server_test.go

@@ -137,7 +137,6 @@ type dialHandler struct {
 var dialUpgrader = &Upgrader{
 	ReadBufferSize:  1024,
 	WriteBufferSize: 1024,
-	CheckOrigin:     func(r *http.Request) bool { return true },
 }
 
 func (t dialHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -242,3 +241,12 @@ func TestDialBadScheme(t *testing.T) {
 		t.Fatalf("Dial() did not return error")
 	}
 }
+
+func TestDialBadOrigin(t *testing.T) {
+	s := httptest.NewServer(dialHandler{t})
+	defer s.Close()
+	_, _, err := DefaultDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
+	if err == nil {
+		t.Fatalf("Dial() did not return error")
+	}
+}

+ 7 - 7
server.go

@@ -48,8 +48,8 @@ type Upgrader struct {
 	Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
 
 	// CheckOrigin returns true if the request Origin header is acceptable. If
-	// CheckOrigin is nil, the host in the Origin header must match the host of
-	// the request.
+	// CheckOrigin is nil, the host in the Origin header must not be set or
+	// must match the host of the request.
 	CheckOrigin func(r *http.Request) bool
 }
 
@@ -63,13 +63,13 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in
 	return nil, err
 }
 
-// checkSameOrigin returns true if the origin is equal to the request host.
+// checkSameOrigin returns true if the origin is not set or is equal to the request host.
 func checkSameOrigin(r *http.Request) bool {
-	origin := r.Header.Get("Origin")
-	if origin == "" {
-		return false
+	origin := r.Header["Origin"]
+	if len(origin) == 0 {
+		return true
 	}
-	u, err := url.Parse(origin)
+	u, err := url.Parse(origin[0])
 	if err != nil {
 		return false
 	}