cors.go 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. // Copyright 2015 CoreOS, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package cors handles cross-origin HTTP requests (CORS).
  15. package cors
  16. import (
  17. "fmt"
  18. "net/http"
  19. "net/url"
  20. "strings"
  21. )
  22. type CORSInfo map[string]bool
  23. // CORSInfo implements the flag.Value interface to allow users to define a list of CORS origins
  24. func (ci *CORSInfo) Set(s string) error {
  25. m := make(map[string]bool)
  26. for _, v := range strings.Split(s, ",") {
  27. v = strings.TrimSpace(v)
  28. if v == "" {
  29. continue
  30. }
  31. if v != "*" {
  32. if _, err := url.Parse(v); err != nil {
  33. return fmt.Errorf("Invalid CORS origin: %s", err)
  34. }
  35. }
  36. m[v] = true
  37. }
  38. *ci = CORSInfo(m)
  39. return nil
  40. }
  41. func (ci *CORSInfo) String() string {
  42. o := make([]string, 0)
  43. for k := range *ci {
  44. o = append(o, k)
  45. }
  46. return strings.Join(o, ",")
  47. }
  48. // OriginAllowed determines whether the server will allow a given CORS origin.
  49. func (c CORSInfo) OriginAllowed(origin string) bool {
  50. return c["*"] || c[origin]
  51. }
  52. type CORSHandler struct {
  53. Handler http.Handler
  54. Info *CORSInfo
  55. }
  56. // addHeader adds the correct cors headers given an origin
  57. func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
  58. w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
  59. w.Header().Add("Access-Control-Allow-Origin", origin)
  60. w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization")
  61. }
  62. // ServeHTTP adds the correct CORS headers based on the origin and returns immediately
  63. // with a 200 OK if the method is OPTIONS.
  64. func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  65. // Write CORS header.
  66. if h.Info.OriginAllowed("*") {
  67. h.addHeader(w, "*")
  68. } else if origin := req.Header.Get("Origin"); h.Info.OriginAllowed(origin) {
  69. h.addHeader(w, origin)
  70. }
  71. if req.Method == "OPTIONS" {
  72. w.WriteHeader(http.StatusOK)
  73. return
  74. }
  75. h.Handler.ServeHTTP(w, req)
  76. }