Browse Source

Merge pull request #2079 from yichengq/291

pkg/cors: add tests
Yicheng Qin 11 years ago
parent
commit
51005d32c7
2 changed files with 126 additions and 4 deletions
  1. 0 4
      pkg/cors/cors.go
  2. 126 0
      pkg/cors/cors_test.go

+ 0 - 4
pkg/cors/cors.go

@@ -73,10 +73,6 @@ func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
 // ServeHTTP adds the correct CORS headers based on the origin and returns immediately
 // with a 200 OK if the method is OPTIONS.
 func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-	// It is important to flush before leaving the goroutine.
-	// Or it may miss the latest info written.
-	defer w.(http.Flusher).Flush()
-
 	// Write CORS header.
 	if h.Info.OriginAllowed("*") {
 		h.addHeader(w, "*")

+ 126 - 0
pkg/cors/cors_test.go

@@ -0,0 +1,126 @@
+/*
+   Copyright 2014 CoreOS, Inc.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package cors
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"reflect"
+	"testing"
+)
+
+func TestCORSInfo(t *testing.T) {
+	tests := []struct {
+		s     string
+		winfo CORSInfo
+		ws    string
+	}{
+		{"", CORSInfo{}, ""},
+		{"http://127.0.0.1", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
+		{"*", CORSInfo{"*": true}, "*"},
+		// with space around
+		{" http://127.0.0.1 ", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
+		// multiple addrs
+		{
+			"http://127.0.0.1,http://127.0.0.2",
+			CORSInfo{"http://127.0.0.1": true, "http://127.0.0.2": true},
+			"http://127.0.0.1,http://127.0.0.2",
+		},
+	}
+	for i, tt := range tests {
+		info := CORSInfo{}
+		if err := info.Set(tt.s); err != nil {
+			t.Errorf("#%d: set error = %v, want nil", i, err)
+		}
+		if !reflect.DeepEqual(info, tt.winfo) {
+			t.Errorf("#%d: info = %v, want %v", i, info, tt.winfo)
+		}
+		if g := info.String(); g != tt.ws {
+			t.Errorf("#%d: info string = %s, want %s", i, g, tt.ws)
+		}
+	}
+}
+
+func TestCORSInfoOriginAllowed(t *testing.T) {
+	tests := []struct {
+		set      string
+		origin   string
+		wallowed bool
+	}{
+		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.1", true},
+		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.2", true},
+		{"http://127.0.0.1,http://127.0.0.2", "*", false},
+		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.3", false},
+		{"*", "*", true},
+		{"*", "http://127.0.0.1", true},
+	}
+	for i, tt := range tests {
+		info := CORSInfo{}
+		if err := info.Set(tt.set); err != nil {
+			t.Errorf("#%d: set error = %v, want nil", i, err)
+		}
+		if g := info.OriginAllowed(tt.origin); g != tt.wallowed {
+			t.Errorf("#%d: allowed = %v, want %v", i, g, tt.wallowed)
+		}
+	}
+}
+
+func TestCORSHandler(t *testing.T) {
+	info := &CORSInfo{}
+	if err := info.Set("http://127.0.0.1,http://127.0.0.2"); err != nil {
+		t.Fatalf("unexpected set error: %v", err)
+	}
+	h := &CORSHandler{
+		Handler: http.NotFoundHandler(),
+		Info:    info,
+	}
+
+	header := func(origin string) http.Header {
+		return http.Header{
+			"Access-Control-Allow-Methods": []string{"POST, GET, OPTIONS, PUT, DELETE"},
+			"Access-Control-Allow-Origin":  []string{origin},
+			"Access-Control-Allow-Headers": []string{"accept, content-type"},
+		}
+	}
+	tests := []struct {
+		method  string
+		origin  string
+		wcode   int
+		wheader http.Header
+	}{
+		{"GET", "http://127.0.0.1", http.StatusNotFound, header("http://127.0.0.1")},
+		{"GET", "http://127.0.0.2", http.StatusNotFound, header("http://127.0.0.2")},
+		{"GET", "http://127.0.0.3", http.StatusNotFound, http.Header{}},
+		{"OPTIONS", "http://127.0.0.1", http.StatusOK, header("http://127.0.0.1")},
+	}
+	for i, tt := range tests {
+		rr := httptest.NewRecorder()
+		req := &http.Request{
+			Method: tt.method,
+			Header: http.Header{"Origin": []string{tt.origin}},
+		}
+		h.ServeHTTP(rr, req)
+		if rr.Code != tt.wcode {
+			t.Errorf("#%d: code = %v, want %v", i, rr.Code, tt.wcode)
+		}
+		// it is set by http package, and there is no need to test it
+		rr.HeaderMap.Del("Content-Type")
+		if !reflect.DeepEqual(rr.HeaderMap, tt.wheader) {
+			t.Errorf("#%d: header = %+v, want %+v", i, rr.HeaderMap, tt.wheader)
+		}
+	}
+}