cors_test.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "github.com/gin-gonic/gin"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. func init() {
  10. gin.SetMode(gin.TestMode)
  11. }
  12. func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder {
  13. req, _ := http.NewRequest(method, path, nil)
  14. w := httptest.NewRecorder()
  15. r.ServeHTTP(w, req)
  16. return w
  17. }
  18. func TestBadConfig(t *testing.T) {
  19. assert.Panics(t, func() { New(Config{}) })
  20. assert.Panics(t, func() {
  21. New(Config{
  22. AllowAllOrigins: true,
  23. AllowedOrigins: []string{"http://google.com"},
  24. })
  25. })
  26. assert.Panics(t, func() {
  27. New(Config{
  28. AllowAllOrigins: true,
  29. AllowOriginFunc: func(origin string) bool { return false },
  30. })
  31. })
  32. assert.Panics(t, func() {
  33. New(Config{
  34. AllowedOrigins: []string{"http://google.com"},
  35. AllowOriginFunc: func(origin string) bool { return false },
  36. })
  37. })
  38. assert.Panics(t, func() {
  39. New(Config{
  40. AllowedOrigins: []string{"google.com"},
  41. })
  42. })
  43. }
  44. func TestNormalize(t *testing.T) {
  45. values := normalize([]string{
  46. "http-access ", "post", "POST", " poSt ",
  47. "HTTP-Access", "",
  48. })
  49. assert.Equal(t, values, []string{"Http-Access", "Post", ""})
  50. values = normalize(nil)
  51. assert.Nil(t, values)
  52. values = normalize([]string{})
  53. assert.Equal(t, values, []string{})
  54. }
  55. func TestGenerateNormalHeaders(t *testing.T) {
  56. header := generateNormalHeaders(Config{
  57. AllowAllOrigins: false,
  58. })
  59. assert.Contains(t, header.Get("Access-Control-Allow-Origin"), "")
  60. assert.Contains(t, header.Get("Vary"), "Origin")
  61. header = generateNormalHeaders(Config{
  62. AllowAllOrigins: true,
  63. })
  64. assert.Contains(t, header.Get("Access-Control-Allow-Origin"), "*")
  65. assert.Contains(t, header.Get("Vary"), "")
  66. header = generateNormalHeaders(Config{
  67. AllowCredentials: true,
  68. })
  69. assert.Contains(t, header.Get("Access-Control-Allow-Credentials"), "true")
  70. header = generateNormalHeaders(Config{
  71. AllowCredentials: false,
  72. })
  73. assert.Contains(t, header.Get("Access-Control-Allow-Credentials"), "")
  74. header = generateNormalHeaders(Config{
  75. ExposedHeaders: []string{"x-user", "xpassword"},
  76. })
  77. assert.Contains(t, header.Get("Access-Control-Expose-Headers"), "x-user, xpassword")
  78. }
  79. //
  80. // func TestDeny0(t *testing.T) {
  81. // called := false
  82. //
  83. // router := gin.New()
  84. // router.Use(New(Config{
  85. // AllowedOrigins: []string{"http://example.com"},
  86. // }))
  87. // router.GET("/", func(c *gin.Context) {
  88. // called = true
  89. // })
  90. // w := httptest.NewRecorder()
  91. // req, _ := http.NewRequest("GET", "/", nil)
  92. // req.Header.Set("Origin", "https://example.com")
  93. // router.ServeHTTP(w, req)
  94. //
  95. // assert.True(t, called)
  96. // assert.NotContains(t, w.Header(), "Access-Control")
  97. // }
  98. //
  99. // func TestDenyAbortOnError(t *testing.T) {
  100. // called := false
  101. //
  102. // router := gin.New()
  103. // router.Use(New(Config{
  104. // AbortOnError: true,
  105. // AllowedOrigins: []string{"http://example.com"},
  106. // }))
  107. // router.GET("/", func(c *gin.Context) {
  108. // called = true
  109. // })
  110. //
  111. // w := httptest.NewRecorder()
  112. // req, _ := http.NewRequest("GET", "/", nil)
  113. // req.Header.Set("Origin", "https://example.com")
  114. // router.ServeHTTP(w, req)
  115. //
  116. // assert.False(t, called)
  117. // assert.NotContains(t, w.Header(), "Access-Control")
  118. // }
  119. //
  120. // func TestDeny2(t *testing.T) {
  121. //
  122. // }
  123. // func TestDeny3(t *testing.T) {
  124. //
  125. // }
  126. //
  127. // func TestPasses0(t *testing.T) {
  128. //
  129. // }
  130. //
  131. // func TestPasses1(t *testing.T) {
  132. //
  133. // }
  134. //
  135. // func TestPasses2(t *testing.T) {
  136. //
  137. // }