Browse Source

feat(etcd_handlers): enable CORS

When developing or using web frontends for etcd it will be necessary to
enable Cross-Origin Resource Sharing. Add a flag that lets the user
enable this feature via a whitelist.
Brandon Philips 12 years ago
parent
commit
2f5015552e
2 changed files with 46 additions and 0 deletions
  1. 27 0
      etcd.go
  2. 19 0
      etcd_handlers.go

+ 27 - 0
etcd.go

@@ -3,9 +3,11 @@ package main
 import (
 import (
 	"crypto/tls"
 	"crypto/tls"
 	"flag"
 	"flag"
+	"fmt"
 	"github.com/coreos/etcd/store"
 	"github.com/coreos/etcd/store"
 	"github.com/coreos/go-raft"
 	"github.com/coreos/go-raft"
 	"io/ioutil"
 	"io/ioutil"
+	"net/url"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -40,6 +42,9 @@ var (
 	maxClusterSize int
 	maxClusterSize int
 
 
 	cpuprofile string
 	cpuprofile string
+
+	cors     string
+	corsList map[string]bool
 )
 )
 
 
 func init() {
 func init() {
@@ -77,6 +82,8 @@ func init() {
 	flag.IntVar(&maxClusterSize, "maxsize", 9, "the max size of the cluster")
 	flag.IntVar(&maxClusterSize, "maxsize", 9, "the max size of the cluster")
 
 
 	flag.StringVar(&cpuprofile, "cpuprofile", "", "write cpu profile to file")
 	flag.StringVar(&cpuprofile, "cpuprofile", "", "write cpu profile to file")
+
+	flag.StringVar(&cors, "cors", "", "whitelist origins for cross-origin resource sharing (e.g. '*' or 'http://localhost:8001,etc')")
 }
 }
 
 
 const (
 const (
@@ -152,6 +159,8 @@ func main() {
 		raft.SetLogLevel(raft.Debug)
 		raft.SetLogLevel(raft.Debug)
 	}
 	}
 
 
+	parseCorsFlag()
+
 	if machines != "" {
 	if machines != "" {
 		cluster = strings.Split(machines, ",")
 		cluster = strings.Split(machines, ",")
 	} else if machinesFile != "" {
 	} else if machinesFile != "" {
@@ -206,3 +215,21 @@ func main() {
 	e.ListenAndServe()
 	e.ListenAndServe()
 
 
 }
 }
+
+// parseCorsFlag gathers up the cors whitelist and puts it into the corsList.
+func parseCorsFlag() {
+	if cors != "" {
+		corsList = make(map[string]bool)
+		list := strings.Split(cors, ",")
+		for _, v := range list {
+			fmt.Println(v)
+			if v != "*" {
+				_, err := url.Parse(v)
+				if err != nil {
+					panic(fmt.Sprintf("bad cors url: %s", err))
+				}
+			}
+			corsList[v] = true
+		}
+	}
+}

+ 19 - 0
etcd_handlers.go

@@ -29,7 +29,26 @@ func NewEtcdMuxer() *http.ServeMux {
 
 
 type errorHandler func(http.ResponseWriter, *http.Request) error
 type errorHandler func(http.ResponseWriter, *http.Request) error
 
 
+// addCorsHeader parses the request Origin header and loops through the user
+// provided allowed origins and sets the Access-Control-Allow-Origin header if
+// there is a match.
+func addCorsHeader(w http.ResponseWriter, r *http.Request) {
+	val, ok := corsList["*"]
+	if val && ok {
+		w.Header().Add("Access-Control-Allow-Origin", "*")
+		return
+	}
+
+	requestOrigin := r.Header.Get("Origin")
+	val, ok = corsList[requestOrigin]
+	if val && ok {
+		w.Header().Add("Access-Control-Allow-Origin", requestOrigin)
+		return
+	}
+}
+
 func (fn errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 func (fn errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	addCorsHeader(w, r)
 	if e := fn(w, r); e != nil {
 	if e := fn(w, r); e != nil {
 		if etcdErr, ok := e.(etcdErr.Error); ok {
 		if etcdErr, ok := e.(etcdErr.Error); ok {
 			debug("Return error: ", etcdErr.Error())
 			debug("Return error: ", etcdErr.Error())