config.go 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package cors
  2. import (
  3. "net/http"
  4. "gopkg.in/gin-gonic/gin.v1"
  5. )
  6. type cors struct {
  7. allowAllOrigins bool
  8. allowCredentials bool
  9. allowOriginFunc func(string) bool
  10. allowOrigins []string
  11. exposeHeaders []string
  12. normalHeaders http.Header
  13. preflightHeaders http.Header
  14. }
  15. func newCors(config Config) *cors {
  16. if err := config.Validate(); err != nil {
  17. panic(err.Error())
  18. }
  19. return &cors{
  20. allowOriginFunc: config.AllowOriginFunc,
  21. allowAllOrigins: config.AllowAllOrigins,
  22. allowCredentials: config.AllowCredentials,
  23. allowOrigins: normalize(config.AllowOrigins),
  24. normalHeaders: generateNormalHeaders(config),
  25. preflightHeaders: generatePreflightHeaders(config),
  26. }
  27. }
  28. func (cors *cors) applyCors(c *gin.Context) {
  29. origin := c.Request.Header.Get("Origin")
  30. if len(origin) == 0 {
  31. // request is not a CORS request
  32. return
  33. }
  34. if !cors.validateOrigin(origin) {
  35. c.AbortWithStatus(http.StatusForbidden)
  36. return
  37. }
  38. if c.Request.Method == "OPTIONS" {
  39. cors.handlePreflight(c)
  40. defer c.AbortWithStatus(200)
  41. } else {
  42. cors.handleNormal(c)
  43. }
  44. if !cors.allowAllOrigins && !cors.allowCredentials {
  45. c.Header("Access-Control-Allow-Origin", origin)
  46. }
  47. }
  48. func (cors *cors) validateOrigin(origin string) bool {
  49. if cors.allowAllOrigins {
  50. return true
  51. }
  52. for _, value := range cors.allowOrigins {
  53. if value == origin {
  54. return true
  55. }
  56. }
  57. if cors.allowOriginFunc != nil {
  58. return cors.allowOriginFunc(origin)
  59. }
  60. return false
  61. }
  62. func (cors *cors) handlePreflight(c *gin.Context) {
  63. header := c.Writer.Header()
  64. for key, value := range cors.preflightHeaders {
  65. header[key] = value
  66. }
  67. }
  68. func (cors *cors) handleNormal(c *gin.Context) {
  69. header := c.Writer.Header()
  70. for key, value := range cors.normalHeaders {
  71. header[key] = value
  72. }
  73. }