// Copyright 2015 CoreOS, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cors import ( "net/http" "net/http/httptest" "reflect" "testing" ) func TestCORSInfo(t *testing.T) { tests := []struct { s string winfo CORSInfo ws string }{ {"", CORSInfo{}, ""}, {"http://127.0.0.1", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"}, {"*", CORSInfo{"*": true}, "*"}, // with space around {" http://127.0.0.1 ", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"}, // multiple addrs { "http://127.0.0.1,http://127.0.0.2", CORSInfo{"http://127.0.0.1": true, "http://127.0.0.2": true}, "http://127.0.0.1,http://127.0.0.2", }, } for i, tt := range tests { info := CORSInfo{} if err := info.Set(tt.s); err != nil { t.Errorf("#%d: set error = %v, want nil", i, err) } if !reflect.DeepEqual(info, tt.winfo) { t.Errorf("#%d: info = %v, want %v", i, info, tt.winfo) } if g := info.String(); g != tt.ws { t.Errorf("#%d: info string = %s, want %s", i, g, tt.ws) } } } func TestCORSInfoOriginAllowed(t *testing.T) { tests := []struct { set string origin string wallowed bool }{ {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.1", true}, {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.2", true}, {"http://127.0.0.1,http://127.0.0.2", "*", false}, {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.3", false}, {"*", "*", true}, {"*", "http://127.0.0.1", true}, } for i, tt := range tests { info := CORSInfo{} if err := info.Set(tt.set); err != nil { t.Errorf("#%d: set error = %v, want nil", i, err) } if g := info.OriginAllowed(tt.origin); g != tt.wallowed { t.Errorf("#%d: allowed = %v, want %v", i, g, tt.wallowed) } } } func TestCORSHandler(t *testing.T) { info := &CORSInfo{} if err := info.Set("http://127.0.0.1,http://127.0.0.2"); err != nil { t.Fatalf("unexpected set error: %v", err) } h := &CORSHandler{ Handler: http.NotFoundHandler(), Info: info, } header := func(origin string) http.Header { return http.Header{ "Access-Control-Allow-Methods": []string{"POST, GET, OPTIONS, PUT, DELETE"}, "Access-Control-Allow-Origin": []string{origin}, "Access-Control-Allow-Headers": []string{"accept, content-type"}, } } tests := []struct { method string origin string wcode int wheader http.Header }{ {"GET", "http://127.0.0.1", http.StatusNotFound, header("http://127.0.0.1")}, {"GET", "http://127.0.0.2", http.StatusNotFound, header("http://127.0.0.2")}, {"GET", "http://127.0.0.3", http.StatusNotFound, http.Header{}}, {"OPTIONS", "http://127.0.0.1", http.StatusOK, header("http://127.0.0.1")}, } for i, tt := range tests { rr := httptest.NewRecorder() req := &http.Request{ Method: tt.method, Header: http.Header{"Origin": []string{tt.origin}}, } h.ServeHTTP(rr, req) if rr.Code != tt.wcode { t.Errorf("#%d: code = %v, want %v", i, rr.Code, tt.wcode) } // it is set by http package, and there is no need to test it rr.HeaderMap.Del("Content-Type") if !reflect.DeepEqual(rr.HeaderMap, tt.wheader) { t.Errorf("#%d: header = %+v, want %+v", i, rr.HeaderMap, tt.wheader) } } }