cors.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. // Copyright 2015 CoreOS, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package cors
  15. import (
  16. "fmt"
  17. "net/http"
  18. "net/url"
  19. "strings"
  20. )
  21. type CORSInfo map[string]bool
  22. // CORSInfo implements the flag.Value interface to allow users to define a list of CORS origins
  23. func (ci *CORSInfo) Set(s string) error {
  24. m := make(map[string]bool)
  25. for _, v := range strings.Split(s, ",") {
  26. v = strings.TrimSpace(v)
  27. if v == "" {
  28. continue
  29. }
  30. if v != "*" {
  31. if _, err := url.Parse(v); err != nil {
  32. return fmt.Errorf("Invalid CORS origin: %s", err)
  33. }
  34. }
  35. m[v] = true
  36. }
  37. *ci = CORSInfo(m)
  38. return nil
  39. }
  40. func (ci *CORSInfo) String() string {
  41. o := make([]string, 0)
  42. for k, _ := range *ci {
  43. o = append(o, k)
  44. }
  45. return strings.Join(o, ",")
  46. }
  47. // OriginAllowed determines whether the server will allow a given CORS origin.
  48. func (c CORSInfo) OriginAllowed(origin string) bool {
  49. return c["*"] || c[origin]
  50. }
  51. type CORSHandler struct {
  52. Handler http.Handler
  53. Info *CORSInfo
  54. }
  55. // addHeader adds the correct cors headers given an origin
  56. func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
  57. w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
  58. w.Header().Add("Access-Control-Allow-Origin", origin)
  59. w.Header().Add("Access-Control-Allow-Headers", "accept, content-type")
  60. }
  61. // ServeHTTP adds the correct CORS headers based on the origin and returns immediately
  62. // with a 200 OK if the method is OPTIONS.
  63. func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  64. // Write CORS header.
  65. if h.Info.OriginAllowed("*") {
  66. h.addHeader(w, "*")
  67. } else if origin := req.Header.Get("Origin"); h.Info.OriginAllowed(origin) {
  68. h.addHeader(w, origin)
  69. }
  70. if req.Method == "OPTIONS" {
  71. w.WriteHeader(http.StatusOK)
  72. return
  73. }
  74. h.Handler.ServeHTTP(w, req)
  75. }