Browse Source

feat(server): introduce a cors handler

Introduce a handler that lives under the gorilla mux and adds the
correct headers based on the request and always returns 200 OK when
there OPTIONS is called on a URL.

This fixes the ability to DELETE from the dashboard on peer X when peer
Y is the leader. As a side effect it reveals some bugs in the dashboard
though notably:

- Due to the RTT immediatly refreshing the dashboard doesn't work and
  deleted keys are still there

- For some reason PUTS from peer X are creating directories and not
  keys.
Brandon Philips 12 years ago
parent
commit
a1ec895b91
2 changed files with 94 additions and 30 deletions
  1. 78 0
      server/cors_handler.go
  2. 16 30
      server/server.go

+ 78 - 0
server/cors_handler.go

@@ -0,0 +1,78 @@
+/*
+Copyright 2013 CoreOS Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package server
+
+import (
+	"fmt"
+	"net/http"
+	"net/url"
+
+	"github.com/gorilla/mux"
+)
+
+type corsHandler struct {
+	router *mux.Router
+	corsOrigins map[string]bool
+}
+
+// AllowOrigins sets a comma-delimited list of origins that are allowed.
+func (s *corsHandler) AllowOrigins(origins []string) error {
+	// Construct a lookup of all origins.
+	m := make(map[string]bool)
+	for _, v := range origins {
+		if v != "*" {
+			if _, err := url.Parse(v); err != nil {
+				return fmt.Errorf("Invalid CORS origin: %s", err)
+			}
+		}
+		m[v] = true
+	}
+	s.corsOrigins = m
+
+	return nil
+}
+
+// OriginAllowed determines whether the server will allow a given CORS origin.
+func (c *corsHandler) OriginAllowed(origin string) bool {
+	return c.corsOrigins["*"] || c.corsOrigins[origin]
+}
+
+// addHeader adds the correct cors headers given an origin
+func (h *corsHandler) addHeader(w http.ResponseWriter, origin string) {
+	w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
+	w.Header().Add("Access-Control-Allow-Origin", origin)
+}
+
+// ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
+// with a 200 OK if the method is OPTIONS.
+func (h *corsHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	// Write CORS header.
+	if h.OriginAllowed("*") {
+		h.addHeader(w, "*")
+	} else if origin := req.Header.Get("Origin"); h.OriginAllowed(origin) {
+		h.addHeader(w, origin)
+	}
+
+	if req.Method == "OPTIONS" {
+		w.WriteHeader(http.StatusOK)
+		return
+	}
+
+	h.router.ServeHTTP(w, req)
+}
+
+

+ 16 - 30
server/server.go

@@ -6,7 +6,6 @@ import (
 	"fmt"
 	"net"
 	"net/http"
-	"net/url"
 	"strings"
 	"time"
 
@@ -32,14 +31,18 @@ type Server struct {
 	url         string
 	tlsConf     *TLSConfig
 	tlsInfo     *TLSInfo
-	corsOrigins map[string]bool
+	router      *mux.Router
+	corsHandler *corsHandler
 }
 
 // Creates a new Server.
 func New(name string, urlStr string, bindAddr string, tlsConf *TLSConfig, tlsInfo *TLSInfo, peerServer *PeerServer, registry *Registry, store store.Store) *Server {
+	r := mux.NewRouter()
+	cors := &corsHandler{router: r}
+
 	s := &Server{
 		Server: http.Server{
-			Handler:   mux.NewRouter(),
+			Handler:   cors,
 			TLSConfig: &tlsConf.Server,
 			Addr:      bindAddr,
 		},
@@ -50,6 +53,8 @@ func New(name string, urlStr string, bindAddr string, tlsConf *TLSConfig, tlsInf
 		tlsConf:    tlsConf,
 		tlsInfo:    tlsInfo,
 		peerServer: peerServer,
+		router:     r,
+		corsHandler: cors,
 	}
 
 	// Install the routes.
@@ -124,7 +129,7 @@ func (s *Server) installV2() {
 }
 
 func (s *Server) installMod() {
-	r := s.Handler.(*mux.Router)
+	r := s.router
 	r.PathPrefix("/mod").Handler(http.StripPrefix("/mod", mod.HttpHandler()))
 }
 
@@ -144,20 +149,13 @@ func (s *Server) handleFuncV2(path string, f func(http.ResponseWriter, *http.Req
 
 // Adds a server handler to the router.
 func (s *Server) handleFunc(path string, f func(http.ResponseWriter, *http.Request) error) *mux.Route {
-	r := s.Handler.(*mux.Router)
+	r := s.router
 
 	// Wrap the standard HandleFunc interface to pass in the server reference.
 	return r.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
 		// Log request.
 		log.Debugf("[recv] %s %s %s [%s]", req.Method, s.url, req.URL.Path, req.RemoteAddr)
 
-		// Write CORS header.
-		if s.OriginAllowed("*") {
-			w.Header().Add("Access-Control-Allow-Origin", "*")
-		} else if origin := req.Header.Get("Origin"); s.OriginAllowed(origin) {
-			w.Header().Add("Access-Control-Allow-Origin", origin)
-		}
-
 		// Execute handler function and return error if necessary.
 		if err := f(w, req); err != nil {
 			if etcdErr, ok := err.(*etcdErr.Error); ok {
@@ -302,26 +300,14 @@ func (s *Server) Dispatch(c raft.Command, w http.ResponseWriter, req *http.Reque
 	}
 }
 
-// Sets a comma-delimited list of origins that are allowed.
-func (s *Server) AllowOrigins(origins []string) error {
-	// Construct a lookup of all origins.
-	m := make(map[string]bool)
-	for _, v := range origins {
-		if v != "*" {
-			if _, err := url.Parse(v); err != nil {
-				return fmt.Errorf("Invalid CORS origin: %s", err)
-			}
-		}
-		m[v] = true
-	}
-	s.corsOrigins = m
-
-	return nil
+// OriginAllowed determines whether the server will allow a given CORS origin.
+func (s *Server) OriginAllowed(origin string) bool {
+	return s.corsHandler.OriginAllowed(origin)
 }
 
-// Determines whether the server will allow a given CORS origin.
-func (s *Server) OriginAllowed(origin string) bool {
-	return s.corsOrigins["*"] || s.corsOrigins[origin]
+// AllowOrigins sets a comma-delimited list of origins that are allowed.
+func (s *Server) AllowOrigins(origins []string) error {
+	return s.corsHandler.AllowOrigins(origins)
 }
 
 // Handler to return the current version of etcd.