cors_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "strings"
  6. "testing"
  7. "time"
  8. "github.com/gin-gonic/gin"
  9. "github.com/stretchr/testify/assert"
  10. )
  11. func init() {
  12. gin.SetMode(gin.TestMode)
  13. }
  14. func newTestRouter(config Config) *gin.Engine {
  15. router := gin.New()
  16. router.Use(New(config))
  17. router.GET("/", func(c *gin.Context) {
  18. c.String(http.StatusOK, "get")
  19. })
  20. router.POST("/", func(c *gin.Context) {
  21. c.String(http.StatusOK, "post")
  22. })
  23. router.PATCH("/", func(c *gin.Context) {
  24. c.String(http.StatusOK, "patch")
  25. })
  26. return router
  27. }
  28. func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder {
  29. req, _ := http.NewRequest(method, "/", nil)
  30. if len(origin) > 0 {
  31. req.Header.Set("Origin", origin)
  32. }
  33. w := httptest.NewRecorder()
  34. r.ServeHTTP(w, req)
  35. return w
  36. }
  37. func performRequestWithHeaders(r http.Handler, method, origin string, headers map[string]string) *httptest.ResponseRecorder {
  38. req, _ := http.NewRequest(method, "/", nil)
  39. for k, v := range headers {
  40. req.Header.Set(k, v)
  41. }
  42. if len(origin) > 0 {
  43. req.Header.Set("Origin", origin)
  44. }
  45. w := httptest.NewRecorder()
  46. r.ServeHTTP(w, req)
  47. return w
  48. }
  49. func TestConfigAddAllow(t *testing.T) {
  50. config := Config{}
  51. config.AddAllowMethods("POST")
  52. config.AddAllowMethods("GET", "PUT")
  53. config.AddExposeHeaders()
  54. config.AddAllowHeaders("Some", " cool")
  55. config.AddAllowHeaders("header")
  56. config.AddExposeHeaders()
  57. config.AddExposeHeaders()
  58. config.AddExposeHeaders("exposed", "header")
  59. config.AddExposeHeaders("hey")
  60. assert.Equal(t, config.AllowMethods, []string{"POST", "GET", "PUT"})
  61. assert.Equal(t, config.AllowHeaders, []string{"Some", " cool", "header"})
  62. assert.Equal(t, config.ExposeHeaders, []string{"exposed", "header", "hey"})
  63. }
  64. func TestBadConfig(t *testing.T) {
  65. assert.Panics(t, func() { New(Config{}) })
  66. assert.Panics(t, func() {
  67. New(Config{
  68. AllowAllOrigins: true,
  69. AllowOrigins: []string{"http://google.com"},
  70. })
  71. })
  72. assert.Panics(t, func() {
  73. New(Config{
  74. AllowAllOrigins: true,
  75. AllowOriginFunc: func(origin string) bool { return false },
  76. })
  77. })
  78. assert.Panics(t, func() {
  79. New(Config{
  80. AllowOrigins: []string{"google.com"},
  81. })
  82. })
  83. }
  84. func TestNormalize(t *testing.T) {
  85. values := normalize([]string{
  86. "http-Access ", "Post", "POST", " poSt ",
  87. "HTTP-Access", "",
  88. })
  89. assert.Equal(t, values, []string{"http-access", "post", ""})
  90. values = normalize(nil)
  91. assert.Nil(t, values)
  92. values = normalize([]string{})
  93. assert.Equal(t, values, []string{})
  94. }
  95. func TestConvert(t *testing.T) {
  96. methods := []string{"Get", "GET", "get"}
  97. headers := []string{"X-CSRF-TOKEN", "X-CSRF-Token", "x-csrf-token"}
  98. assert.Equal(t, []string{"GET", "GET", "GET"}, convert(methods, strings.ToUpper))
  99. assert.Equal(t, []string{"X-Csrf-Token", "X-Csrf-Token", "X-Csrf-Token"}, convert(headers, http.CanonicalHeaderKey))
  100. }
  101. func TestGenerateNormalHeaders_AllowAllOrigins(t *testing.T) {
  102. header := generateNormalHeaders(Config{
  103. AllowAllOrigins: false,
  104. })
  105. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
  106. assert.Equal(t, header.Get("Vary"), "Origin")
  107. assert.Len(t, header, 1)
  108. header = generateNormalHeaders(Config{
  109. AllowAllOrigins: true,
  110. })
  111. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
  112. assert.Equal(t, header.Get("Vary"), "")
  113. assert.Len(t, header, 1)
  114. }
  115. func TestGenerateNormalHeaders_AllowCredentials(t *testing.T) {
  116. header := generateNormalHeaders(Config{
  117. AllowCredentials: true,
  118. })
  119. assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
  120. assert.Equal(t, header.Get("Vary"), "Origin")
  121. assert.Len(t, header, 2)
  122. }
  123. func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) {
  124. header := generateNormalHeaders(Config{
  125. ExposeHeaders: []string{"X-user", "xPassword"},
  126. })
  127. assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "X-User,Xpassword")
  128. assert.Equal(t, header.Get("Vary"), "Origin")
  129. assert.Len(t, header, 2)
  130. }
  131. func TestGeneratePreflightHeaders(t *testing.T) {
  132. header := generatePreflightHeaders(Config{
  133. AllowAllOrigins: false,
  134. })
  135. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
  136. assert.Equal(t, header.Get("Vary"), "Origin")
  137. assert.Len(t, header, 1)
  138. header = generateNormalHeaders(Config{
  139. AllowAllOrigins: true,
  140. })
  141. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
  142. assert.Equal(t, header.Get("Vary"), "")
  143. assert.Len(t, header, 1)
  144. }
  145. func TestGeneratePreflightHeaders_AllowCredentials(t *testing.T) {
  146. header := generatePreflightHeaders(Config{
  147. AllowCredentials: true,
  148. })
  149. assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
  150. assert.Equal(t, header.Get("Vary"), "Origin")
  151. assert.Len(t, header, 2)
  152. }
  153. func TestGeneratePreflightHeaders_AllowMethods(t *testing.T) {
  154. header := generatePreflightHeaders(Config{
  155. AllowMethods: []string{"GET ", "post", "PUT", " put "},
  156. })
  157. assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "GET,POST,PUT")
  158. assert.Equal(t, header.Get("Vary"), "Origin")
  159. assert.Len(t, header, 2)
  160. }
  161. func TestGeneratePreflightHeaders_AllowHeaders(t *testing.T) {
  162. header := generatePreflightHeaders(Config{
  163. AllowHeaders: []string{"X-user", "Content-Type"},
  164. })
  165. assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "X-User,Content-Type")
  166. assert.Equal(t, header.Get("Vary"), "Origin")
  167. assert.Len(t, header, 2)
  168. }
  169. func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
  170. header := generatePreflightHeaders(Config{
  171. MaxAge: 12 * time.Hour,
  172. })
  173. assert.Equal(t, header.Get("Access-Control-Max-Age"), "43200") // 12*60*60
  174. assert.Equal(t, header.Get("Vary"), "Origin")
  175. assert.Len(t, header, 2)
  176. }
  177. func TestValidateOrigin(t *testing.T) {
  178. cors := newCors(Config{
  179. AllowAllOrigins: true,
  180. })
  181. assert.True(t, cors.validateOrigin("http://google.com"))
  182. assert.True(t, cors.validateOrigin("https://google.com"))
  183. assert.True(t, cors.validateOrigin("example.com"))
  184. assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  185. cors = newCors(Config{
  186. AllowOrigins: []string{"https://google.com", "https://github.com"},
  187. AllowOriginFunc: func(origin string) bool {
  188. return (origin == "http://news.ycombinator.com")
  189. },
  190. AllowBrowserExtensions: true,
  191. })
  192. assert.False(t, cors.validateOrigin("http://google.com"))
  193. assert.True(t, cors.validateOrigin("https://google.com"))
  194. assert.True(t, cors.validateOrigin("https://github.com"))
  195. assert.True(t, cors.validateOrigin("http://news.ycombinator.com"))
  196. assert.False(t, cors.validateOrigin("http://example.com"))
  197. assert.False(t, cors.validateOrigin("google.com"))
  198. assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  199. cors = newCors(Config{
  200. AllowOrigins: []string{"https://google.com", "https://github.com"},
  201. })
  202. assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  203. assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
  204. assert.False(t, cors.validateOrigin("wss://socket-connection"))
  205. cors = newCors(Config{
  206. AllowOrigins: []string{"chrome-extension://*", "safari-extension://my-extension-*-app", "*.some-domain.com"},
  207. AllowBrowserExtensions: true,
  208. AllowWildcard: true,
  209. })
  210. assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  211. assert.True(t, cors.validateOrigin("chrome-extension://another-one"))
  212. assert.True(t, cors.validateOrigin("safari-extension://my-extension-one-app"))
  213. assert.True(t, cors.validateOrigin("safari-extension://my-extension-two-app"))
  214. assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow"))
  215. assert.True(t, cors.validateOrigin("http://api.some-domain.com"))
  216. assert.False(t, cors.validateOrigin("http://api.another-domain.com"))
  217. cors = newCors(Config{
  218. AllowOrigins: []string{"file://safe-file.js", "wss://some-session-layer-connection"},
  219. AllowFiles: true,
  220. AllowWebSockets: true,
  221. })
  222. assert.True(t, cors.validateOrigin("file://safe-file.js"))
  223. assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
  224. assert.True(t, cors.validateOrigin("wss://some-session-layer-connection"))
  225. assert.False(t, cors.validateOrigin("ws://not-what-we-expected"))
  226. cors = newCors(Config{
  227. AllowOrigins: []string{"*"},
  228. })
  229. assert.True(t, cors.validateOrigin("http://google.com"))
  230. assert.True(t, cors.validateOrigin("https://google.com"))
  231. assert.True(t, cors.validateOrigin("example.com"))
  232. assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  233. }
  234. func TestPassesAllowOrigins(t *testing.T) {
  235. router := newTestRouter(Config{
  236. AllowOrigins: []string{"http://google.com"},
  237. AllowMethods: []string{" GeT ", "get", "post", "PUT ", "Head", "POST"},
  238. AllowHeaders: []string{"Content-type", "timeStamp "},
  239. ExposeHeaders: []string{"Data", "x-User"},
  240. AllowCredentials: false,
  241. MaxAge: 12 * time.Hour,
  242. AllowOriginFunc: func(origin string) bool {
  243. return origin == "http://github.com"
  244. },
  245. })
  246. // no CORS request, origin == ""
  247. w := performRequest(router, "GET", "")
  248. assert.Equal(t, "get", w.Body.String())
  249. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  250. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  251. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  252. // no CORS request, origin == host
  253. w = performRequestWithHeaders(router, "GET", "http://facebook.com", map[string]string{"Host": "facebook.com"})
  254. assert.Equal(t, "get", w.Body.String())
  255. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  256. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  257. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  258. // allowed CORS request
  259. w = performRequest(router, "GET", "http://google.com")
  260. assert.Equal(t, "get", w.Body.String())
  261. assert.Equal(t, "http://google.com", w.Header().Get("Access-Control-Allow-Origin"))
  262. assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
  263. assert.Equal(t, "Data,X-User", w.Header().Get("Access-Control-Expose-Headers"))
  264. w = performRequest(router, "GET", "http://github.com")
  265. assert.Equal(t, "get", w.Body.String())
  266. assert.Equal(t, "http://github.com", w.Header().Get("Access-Control-Allow-Origin"))
  267. assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
  268. assert.Equal(t, "Data,X-User", w.Header().Get("Access-Control-Expose-Headers"))
  269. // deny CORS request
  270. w = performRequest(router, "GET", "https://google.com")
  271. assert.Equal(t, http.StatusForbidden, w.Code)
  272. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  273. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  274. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  275. // allowed CORS prefligh request
  276. w = performRequest(router, "OPTIONS", "http://github.com")
  277. assert.Equal(t, http.StatusNoContent, w.Code)
  278. assert.Equal(t, "http://github.com", w.Header().Get("Access-Control-Allow-Origin"))
  279. assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
  280. assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods"))
  281. assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
  282. assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))
  283. // deny CORS prefligh request
  284. w = performRequest(router, "OPTIONS", "http://example.com")
  285. assert.Equal(t, http.StatusForbidden, w.Code)
  286. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  287. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  288. assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"))
  289. assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
  290. assert.Empty(t, w.Header().Get("Access-Control-Max-Age"))
  291. }
  292. func TestPassesAllowAllOrigins(t *testing.T) {
  293. router := newTestRouter(Config{
  294. AllowAllOrigins: true,
  295. AllowMethods: []string{" Patch ", "get", "post", "POST"},
  296. AllowHeaders: []string{"Content-type", " testheader "},
  297. ExposeHeaders: []string{"Data2", "x-User2"},
  298. AllowCredentials: false,
  299. MaxAge: 10 * time.Hour,
  300. })
  301. // no CORS request, origin == ""
  302. w := performRequest(router, "GET", "")
  303. assert.Equal(t, "get", w.Body.String())
  304. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  305. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  306. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  307. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  308. // allowed CORS request
  309. w = performRequest(router, "POST", "example.com")
  310. assert.Equal(t, "post", w.Body.String())
  311. assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
  312. assert.Equal(t, "Data2,X-User2", w.Header().Get("Access-Control-Expose-Headers"))
  313. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  314. assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
  315. // allowed CORS prefligh request
  316. w = performRequest(router, "OPTIONS", "https://facebook.com")
  317. assert.Equal(t, http.StatusNoContent, w.Code)
  318. assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
  319. assert.Equal(t, "PATCH,GET,POST", w.Header().Get("Access-Control-Allow-Methods"))
  320. assert.Equal(t, "Content-Type,Testheader", w.Header().Get("Access-Control-Allow-Headers"))
  321. assert.Equal(t, "36000", w.Header().Get("Access-Control-Max-Age"))
  322. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  323. }
  324. func TestWildcard(t *testing.T) {
  325. router := newTestRouter(Config{
  326. AllowOrigins: []string{"https://*.github.com", "https://api.*", "http://*", "https://facebook.com", "*.golang.org"},
  327. AllowMethods: []string{"GET"},
  328. AllowWildcard: true,
  329. })
  330. w := performRequest(router, "GET", "https://gist.github.com")
  331. assert.Equal(t, 200, w.Code)
  332. w = performRequest(router, "GET", "https://api.github.com/v1/users")
  333. assert.Equal(t, 200, w.Code)
  334. w = performRequest(router, "GET", "https://giphy.com/")
  335. assert.Equal(t, 403, w.Code)
  336. w = performRequest(router, "GET", "http://hard-to-find-http-example.com")
  337. assert.Equal(t, 200, w.Code)
  338. w = performRequest(router, "GET", "https://facebook.com")
  339. assert.Equal(t, 200, w.Code)
  340. w = performRequest(router, "GET", "https://something.golang.org")
  341. assert.Equal(t, 200, w.Code)
  342. w = performRequest(router, "GET", "https://something.go.org")
  343. assert.Equal(t, 403, w.Code)
  344. router = newTestRouter(Config{
  345. AllowOrigins: []string{"https://github.com", "https://facebook.com"},
  346. AllowMethods: []string{"GET"},
  347. })
  348. w = performRequest(router, "GET", "https://gist.github.com")
  349. assert.Equal(t, 403, w.Code)
  350. w = performRequest(router, "GET", "https://github.com")
  351. assert.Equal(t, 200, w.Code)
  352. }