cors.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. // Copyright 2015 The etcd Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package cors handles cross-origin HTTP requests (CORS).
  15. package cors
  16. import (
  17. "fmt"
  18. "net/http"
  19. "net/url"
  20. "sort"
  21. "strings"
  22. )
  23. type CORSInfo map[string]bool
  24. // Set implements the flag.Value interface to allow users to define a list of CORS origins
  25. func (ci *CORSInfo) Set(s string) error {
  26. m := make(map[string]bool)
  27. for _, v := range strings.Split(s, ",") {
  28. v = strings.TrimSpace(v)
  29. if v == "" {
  30. continue
  31. }
  32. if v != "*" {
  33. if _, err := url.Parse(v); err != nil {
  34. return fmt.Errorf("Invalid CORS origin: %s", err)
  35. }
  36. }
  37. m[v] = true
  38. }
  39. *ci = CORSInfo(m)
  40. return nil
  41. }
  42. func (ci *CORSInfo) String() string {
  43. o := make([]string, 0)
  44. for k := range *ci {
  45. o = append(o, k)
  46. }
  47. sort.StringSlice(o).Sort()
  48. return strings.Join(o, ",")
  49. }
  50. // OriginAllowed determines whether the server will allow a given CORS origin.
  51. func (c CORSInfo) OriginAllowed(origin string) bool {
  52. return c["*"] || c[origin]
  53. }
  54. type CORSHandler struct {
  55. Handler http.Handler
  56. Info *CORSInfo
  57. }
  58. // addHeader adds the correct cors headers given an origin
  59. func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
  60. w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
  61. w.Header().Add("Access-Control-Allow-Origin", origin)
  62. w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization")
  63. }
  64. // ServeHTTP adds the correct CORS headers based on the origin and returns immediately
  65. // with a 200 OK if the method is OPTIONS.
  66. func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  67. // Write CORS header.
  68. if h.Info.OriginAllowed("*") {
  69. h.addHeader(w, "*")
  70. } else if origin := req.Header.Get("Origin"); h.Info.OriginAllowed(origin) {
  71. h.addHeader(w, origin)
  72. }
  73. if req.Method == "OPTIONS" {
  74. w.WriteHeader(http.StatusOK)
  75. return
  76. }
  77. h.Handler.ServeHTTP(w, req)
  78. }