cors_test.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/assert"
  8. "gopkg.in/gin-gonic/gin.v1"
  9. )
  10. func init() {
  11. gin.SetMode(gin.TestMode)
  12. }
  13. func newTestRouter(config Config) *gin.Engine {
  14. router := gin.New()
  15. router.Use(New(config))
  16. router.GET("/", func(c *gin.Context) {
  17. c.String(200, "get")
  18. })
  19. router.POST("/", func(c *gin.Context) {
  20. c.String(200, "post")
  21. })
  22. router.PATCH("/", func(c *gin.Context) {
  23. c.String(200, "patch")
  24. })
  25. return router
  26. }
  27. func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder {
  28. req, _ := http.NewRequest(method, "/", nil)
  29. if len(origin) > 0 {
  30. req.Header.Set("Origin", origin)
  31. }
  32. w := httptest.NewRecorder()
  33. r.ServeHTTP(w, req)
  34. return w
  35. }
  36. func TestConfigAddAllow(t *testing.T) {
  37. config := Config{}
  38. config.AddAllowMethods("POST")
  39. config.AddAllowMethods("GET", "PUT")
  40. config.AddExposeHeaders()
  41. config.AddAllowHeaders("Some", " cool")
  42. config.AddAllowHeaders("header")
  43. config.AddExposeHeaders()
  44. config.AddExposeHeaders()
  45. config.AddExposeHeaders("exposed", "header")
  46. config.AddExposeHeaders("hey")
  47. assert.Equal(t, config.AllowMethods, []string{"POST", "GET", "PUT"})
  48. assert.Equal(t, config.AllowHeaders, []string{"Some", " cool", "header"})
  49. assert.Equal(t, config.ExposeHeaders, []string{"exposed", "header", "hey"})
  50. }
  51. func TestBadConfig(t *testing.T) {
  52. assert.Panics(t, func() { New(Config{}) })
  53. assert.Panics(t, func() {
  54. New(Config{
  55. AllowAllOrigins: true,
  56. AllowOrigins: []string{"http://google.com"},
  57. })
  58. })
  59. assert.Panics(t, func() {
  60. New(Config{
  61. AllowAllOrigins: true,
  62. AllowOriginFunc: func(origin string) bool { return false },
  63. })
  64. })
  65. assert.Panics(t, func() {
  66. New(Config{
  67. AllowOrigins: []string{"google.com"},
  68. })
  69. })
  70. }
  71. func TestNormalize(t *testing.T) {
  72. values := normalize([]string{
  73. "http-Access ", "Post", "POST", " poSt ",
  74. "HTTP-Access", "",
  75. })
  76. assert.Equal(t, values, []string{"http-access", "post", ""})
  77. values = normalize(nil)
  78. assert.Nil(t, values)
  79. values = normalize([]string{})
  80. assert.Equal(t, values, []string{})
  81. }
  82. func TestGenerateNormalHeaders_AllowAllOrigins(t *testing.T) {
  83. header := generateNormalHeaders(Config{
  84. AllowAllOrigins: false,
  85. })
  86. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
  87. assert.Equal(t, header.Get("Vary"), "Origin")
  88. assert.Len(t, header, 1)
  89. header = generateNormalHeaders(Config{
  90. AllowAllOrigins: true,
  91. })
  92. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
  93. assert.Equal(t, header.Get("Vary"), "")
  94. assert.Len(t, header, 1)
  95. }
  96. func TestGenerateNormalHeaders_AllowCredentials(t *testing.T) {
  97. header := generateNormalHeaders(Config{
  98. AllowCredentials: true,
  99. })
  100. assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
  101. assert.Equal(t, header.Get("Vary"), "Origin")
  102. assert.Len(t, header, 2)
  103. }
  104. func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) {
  105. header := generateNormalHeaders(Config{
  106. ExposeHeaders: []string{"X-user", "xPassword"},
  107. })
  108. assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "x-user,xpassword")
  109. assert.Equal(t, header.Get("Vary"), "Origin")
  110. assert.Len(t, header, 2)
  111. }
  112. func TestGeneratePreflightHeaders(t *testing.T) {
  113. header := generatePreflightHeaders(Config{
  114. AllowAllOrigins: false,
  115. })
  116. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
  117. assert.Equal(t, header.Get("Vary"), "Origin")
  118. assert.Len(t, header, 1)
  119. header = generateNormalHeaders(Config{
  120. AllowAllOrigins: true,
  121. })
  122. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
  123. assert.Equal(t, header.Get("Vary"), "")
  124. assert.Len(t, header, 1)
  125. }
  126. func TestGeneratePreflightHeaders_AllowCredentials(t *testing.T) {
  127. header := generatePreflightHeaders(Config{
  128. AllowCredentials: true,
  129. })
  130. assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
  131. assert.Equal(t, header.Get("Vary"), "Origin")
  132. assert.Len(t, header, 2)
  133. }
  134. func TestGeneratePreflightHeaders_AllowedMethods(t *testing.T) {
  135. header := generatePreflightHeaders(Config{
  136. AllowMethods: []string{"GET ", "post", "PUT", " put "},
  137. })
  138. assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "get,post,put")
  139. assert.Equal(t, header.Get("Vary"), "Origin")
  140. assert.Len(t, header, 2)
  141. }
  142. func TestGeneratePreflightHeaders_AllowedHeaders(t *testing.T) {
  143. header := generatePreflightHeaders(Config{
  144. AllowHeaders: []string{"X-user", "Content-Type"},
  145. })
  146. assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "x-user,content-type")
  147. assert.Equal(t, header.Get("Vary"), "Origin")
  148. assert.Len(t, header, 2)
  149. }
  150. func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
  151. header := generatePreflightHeaders(Config{
  152. MaxAge: 12 * time.Hour,
  153. })
  154. assert.Equal(t, header.Get("Access-Control-Max-Age"), "43200") // 12*60*60
  155. assert.Equal(t, header.Get("Vary"), "Origin")
  156. assert.Len(t, header, 2)
  157. }
  158. func TestValidateOrigin(t *testing.T) {
  159. cors := newCors(Config{
  160. AllowAllOrigins: true,
  161. })
  162. assert.True(t, cors.validateOrigin("http://google.com"))
  163. assert.True(t, cors.validateOrigin("https://google.com"))
  164. assert.True(t, cors.validateOrigin("example.com"))
  165. cors = newCors(Config{
  166. AllowOrigins: []string{"https://google.com", "https://github.com"},
  167. AllowOriginFunc: func(origin string) bool {
  168. return (origin == "http://news.ycombinator.com")
  169. },
  170. })
  171. assert.False(t, cors.validateOrigin("http://google.com"))
  172. assert.True(t, cors.validateOrigin("https://google.com"))
  173. assert.True(t, cors.validateOrigin("https://github.com"))
  174. assert.True(t, cors.validateOrigin("http://news.ycombinator.com"))
  175. assert.False(t, cors.validateOrigin("http://example.com"))
  176. assert.False(t, cors.validateOrigin("google.com"))
  177. }
  178. func TestPassesAllowedOrigins(t *testing.T) {
  179. router := newTestRouter(Config{
  180. AllowOrigins: []string{"http://google.com"},
  181. AllowMethods: []string{" GeT ", "get", "post", "PUT ", "Head", "POST"},
  182. AllowHeaders: []string{"Content-type", "timeStamp "},
  183. ExposeHeaders: []string{"Data", "x-User"},
  184. AllowCredentials: true,
  185. MaxAge: 12 * time.Hour,
  186. AllowOriginFunc: func(origin string) bool {
  187. return origin == "http://github.com"
  188. },
  189. })
  190. // no CORS request, origin == ""
  191. w := performRequest(router, "GET", "")
  192. assert.Equal(t, w.Body.String(), "get")
  193. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  194. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  195. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  196. // allowed CORS request
  197. w = performRequest(router, "GET", "http://google.com")
  198. assert.Equal(t, w.Body.String(), "get")
  199. assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://google.com")
  200. assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true")
  201. assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data,x-user")
  202. // deny CORS request
  203. w = performRequest(router, "GET", "https://google.com")
  204. assert.Equal(t, w.Code, 403)
  205. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  206. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  207. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  208. // allowed CORS prefligh request
  209. w = performRequest(router, "OPTIONS", "http://github.com")
  210. assert.Equal(t, w.Code, 200)
  211. assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://github.com")
  212. assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true")
  213. assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "get,post,put,head")
  214. assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type,timestamp")
  215. assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "43200")
  216. // deny CORS prefligh request
  217. w = performRequest(router, "OPTIONS", "http://example.com")
  218. assert.Equal(t, w.Code, 403)
  219. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  220. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  221. assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"))
  222. assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
  223. assert.Empty(t, w.Header().Get("Access-Control-Max-Age"))
  224. }
  225. func TestPassesAllowedAllOrigins(t *testing.T) {
  226. router := newTestRouter(Config{
  227. AllowAllOrigins: true,
  228. AllowMethods: []string{" Patch ", "get", "post", "POST"},
  229. AllowHeaders: []string{"Content-type", " testheader "},
  230. ExposeHeaders: []string{"Data2", "x-User2"},
  231. AllowCredentials: false,
  232. MaxAge: 10 * time.Hour,
  233. })
  234. // no CORS request, origin == ""
  235. w := performRequest(router, "GET", "")
  236. assert.Equal(t, w.Body.String(), "get")
  237. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  238. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  239. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  240. // allowed CORS request
  241. w = performRequest(router, "POST", "example.com")
  242. assert.Equal(t, w.Body.String(), "post")
  243. assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*")
  244. assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data2,x-user2")
  245. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  246. // allowed CORS prefligh request
  247. w = performRequest(router, "OPTIONS", "https://facebook.com")
  248. assert.Equal(t, w.Code, 200)
  249. assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*")
  250. assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "patch,get,post")
  251. assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type,testheader")
  252. assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "36000")
  253. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  254. }