Browse Source

refactor(cors): Break apart CORS data and middleware

Brian Waldon 12 years ago
parent
commit
a93d60be90
6 changed files with 43 additions and 43 deletions
  1. 8 3
      etcd.go
  2. 17 17
      server/cors.go
  3. 1 0
      server/peer_server.go
  4. 2 2
      server/peer_server_handlers.go
  5. 12 21
      server/server.go
  6. 3 0
      tests/server_utils.go

+ 8 - 3
etcd.go

@@ -98,6 +98,12 @@ func main() {
 		}
 		}
 	}
 	}
 
 
+	// Retrieve CORS configuration
+	corsInfo, err := server.NewCORSInfo(config.CorsOrigins)
+	if err != nil {
+		log.Fatal("CORS:", err)
+	}
+
 	// Create etcd key-value store and registry.
 	// Create etcd key-value store and registry.
 	store := store.New()
 	store := store.New()
 	registry := server.NewRegistry(store)
 	registry := server.NewRegistry(store)
@@ -113,6 +119,7 @@ func main() {
 		ElectionTimeout: time.Duration(config.Peer.ElectionTimeout) * time.Millisecond,
 		ElectionTimeout: time.Duration(config.Peer.ElectionTimeout) * time.Millisecond,
 		MaxClusterSize: config.MaxClusterSize,
 		MaxClusterSize: config.MaxClusterSize,
 		RetryTimes: config.MaxRetryAttempts,
 		RetryTimes: config.MaxRetryAttempts,
+		CORS: corsInfo,
 	}
 	}
 	ps := server.NewPeerServer(psConfig, &peerTLSConfig, &info.RaftTLS, registry, store, &mb)
 	ps := server.NewPeerServer(psConfig, &peerTLSConfig, &info.RaftTLS, registry, store, &mb)
 
 
@@ -121,11 +128,9 @@ func main() {
 		Name: info.Name,
 		Name: info.Name,
 		URL: info.EtcdURL,
 		URL: info.EtcdURL,
 		BindAddr: info.EtcdListenHost,
 		BindAddr: info.EtcdListenHost,
+		CORS: corsInfo,
 	}
 	}
 	s := server.New(sConfig, &tlsConfig, &info.EtcdTLS, ps, registry, store, &mb)
 	s := server.New(sConfig, &tlsConfig, &info.EtcdTLS, ps, registry, store, &mb)
-	if err := s.AllowOrigins(config.CorsOrigins); err != nil {
-		panic(err)
-	}
 
 
 	if config.Trace() {
 	if config.Trace() {
 		s.EnableTracing()
 		s.EnableTracing()

+ 17 - 17
server/cors.go

@@ -20,50 +20,50 @@ import (
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
-
-	"github.com/gorilla/mux"
 )
 )
 
 
-type corsHandler struct {
-	router      *mux.Router
-	corsOrigins map[string]bool
+type corsInfo struct {
+	origins map[string]bool
 }
 }
 
 
-// AllowOrigins sets a comma-delimited list of origins that are allowed.
-func (s *corsHandler) AllowOrigins(origins []string) error {
+func NewCORSInfo(origins []string) (*corsInfo, error) {
 	// Construct a lookup of all origins.
 	// Construct a lookup of all origins.
 	m := make(map[string]bool)
 	m := make(map[string]bool)
 	for _, v := range origins {
 	for _, v := range origins {
 		if v != "*" {
 		if v != "*" {
 			if _, err := url.Parse(v); err != nil {
 			if _, err := url.Parse(v); err != nil {
-				return fmt.Errorf("Invalid CORS origin: %s", err)
+				return nil, fmt.Errorf("Invalid CORS origin: %s", err)
 			}
 			}
 		}
 		}
 		m[v] = true
 		m[v] = true
 	}
 	}
-	s.corsOrigins = m
 
 
-	return nil
+	return &corsInfo{m}, nil
 }
 }
 
 
 // OriginAllowed determines whether the server will allow a given CORS origin.
 // OriginAllowed determines whether the server will allow a given CORS origin.
-func (c *corsHandler) OriginAllowed(origin string) bool {
-	return c.corsOrigins["*"] || c.corsOrigins[origin]
+func (c *corsInfo) OriginAllowed(origin string) bool {
+	return c.origins["*"] || c.origins[origin]
+}
+
+type corsHTTPMiddleware struct {
+	next   http.Handler
+	info   *corsInfo
 }
 }
 
 
 // addHeader adds the correct cors headers given an origin
 // addHeader adds the correct cors headers given an origin
-func (h *corsHandler) addHeader(w http.ResponseWriter, origin string) {
+func (h *corsHTTPMiddleware) addHeader(w http.ResponseWriter, origin string) {
 	w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
 	w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
 	w.Header().Add("Access-Control-Allow-Origin", origin)
 	w.Header().Add("Access-Control-Allow-Origin", origin)
 }
 }
 
 
 // ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
 // ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
 // with a 200 OK if the method is OPTIONS.
 // with a 200 OK if the method is OPTIONS.
-func (h *corsHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+func (h *corsHTTPMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 	// Write CORS header.
 	// Write CORS header.
-	if h.OriginAllowed("*") {
+	if h.info.OriginAllowed("*") {
 		h.addHeader(w, "*")
 		h.addHeader(w, "*")
-	} else if origin := req.Header.Get("Origin"); h.OriginAllowed(origin) {
+	} else if origin := req.Header.Get("Origin"); h.info.OriginAllowed(origin) {
 		h.addHeader(w, origin)
 		h.addHeader(w, origin)
 	}
 	}
 
 
@@ -72,5 +72,5 @@ func (h *corsHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	h.router.ServeHTTP(w, req)
+	h.next.ServeHTTP(w, req)
 }
 }

+ 1 - 0
server/peer_server.go

@@ -35,6 +35,7 @@ type PeerServerConfig struct {
 	ElectionTimeout  time.Duration
 	ElectionTimeout  time.Duration
 	MaxClusterSize   int
 	MaxClusterSize   int
 	RetryTimes       int
 	RetryTimes       int
+	CORS             *corsInfo
 }
 }
 
 
 type PeerServer struct {
 type PeerServer struct {

+ 2 - 2
server/peer_server_handlers.go

@@ -150,9 +150,9 @@ func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request)
 	command := &JoinCommand{}
 	command := &JoinCommand{}
 
 
 	// Write CORS header.
 	// Write CORS header.
-	if ps.server.OriginAllowed("*") {
+	if ps.Config.CORS.OriginAllowed("*") {
 		w.Header().Add("Access-Control-Allow-Origin", "*")
 		w.Header().Add("Access-Control-Allow-Origin", "*")
-	} else if ps.server.OriginAllowed(req.Header.Get("Origin")) {
+	} else if ps.Config.CORS.OriginAllowed(req.Header.Get("Origin")) {
 		w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin"))
 		w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin"))
 	}
 	}
 
 

+ 12 - 21
server/server.go

@@ -26,27 +26,28 @@ type ServerConfig struct {
 	Name     string
 	Name     string
 	URL      string
 	URL      string
 	BindAddr string
 	BindAddr string
+	CORS     *corsInfo
 }
 }
 
 
 // This is the default implementation of the Server interface.
 // This is the default implementation of the Server interface.
 type Server struct {
 type Server struct {
 	http.Server
 	http.Server
-	Config      ServerConfig
-	peerServer  *PeerServer
-	registry    *Registry
-	listener    net.Listener
-	store       store.Store
-	tlsConf     *TLSConfig
-	tlsInfo     *TLSInfo
-	router      *mux.Router
-	corsHandler *corsHandler
+	Config         ServerConfig
+	peerServer     *PeerServer
+	registry       *Registry
+	listener       net.Listener
+	store          store.Store
+	tlsConf        *TLSConfig
+	tlsInfo        *TLSInfo
+	router         *mux.Router
+	corsMiddleware *corsHTTPMiddleware
 	metrics     *metrics.Bucket
 	metrics     *metrics.Bucket
 }
 }
 
 
 // Creates a new Server.
 // Creates a new Server.
 func New(sConfig ServerConfig, tlsConf *TLSConfig, tlsInfo *TLSInfo, peerServer *PeerServer, registry *Registry, store store.Store, mb *metrics.Bucket) *Server {
 func New(sConfig ServerConfig, tlsConf *TLSConfig, tlsInfo *TLSInfo, peerServer *PeerServer, registry *Registry, store store.Store, mb *metrics.Bucket) *Server {
 	r := mux.NewRouter()
 	r := mux.NewRouter()
-	cors := &corsHandler{router: r}
+	cors := &corsHTTPMiddleware{r, sConfig.CORS}
 
 
 	s := &Server{
 	s := &Server{
 		Config: sConfig,
 		Config: sConfig,
@@ -61,7 +62,7 @@ func New(sConfig ServerConfig, tlsConf *TLSConfig, tlsInfo *TLSInfo, peerServer
 		tlsInfo:     tlsInfo,
 		tlsInfo:     tlsInfo,
 		peerServer:  peerServer,
 		peerServer:  peerServer,
 		router:      r,
 		router:      r,
-		corsHandler: cors,
+		corsMiddleware: cors,
 		metrics:     mb,
 		metrics:     mb,
 	}
 	}
 
 
@@ -326,16 +327,6 @@ func (s *Server) Dispatch(c raft.Command, w http.ResponseWriter, req *http.Reque
 	}
 	}
 }
 }
 
 
-// OriginAllowed determines whether the server will allow a given CORS origin.
-func (s *Server) OriginAllowed(origin string) bool {
-	return s.corsHandler.OriginAllowed(origin)
-}
-
-// AllowOrigins sets a comma-delimited list of origins that are allowed.
-func (s *Server) AllowOrigins(origins []string) error {
-	return s.corsHandler.AllowOrigins(origins)
-}
-
 // Handler to return the current version of etcd.
 // Handler to return the current version of etcd.
 func (s *Server) GetVersionHandler(w http.ResponseWriter, req *http.Request) error {
 func (s *Server) GetVersionHandler(w http.ResponseWriter, req *http.Request) error {
 	w.WriteHeader(http.StatusOK)
 	w.WriteHeader(http.StatusOK)

+ 3 - 0
tests/server_utils.go

@@ -25,6 +25,7 @@ func RunServer(f func(*server.Server)) {
 
 
 	store := store.New()
 	store := store.New()
 	registry := server.NewRegistry(store)
 	registry := server.NewRegistry(store)
+	corsInfo, _ := server.NewCORSInfo([]string{})
 
 
 	psConfig := server.PeerServerConfig{
 	psConfig := server.PeerServerConfig{
 		Name: testName,
 		Name: testName,
@@ -35,6 +36,7 @@ func RunServer(f func(*server.Server)) {
 		HeartbeatTimeout: testHeartbeatTimeout,
 		HeartbeatTimeout: testHeartbeatTimeout,
 		ElectionTimeout: testElectionTimeout,
 		ElectionTimeout: testElectionTimeout,
 		MaxClusterSize: 9,
 		MaxClusterSize: 9,
+		CORS: corsInfo,
 	}
 	}
 	ps := server.NewPeerServer(psConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, registry, store, nil)
 	ps := server.NewPeerServer(psConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, registry, store, nil)
 
 
@@ -42,6 +44,7 @@ func RunServer(f func(*server.Server)) {
 		Name: testName,
 		Name: testName,
 		URL: "http://"+testClientURL,
 		URL: "http://"+testClientURL,
 		BindAddr: testClientURL,
 		BindAddr: testClientURL,
+		CORS: corsInfo,
 	}
 	}
 	s := server.New(sConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, ps, registry, store, nil)
 	s := server.New(sConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, ps, registry, store, nil)