cors_test.go 12 KB


  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "strings"
  6. "testing"
  7. "time"
  8. "github.com/gin-gonic/gin"
  9. "github.com/stretchr/testify/assert"
  10. )
  11. func init() {
  12. gin.SetMode(gin.TestMode)
  13. }
  14. func newTestRouter(config Config) *gin.Engine {
  15. router := gin.New()
  16. router.Use(New(config))
  17. router.GET("/", func(c *gin.Context) {
  18. c.String(200, "get")
  19. })
  20. router.POST("/", func(c *gin.Context) {
  21. c.String(200, "post")
  22. })
  23. router.PATCH("/", func(c *gin.Context) {
  24. c.String(200, "patch")
  25. })
  26. return router
  27. }
  28. func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder {
  29. req, _ := http.NewRequest(method, "/", nil)
  30. if len(origin) > 0 {
  31. req.Header.Set("Origin", origin)
  32. }
  33. w := httptest.NewRecorder()
  34. r.ServeHTTP(w, req)
  35. return w
  36. }
  37. func performRequestWithHeaders(r http.Handler, method, origin string, headers map[string]string) *httptest.ResponseRecorder {
  38. req, _ := http.NewRequest(method, "/", nil)
  39. for k, v := range headers {
  40. req.Header.Set(k, v)
  41. }
  42. if len(origin) > 0 {
  43. req.Header.Set("Origin", origin)
  44. }
  45. w := httptest.NewRecorder()
  46. r.ServeHTTP(w, req)
  47. return w
  48. }
  49. func TestConfigAddAllow(t *testing.T) {
  50. config := Config{}
  51. config.AddAllowMethods("POST")
  52. config.AddAllowMethods("GET", "PUT")
  53. config.AddExposeHeaders()
  54. config.AddAllowHeaders("Some", " cool")
  55. config.AddAllowHeaders("header")
  56. config.AddExposeHeaders()
  57. config.AddExposeHeaders()
  58. config.AddExposeHeaders("exposed", "header")
  59. config.AddExposeHeaders("hey")
  60. assert.Equal(t, config.AllowMethods, []string{"POST", "GET", "PUT"})
  61. assert.Equal(t, config.AllowHeaders, []string{"Some", " cool", "header"})
  62. assert.Equal(t, config.ExposeHeaders, []string{"exposed", "header", "hey"})
  63. }
  64. func TestBadConfig(t *testing.T) {
  65. assert.Panics(t, func() { New(Config{}) })
  66. assert.Panics(t, func() {
  67. New(Config{
  68. AllowAllOrigins: true,
  69. AllowOrigins: []string{"http://google.com"},
  70. })
  71. })
  72. assert.Panics(t, func() {
  73. New(Config{
  74. AllowAllOrigins: true,
  75. AllowOriginFunc: func(origin string) bool { return false },
  76. })
  77. })
  78. assert.Panics(t, func() {
  79. New(Config{
  80. AllowOrigins: []string{"google.com"},
  81. })
  82. })
  83. }
  84. func TestNormalize(t *testing.T) {
  85. values := normalize([]string{
  86. "http-Access ", "Post", "POST", " poSt ",
  87. "HTTP-Access", "",
  88. })
  89. assert.Equal(t, values, []string{"http-access", "post", ""})
  90. values = normalize(nil)
  91. assert.Nil(t, values)
  92. values = normalize([]string{})
  93. assert.Equal(t, values, []string{})
  94. }
  95. func TestConvert(t *testing.T) {
  96. methods := []string{"Get", "GET", "get"}
  97. headers := []string{"X-CSRF-TOKEN", "X-CSRF-Token", "x-csrf-token"}
  98. assert.Equal(t, []string{"GET", "GET", "GET"}, convert(methods, strings.ToUpper))
  99. assert.Equal(t, []string{"X-Csrf-Token", "X-Csrf-Token", "X-Csrf-Token"}, convert(headers, http.CanonicalHeaderKey))
  100. }
  101. func TestGenerateNormalHeaders_AllowAllOrigins(t *testing.T) {
  102. header := generateNormalHeaders(Config{
  103. AllowAllOrigins: false,
  104. })
  105. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
  106. assert.Equal(t, header.Get("Vary"), "Origin")
  107. assert.Len(t, header, 1)
  108. header = generateNormalHeaders(Config{
  109. AllowAllOrigins: true,
  110. })
  111. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
  112. assert.Equal(t, header.Get("Vary"), "")
  113. assert.Len(t, header, 1)
  114. }
  115. func TestGenerateNormalHeaders_AllowCredentials(t *testing.T) {
  116. header := generateNormalHeaders(Config{
  117. AllowCredentials: true,
  118. })
  119. assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
  120. assert.Equal(t, header.Get("Vary"), "Origin")
  121. assert.Len(t, header, 2)
  122. }
  123. func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) {
  124. header := generateNormalHeaders(Config{
  125. ExposeHeaders: []string{"X-user", "xPassword"},
  126. })
  127. assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "X-User,Xpassword")
  128. assert.Equal(t, header.Get("Vary"), "Origin")
  129. assert.Len(t, header, 2)
  130. }
  131. func TestGeneratePreflightHeaders(t *testing.T) {
  132. header := generatePreflightHeaders(Config{
  133. AllowAllOrigins: false,
  134. })
  135. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
  136. assert.Equal(t, header.Get("Vary"), "Origin")
  137. assert.Len(t, header, 1)
  138. header = generateNormalHeaders(Config{
  139. AllowAllOrigins: true,
  140. })
  141. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
  142. assert.Equal(t, header.Get("Vary"), "")
  143. assert.Len(t, header, 1)
  144. }
  145. func TestGeneratePreflightHeaders_AllowCredentials(t *testing.T) {
  146. header := generatePreflightHeaders(Config{
  147. AllowCredentials: true,
  148. })
  149. assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
  150. assert.Equal(t, header.Get("Vary"), "Origin")
  151. assert.Len(t, header, 2)
  152. }
  153. func TestGeneratePreflightHeaders_AllowedMethods(t *testing.T) {
  154. header := generatePreflightHeaders(Config{
  155. AllowMethods: []string{"GET ", "post", "PUT", " put "},
  156. })
  157. assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "GET,POST,PUT")
  158. assert.Equal(t, header.Get("Vary"), "Origin")
  159. assert.Len(t, header, 2)
  160. }
  161. func TestGeneratePreflightHeaders_AllowedHeaders(t *testing.T) {
  162. header := generatePreflightHeaders(Config{
  163. AllowHeaders: []string{"X-user", "Content-Type"},
  164. })
  165. assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "X-User,Content-Type")
  166. assert.Equal(t, header.Get("Vary"), "Origin")
  167. assert.Len(t, header, 2)
  168. }
  169. func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
  170. header := generatePreflightHeaders(Config{
  171. MaxAge: 12 * time.Hour,
  172. })
  173. assert.Equal(t, header.Get("Access-Control-Max-Age"), "43200") // 12*60*60
  174. assert.Equal(t, header.Get("Vary"), "Origin")
  175. assert.Len(t, header, 2)
  176. }
  177. func TestValidateOrigin(t *testing.T) {
  178. cors := newCors(Config{
  179. AllowAllOrigins: true,
  180. })
  181. assert.True(t, cors.validateOrigin("http://google.com"))
  182. assert.True(t, cors.validateOrigin("https://google.com"))
  183. assert.True(t, cors.validateOrigin("example.com"))
  184. assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  185. cors = newCors(Config{
  186. AllowOrigins: []string{"https://google.com", "https://github.com"},
  187. AllowOriginFunc: func(origin string) bool {
  188. return (origin == "http://news.ycombinator.com")
  189. },
  190. AllowBrowserExtensions: true,
  191. })
  192. assert.False(t, cors.validateOrigin("http://google.com"))
  193. assert.True(t, cors.validateOrigin("https://google.com"))
  194. assert.True(t, cors.validateOrigin("https://github.com"))
  195. assert.True(t, cors.validateOrigin("http://news.ycombinator.com"))
  196. assert.False(t, cors.validateOrigin("http://example.com"))
  197. assert.False(t, cors.validateOrigin("google.com"))
  198. assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  199. cors = newCors(Config{
  200. AllowOrigins: []string{"https://google.com", "https://github.com"},
  201. })
  202. assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  203. assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
  204. assert.False(t, cors.validateOrigin("wss://socket-connection"))
  205. cors = newCors(Config{
  206. AllowOrigins: []string{"chrome-extension://random-extension-id", "safari-extension://another-ext-id"},
  207. AllowBrowserExtensions: true,
  208. })
  209. assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
  210. assert.True(t, cors.validateOrigin("safari-extension://another-ext-id"))
  211. assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow"))
  212. cors = newCors(Config{
  213. AllowOrigins: []string{"file://safe-file.js", "wss://some-session-layer-connection"},
  214. AllowFiles: true,
  215. AllowWebSockets: true,
  216. })
  217. assert.True(t, cors.validateOrigin("file://safe-file.js"))
  218. assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
  219. assert.True(t, cors.validateOrigin("wss://some-session-layer-connection"))
  220. assert.False(t, cors.validateOrigin("ws://not-what-we-expected"))
  221. }
  222. func TestPassesAllowedOrigins(t *testing.T) {
  223. router := newTestRouter(Config{
  224. AllowOrigins: []string{"http://google.com"},
  225. AllowMethods: []string{" GeT ", "get", "post", "PUT ", "Head", "POST"},
  226. AllowHeaders: []string{"Content-type", "timeStamp "},
  227. ExposeHeaders: []string{"Data", "x-User"},
  228. AllowCredentials: false,
  229. MaxAge: 12 * time.Hour,
  230. AllowOriginFunc: func(origin string) bool {
  231. return origin == "http://github.com"
  232. },
  233. })
  234. // no CORS request, origin == ""
  235. w := performRequest(router, "GET", "")
  236. assert.Equal(t, "get", w.Body.String())
  237. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  238. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  239. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  240. // no CORS request, origin == host
  241. w = performRequestWithHeaders(router, "GET", "http://facebook.com", map[string]string{"Host": "facebook.com"})
  242. assert.Equal(t, "get", w.Body.String())
  243. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  244. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  245. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  246. // allowed CORS request
  247. w = performRequest(router, "GET", "http://google.com")
  248. assert.Equal(t, "get", w.Body.String())
  249. assert.Equal(t, "http://google.com", w.Header().Get("Access-Control-Allow-Origin"))
  250. assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
  251. assert.Equal(t, "Data,X-User", w.Header().Get("Access-Control-Expose-Headers"))
  252. w = performRequest(router, "GET", "http://github.com")
  253. assert.Equal(t, "get", w.Body.String())
  254. assert.Equal(t, "http://github.com", w.Header().Get("Access-Control-Allow-Origin"))
  255. assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
  256. assert.Equal(t, "Data,X-User", w.Header().Get("Access-Control-Expose-Headers"))
  257. // deny CORS request
  258. w = performRequest(router, "GET", "https://google.com")
  259. assert.Equal(t, 403, w.Code)
  260. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  261. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  262. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  263. // allowed CORS prefligh request
  264. w = performRequest(router, "OPTIONS", "http://github.com")
  265. assert.Equal(t, 200, w.Code)
  266. assert.Equal(t, "http://github.com", w.Header().Get("Access-Control-Allow-Origin"))
  267. assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
  268. assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods"))
  269. assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
  270. assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))
  271. // deny CORS prefligh request
  272. w = performRequest(router, "OPTIONS", "http://example.com")
  273. assert.Equal(t, 403, w.Code)
  274. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  275. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  276. assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"))
  277. assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
  278. assert.Empty(t, w.Header().Get("Access-Control-Max-Age"))
  279. }
  280. func TestPassesAllowedAllOrigins(t *testing.T) {
  281. router := newTestRouter(Config{
  282. AllowAllOrigins: true,
  283. AllowMethods: []string{" Patch ", "get", "post", "POST"},
  284. AllowHeaders: []string{"Content-type", " testheader "},
  285. ExposeHeaders: []string{"Data2", "x-User2"},
  286. AllowCredentials: false,
  287. MaxAge: 10 * time.Hour,
  288. })
  289. // no CORS request, origin == ""
  290. w := performRequest(router, "GET", "")
  291. assert.Equal(t, "get", w.Body.String())
  292. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  293. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  294. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  295. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  296. // allowed CORS request
  297. w = performRequest(router, "POST", "example.com")
  298. assert.Equal(t, "post", w.Body.String())
  299. assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
  300. assert.Equal(t, "Data2,X-User2", w.Header().Get("Access-Control-Expose-Headers"))
  301. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  302. assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
  303. // allowed CORS prefligh request
  304. w = performRequest(router, "OPTIONS", "https://facebook.com")
  305. assert.Equal(t, 200, w.Code)
  306. assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
  307. assert.Equal(t, "PATCH,GET,POST", w.Header().Get("Access-Control-Allow-Methods"))
  308. assert.Equal(t, "Content-Type,Testheader", w.Header().Get("Access-Control-Allow-Headers"))
  309. assert.Equal(t, "36000", w.Header().Get("Access-Control-Max-Age"))
  310. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  311. }