utils.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package cors
  2. import (
  3. "net/http"
  4. "strconv"
  5. "strings"
  6. "time"
  7. )
  8. type converter func(string) string
  9. func generateNormalHeaders(c Config) http.Header {
  10. headers := make(http.Header)
  11. if c.AllowCredentials {
  12. headers.Set("Access-Control-Allow-Credentials", "true")
  13. }
  14. if len(c.ExposeHeaders) > 0 {
  15. exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
  16. headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
  17. }
  18. if c.AllowAllOrigins {
  19. headers.Set("Access-Control-Allow-Origin", "*")
  20. } else {
  21. headers.Set("Vary", "Origin")
  22. }
  23. return headers
  24. }
  25. func generatePreflightHeaders(c Config) http.Header {
  26. headers := make(http.Header)
  27. if c.AllowCredentials {
  28. headers.Set("Access-Control-Allow-Credentials", "true")
  29. }
  30. if len(c.AllowMethods) > 0 {
  31. allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
  32. value := strings.Join(allowMethods, ",")
  33. headers.Set("Access-Control-Allow-Methods", value)
  34. }
  35. if len(c.AllowHeaders) > 0 {
  36. allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
  37. value := strings.Join(allowHeaders, ",")
  38. headers.Set("Access-Control-Allow-Headers", value)
  39. }
  40. if c.MaxAge > time.Duration(0) {
  41. value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
  42. headers.Set("Access-Control-Max-Age", value)
  43. }
  44. if c.AllowAllOrigins {
  45. headers.Set("Access-Control-Allow-Origin", "*")
  46. } else {
  47. // Always set Vary headers
  48. // see https://github.com/rs/cors/issues/10,
  49. // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
  50. headers.Add("Vary", "Origin")
  51. headers.Add("Vary", "Access-Control-Request-Method")
  52. headers.Add("Vary", "Access-Control-Request-Headers")
  53. }
  54. return headers
  55. }
  56. func normalize(values []string) []string {
  57. if values == nil {
  58. return nil
  59. }
  60. distinctMap := make(map[string]bool, len(values))
  61. normalized := make([]string, 0, len(values))
  62. for _, value := range values {
  63. value = strings.TrimSpace(value)
  64. value = strings.ToLower(value)
  65. if _, seen := distinctMap[value]; !seen {
  66. normalized = append(normalized, value)
  67. distinctMap[value] = true
  68. }
  69. }
  70. return normalized
  71. }
  72. func convert(s []string, c converter) []string {
  73. var out []string
  74. for _, i := range s {
  75. out = append(out, c(i))
  76. }
  77. return out
  78. }