cors_test.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "strings"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. "gopkg.in/gin-gonic/gin.v1"
  10. )
  11. func init() {
  12. gin.SetMode(gin.TestMode)
  13. }
  14. func newTestRouter(config Config) *gin.Engine {
  15. router := gin.New()
  16. router.Use(New(config))
  17. router.GET("/", func(c *gin.Context) {
  18. c.String(200, "get")
  19. })
  20. router.POST("/", func(c *gin.Context) {
  21. c.String(200, "post")
  22. })
  23. router.PATCH("/", func(c *gin.Context) {
  24. c.String(200, "patch")
  25. })
  26. return router
  27. }
  28. func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder {
  29. req, _ := http.NewRequest(method, "/", nil)
  30. if len(origin) > 0 {
  31. req.Header.Set("Origin", origin)
  32. }
  33. w := httptest.NewRecorder()
  34. r.ServeHTTP(w, req)
  35. return w
  36. }
  37. func TestConfigAddAllow(t *testing.T) {
  38. config := Config{}
  39. config.AddAllowMethods("POST")
  40. config.AddAllowMethods("GET", "PUT")
  41. config.AddExposeHeaders()
  42. config.AddAllowHeaders("Some", " cool")
  43. config.AddAllowHeaders("header")
  44. config.AddExposeHeaders()
  45. config.AddExposeHeaders()
  46. config.AddExposeHeaders("exposed", "header")
  47. config.AddExposeHeaders("hey")
  48. assert.Equal(t, config.AllowMethods, []string{"POST", "GET", "PUT"})
  49. assert.Equal(t, config.AllowHeaders, []string{"Some", " cool", "header"})
  50. assert.Equal(t, config.ExposeHeaders, []string{"exposed", "header", "hey"})
  51. }
  52. func TestBadConfig(t *testing.T) {
  53. assert.Panics(t, func() { New(Config{}) })
  54. assert.Panics(t, func() {
  55. New(Config{
  56. AllowAllOrigins: true,
  57. AllowOrigins: []string{"http://google.com"},
  58. })
  59. })
  60. assert.Panics(t, func() {
  61. New(Config{
  62. AllowAllOrigins: true,
  63. AllowOriginFunc: func(origin string) bool { return false },
  64. })
  65. })
  66. assert.Panics(t, func() {
  67. New(Config{
  68. AllowOrigins: []string{"google.com"},
  69. })
  70. })
  71. }
  72. func TestNormalize(t *testing.T) {
  73. values := normalize([]string{
  74. "http-Access ", "Post", "POST", " poSt ",
  75. "HTTP-Access", "",
  76. })
  77. assert.Equal(t, values, []string{"http-access", "post", ""})
  78. values = normalize(nil)
  79. assert.Nil(t, values)
  80. values = normalize([]string{})
  81. assert.Equal(t, values, []string{})
  82. }
  83. func TestConvert(t *testing.T) {
  84. methods := []string{"Get", "GET", "get"}
  85. headers := []string{"X-CSRF-TOKEN", "X-CSRF-Token", "x-csrf-token"}
  86. assert.Equal(t, []string{"GET", "GET", "GET"}, convert(methods, strings.ToUpper))
  87. assert.Equal(t, []string{"X-Csrf-Token", "X-Csrf-Token", "X-Csrf-Token"}, convert(headers, http.CanonicalHeaderKey))
  88. }
  89. func TestGenerateNormalHeaders_AllowAllOrigins(t *testing.T) {
  90. header := generateNormalHeaders(Config{
  91. AllowAllOrigins: false,
  92. })
  93. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
  94. assert.Equal(t, header.Get("Vary"), "Origin")
  95. assert.Len(t, header, 1)
  96. header = generateNormalHeaders(Config{
  97. AllowAllOrigins: true,
  98. })
  99. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
  100. assert.Equal(t, header.Get("Vary"), "")
  101. assert.Len(t, header, 1)
  102. }
  103. func TestGenerateNormalHeaders_AllowCredentials(t *testing.T) {
  104. header := generateNormalHeaders(Config{
  105. AllowCredentials: true,
  106. })
  107. assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
  108. assert.Equal(t, header.Get("Vary"), "Origin")
  109. assert.Len(t, header, 2)
  110. }
  111. func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) {
  112. header := generateNormalHeaders(Config{
  113. ExposeHeaders: []string{"X-user", "xPassword"},
  114. })
  115. assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "X-User,Xpassword")
  116. assert.Equal(t, header.Get("Vary"), "Origin")
  117. assert.Len(t, header, 2)
  118. }
  119. func TestGeneratePreflightHeaders(t *testing.T) {
  120. header := generatePreflightHeaders(Config{
  121. AllowAllOrigins: false,
  122. })
  123. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
  124. assert.Equal(t, header.Get("Vary"), "Origin")
  125. assert.Len(t, header, 1)
  126. header = generateNormalHeaders(Config{
  127. AllowAllOrigins: true,
  128. })
  129. assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
  130. assert.Equal(t, header.Get("Vary"), "")
  131. assert.Len(t, header, 1)
  132. }
  133. func TestGeneratePreflightHeaders_AllowCredentials(t *testing.T) {
  134. header := generatePreflightHeaders(Config{
  135. AllowCredentials: true,
  136. })
  137. assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
  138. assert.Equal(t, header.Get("Vary"), "Origin")
  139. assert.Len(t, header, 2)
  140. }
  141. func TestGeneratePreflightHeaders_AllowedMethods(t *testing.T) {
  142. header := generatePreflightHeaders(Config{
  143. AllowMethods: []string{"GET ", "post", "PUT", " put "},
  144. })
  145. assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "GET,POST,PUT")
  146. assert.Equal(t, header.Get("Vary"), "Origin")
  147. assert.Len(t, header, 2)
  148. }
  149. func TestGeneratePreflightHeaders_AllowedHeaders(t *testing.T) {
  150. header := generatePreflightHeaders(Config{
  151. AllowHeaders: []string{"X-user", "Content-Type"},
  152. })
  153. assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "X-User,Content-Type")
  154. assert.Equal(t, header.Get("Vary"), "Origin")
  155. assert.Len(t, header, 2)
  156. }
  157. func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
  158. header := generatePreflightHeaders(Config{
  159. MaxAge: 12 * time.Hour,
  160. })
  161. assert.Equal(t, header.Get("Access-Control-Max-Age"), "43200") // 12*60*60
  162. assert.Equal(t, header.Get("Vary"), "Origin")
  163. assert.Len(t, header, 2)
  164. }
  165. func TestValidateOrigin(t *testing.T) {
  166. cors := newCors(Config{
  167. AllowAllOrigins: true,
  168. })
  169. assert.True(t, cors.validateOrigin("http://google.com"))
  170. assert.True(t, cors.validateOrigin("https://google.com"))
  171. assert.True(t, cors.validateOrigin("example.com"))
  172. cors = newCors(Config{
  173. AllowOrigins: []string{"https://google.com", "https://github.com"},
  174. AllowOriginFunc: func(origin string) bool {
  175. return (origin == "http://news.ycombinator.com")
  176. },
  177. })
  178. assert.False(t, cors.validateOrigin("http://google.com"))
  179. assert.True(t, cors.validateOrigin("https://google.com"))
  180. assert.True(t, cors.validateOrigin("https://github.com"))
  181. assert.True(t, cors.validateOrigin("http://news.ycombinator.com"))
  182. assert.False(t, cors.validateOrigin("http://example.com"))
  183. assert.False(t, cors.validateOrigin("google.com"))
  184. }
  185. func TestPassesAllowedOrigins(t *testing.T) {
  186. router := newTestRouter(Config{
  187. AllowOrigins: []string{"http://google.com"},
  188. AllowMethods: []string{" GeT ", "get", "post", "PUT ", "Head", "POST"},
  189. AllowHeaders: []string{"Content-type", "timeStamp "},
  190. ExposeHeaders: []string{"Data", "x-User"},
  191. AllowCredentials: true,
  192. MaxAge: 12 * time.Hour,
  193. AllowOriginFunc: func(origin string) bool {
  194. return origin == "http://github.com"
  195. },
  196. })
  197. // no CORS request, origin == ""
  198. w := performRequest(router, "GET", "")
  199. assert.Equal(t, w.Body.String(), "get")
  200. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  201. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  202. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  203. // allowed CORS request
  204. w = performRequest(router, "GET", "http://google.com")
  205. assert.Equal(t, w.Body.String(), "get")
  206. assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://google.com")
  207. assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true")
  208. assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "Data,X-User")
  209. // deny CORS request
  210. w = performRequest(router, "GET", "https://google.com")
  211. assert.Equal(t, w.Code, 403)
  212. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  213. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  214. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  215. // allowed CORS prefligh request
  216. w = performRequest(router, "OPTIONS", "http://github.com")
  217. assert.Equal(t, w.Code, 200)
  218. assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://github.com")
  219. assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true")
  220. assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "GET,POST,PUT,HEAD")
  221. assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "Content-Type,Timestamp")
  222. assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "43200")
  223. // deny CORS prefligh request
  224. w = performRequest(router, "OPTIONS", "http://example.com")
  225. assert.Equal(t, w.Code, 403)
  226. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  227. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  228. assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"))
  229. assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
  230. assert.Empty(t, w.Header().Get("Access-Control-Max-Age"))
  231. }
  232. func TestPassesAllowedAllOrigins(t *testing.T) {
  233. router := newTestRouter(Config{
  234. AllowAllOrigins: true,
  235. AllowMethods: []string{" Patch ", "get", "post", "POST"},
  236. AllowHeaders: []string{"Content-type", " testheader "},
  237. ExposeHeaders: []string{"Data2", "x-User2"},
  238. AllowCredentials: false,
  239. MaxAge: 10 * time.Hour,
  240. })
  241. // no CORS request, origin == ""
  242. w := performRequest(router, "GET", "")
  243. assert.Equal(t, w.Body.String(), "get")
  244. assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
  245. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  246. assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
  247. // allowed CORS request
  248. w = performRequest(router, "POST", "example.com")
  249. assert.Equal(t, w.Body.String(), "post")
  250. assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*")
  251. assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "Data2,X-User2")
  252. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  253. // allowed CORS prefligh request
  254. w = performRequest(router, "OPTIONS", "https://facebook.com")
  255. assert.Equal(t, w.Code, 200)
  256. assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*")
  257. assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "PATCH,GET,POST")
  258. assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "Content-Type,Testheader")
  259. assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "36000")
  260. assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
  261. }