|
|
@@ -40,6 +40,19 @@ func performRequest(r http.Handler, method, origin string) *httptest.ResponseRec
|
|
|
return w
|
|
|
}
|
|
|
|
|
|
+func performRequestWithHeaders(r http.Handler, method, origin string, headers map[string]string) *httptest.ResponseRecorder {
|
|
|
+ req, _ := http.NewRequest(method, "/", nil)
|
|
|
+ for k, v := range headers {
|
|
|
+ req.Header.Set(k, v)
|
|
|
+ }
|
|
|
+ if len(origin) > 0 {
|
|
|
+ req.Header.Set("Origin", origin)
|
|
|
+ }
|
|
|
+ w := httptest.NewRecorder()
|
|
|
+ r.ServeHTTP(w, req)
|
|
|
+ return w
|
|
|
+}
|
|
|
+
|
|
|
func TestConfigAddAllow(t *testing.T) {
|
|
|
config := Config{}
|
|
|
config.AddAllowMethods("POST")
|
|
|
@@ -231,6 +244,13 @@ func TestPassesAllowedOrigins(t *testing.T) {
|
|
|
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
|
|
|
assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
|
|
|
|
|
|
+ // no CORS request, origin == host
|
|
|
+ w = performRequestWithHeaders(router, "GET", "http://facebook.com", map[string]string{"Host": "facebook.com"})
|
|
|
+ assert.Equal(t, "get", w.Body.String())
|
|
|
+ assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
|
|
|
+ assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
|
|
|
+ assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
|
|
|
+
|
|
|
// allowed CORS request
|
|
|
w = performRequest(router, "GET", "http://google.com")
|
|
|
assert.Equal(t, "get", w.Body.String())
|