cors_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. // Copyright 2015 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package cors
  15. import (
  16. "net/http"
  17. "net/http/httptest"
  18. "reflect"
  19. "testing"
  20. )
  21. func TestCORSInfo(t *testing.T) {
  22. tests := []struct {
  23. s string
  24. winfo CORSInfo
  25. ws string
  26. }{
  27. {"", CORSInfo{}, ""},
  28. {"http://127.0.0.1", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
  29. {"*", CORSInfo{"*": true}, "*"},
  30. // with space around
  31. {" http://127.0.0.1 ", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
  32. // multiple addrs
  33. {
  34. "http://127.0.0.1,http://127.0.0.2",
  35. CORSInfo{"http://127.0.0.1": true, "http://127.0.0.2": true},
  36. "http://127.0.0.1,http://127.0.0.2",
  37. },
  38. }
  39. for i, tt := range tests {
  40. info := CORSInfo{}
  41. if err := info.Set(tt.s); err != nil {
  42. t.Errorf("#%d: set error = %v, want nil", i, err)
  43. }
  44. if !reflect.DeepEqual(info, tt.winfo) {
  45. t.Errorf("#%d: info = %v, want %v", i, info, tt.winfo)
  46. }
  47. if g := info.String(); g != tt.ws {
  48. t.Errorf("#%d: info string = %s, want %s", i, g, tt.ws)
  49. }
  50. }
  51. }
  52. func TestCORSInfoOriginAllowed(t *testing.T) {
  53. tests := []struct {
  54. set string
  55. origin string
  56. wallowed bool
  57. }{
  58. {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.1", true},
  59. {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.2", true},
  60. {"http://127.0.0.1,http://127.0.0.2", "*", false},
  61. {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.3", false},
  62. {"*", "*", true},
  63. {"*", "http://127.0.0.1", true},
  64. }
  65. for i, tt := range tests {
  66. info := CORSInfo{}
  67. if err := info.Set(tt.set); err != nil {
  68. t.Errorf("#%d: set error = %v, want nil", i, err)
  69. }
  70. if g := info.OriginAllowed(tt.origin); g != tt.wallowed {
  71. t.Errorf("#%d: allowed = %v, want %v", i, g, tt.wallowed)
  72. }
  73. }
  74. }
  75. func TestCORSHandler(t *testing.T) {
  76. info := &CORSInfo{}
  77. if err := info.Set("http://127.0.0.1,http://127.0.0.2"); err != nil {
  78. t.Fatalf("unexpected set error: %v", err)
  79. }
  80. h := &CORSHandler{
  81. Handler: http.NotFoundHandler(),
  82. Info: info,
  83. }
  84. header := func(origin string) http.Header {
  85. return http.Header{
  86. "Access-Control-Allow-Methods": []string{"POST, GET, OPTIONS, PUT, DELETE"},
  87. "Access-Control-Allow-Origin": []string{origin},
  88. "Access-Control-Allow-Headers": []string{"accept, content-type, authorization"},
  89. }
  90. }
  91. tests := []struct {
  92. method string
  93. origin string
  94. wcode int
  95. wheader http.Header
  96. }{
  97. {"GET", "http://127.0.0.1", http.StatusNotFound, header("http://127.0.0.1")},
  98. {"GET", "http://127.0.0.2", http.StatusNotFound, header("http://127.0.0.2")},
  99. {"GET", "http://127.0.0.3", http.StatusNotFound, http.Header{}},
  100. {"OPTIONS", "http://127.0.0.1", http.StatusOK, header("http://127.0.0.1")},
  101. }
  102. for i, tt := range tests {
  103. rr := httptest.NewRecorder()
  104. req := &http.Request{
  105. Method: tt.method,
  106. Header: http.Header{"Origin": []string{tt.origin}},
  107. }
  108. h.ServeHTTP(rr, req)
  109. if rr.Code != tt.wcode {
  110. t.Errorf("#%d: code = %v, want %v", i, rr.Code, tt.wcode)
  111. }
  112. // it is set by http package, and there is no need to test it
  113. rr.HeaderMap.Del("Content-Type")
  114. rr.HeaderMap.Del("X-Content-Type-Options")
  115. if !reflect.DeepEqual(rr.HeaderMap, tt.wheader) {
  116. t.Errorf("#%d: header = %+v, want %+v", i, rr.HeaderMap, tt.wheader)
  117. }
  118. }
  119. }