Browse Source

refactor(server): drop Serve code; rename cors object

* server/cors.go renamed to http/cors.go
* all CORS code removed from Server and PeerServer
* Server and PeerServer fulfill http.Handler, now passed to http.Serve
* non-HTTP code in PeerServer.Serve moved to PeerServer.Start
Brian Waldon 12 years ago
parent
commit
0abd860f7e
6 changed files with 50 additions and 57 deletions
  1. 8 4
      etcd.go
  2. 9 9
      http/cors.go
  3. 19 19
      server/peer_server.go
  4. 0 7
      server/peer_server_handlers.go
  5. 7 11
      server/server.go
  6. 7 7
      tests/server_utils.go

+ 8 - 4
etcd.go

@@ -26,6 +26,7 @@ import (
 
 	"github.com/coreos/raft"
 
+	ehttp "github.com/coreos/etcd/http"
 	"github.com/coreos/etcd/log"
 	"github.com/coreos/etcd/metrics"
 	"github.com/coreos/etcd/server"
@@ -102,7 +103,7 @@ func main() {
 	}
 
 	// Retrieve CORS configuration
-	corsInfo, err := server.NewCORSInfo(config.CorsOrigins)
+	corsInfo, err := ehttp.NewCORSInfo(config.CorsOrigins)
 	if err != nil {
 		log.Fatal("CORS:", err)
 	}
@@ -130,7 +131,6 @@ func main() {
 		SnapshotCount:    config.SnapshotCount,
 		MaxClusterSize:   config.MaxClusterSize,
 		RetryTimes:       config.MaxRetryAttempts,
-		CORS:             corsInfo,
 	}
 	ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
 
@@ -177,12 +177,16 @@ func main() {
 
 	ps.SetServer(s)
 
+	ps.Start(config.Snapshot, config.Peers)
+
 	// Run peer server in separate thread while the client server blocks.
 	go func() {
-		log.Fatal(ps.Serve(psListener, config.Snapshot, config.Peers))
+		log.Infof("raft server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL)
+		sHTTP := &ehttp.CORSHandler{ps, corsInfo}
+		log.Fatal(http.Serve(psListener, sHTTP))
 	}()
 
 	log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Config.Name, sListener.Addr(), s.Config.URL)
-	sHTTP := &server.CORSHTTPMiddleware{s, corsInfo}
+	sHTTP := &ehttp.CORSHandler{s, corsInfo}
 	log.Fatal(http.Serve(sListener, sHTTP))
 }

+ 9 - 9
server/cors.go → http/cors.go

@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 
-package server
+package http
 
 import (
 	"fmt"
@@ -22,9 +22,9 @@ import (
 	"net/url"
 )
 
-type corsInfo map[string]bool
+type CORSInfo map[string]bool
 
-func NewCORSInfo(origins []string) (*corsInfo, error) {
+func NewCORSInfo(origins []string) (*CORSInfo, error) {
 	// Construct a lookup of all origins.
 	m := make(map[string]bool)
 	for _, v := range origins {
@@ -36,29 +36,29 @@ func NewCORSInfo(origins []string) (*corsInfo, error) {
 		m[v] = true
 	}
 
-	info := corsInfo(m)
+	info := CORSInfo(m)
 	return &info, nil
 }
 
 // OriginAllowed determines whether the server will allow a given CORS origin.
-func (c corsInfo) OriginAllowed(origin string) bool {
+func (c CORSInfo) OriginAllowed(origin string) bool {
 	return c["*"] || c[origin]
 }
 
-type CORSHTTPMiddleware struct {
+type CORSHandler struct {
 	Handler http.Handler
-	Info    *corsInfo
+	Info    *CORSInfo
 }
 
 // addHeader adds the correct cors headers given an origin
-func (h *CORSHTTPMiddleware) addHeader(w http.ResponseWriter, origin string) {
+func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
 	w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
 	w.Header().Add("Access-Control-Allow-Origin", origin)
 }
 
 // ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
 // with a 200 OK if the method is OPTIONS.
-func (h *CORSHTTPMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 	// Write CORS header.
 	if h.Info.OriginAllowed("*") {
 		h.addHeader(w, "*")

+ 19 - 19
server/peer_server.go

@@ -6,7 +6,6 @@ import (
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
-	"net"
 	"net/http"
 	"net/url"
 	"strconv"
@@ -35,11 +34,11 @@ type PeerServerConfig struct {
 	ElectionTimeout  time.Duration
 	MaxClusterSize   int
 	RetryTimes       int
-	CORS             *corsInfo
 }
 
 type PeerServer struct {
 	Config         PeerServerConfig
+	handler        http.Handler
 	raftServer     raft.Server
 	server         *Server
 	joinIndex      uint64
@@ -49,8 +48,6 @@ type PeerServer struct {
 	store          store.Store
 	snapConf       *snapshotConf
 
-	listener net.Listener
-
 	closeChan            chan bool
 	timeoutThresholdChan chan interface{}
 
@@ -82,6 +79,9 @@ func NewPeerServer(psConfig PeerServerConfig, registry *Registry, store store.St
 
 		metrics: mb,
 	}
+
+	s.handler = s.buildHTTPHandler()
+
 	return s
 }
 
@@ -107,7 +107,7 @@ func (s *PeerServer) SetRaftServer(raftServer raft.Server) {
 }
 
 // Start the raft server
-func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []string) error {
+func (s *PeerServer) Start(snapshot bool, cluster []string) error {
 	// LoadSnapshot
 	if snapshot {
 		err := s.raftServer.LoadSnapshot()
@@ -157,8 +157,18 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
 		go s.monitorSnapshot()
 	}
 
+	return nil
+}
+
+func (s *PeerServer) Stop() {
+	if s.closeChan != nil {
+		close(s.closeChan)
+		s.closeChan = nil
+	}
+}
+
+func (s *PeerServer) buildHTTPHandler() http.Handler {
 	router := mux.NewRouter()
-	httpServer := &http.Server{Handler: router}
 
 	// internal commands
 	router.HandleFunc("/name", s.NameHttpHandler)
@@ -174,21 +184,11 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
 	router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
 	router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
 
-	s.listener = listener
-	log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, listener.Addr(), s.Config.URL)
-	httpServer.Serve(listener)
-	return nil
+	return router
 }
 
-func (s *PeerServer) Close() {
-	if s.closeChan != nil {
-		close(s.closeChan)
-		s.closeChan = nil
-	}
-	if s.listener != nil {
-		s.listener.Close()
-		s.listener = nil
-	}
+func (s *PeerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	s.handler.ServeHTTP(w, r)
 }
 
 // Retrieves the underlying Raft server.

+ 0 - 7
server/peer_server_handlers.go

@@ -149,13 +149,6 @@ func (ps *PeerServer) EtcdURLHttpHandler(w http.ResponseWriter, req *http.Reques
 func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) {
 	command := &JoinCommand{}
 
-	// Write CORS header.
-	if ps.Config.CORS.OriginAllowed("*") {
-		w.Header().Add("Access-Control-Allow-Origin", "*")
-	} else if ps.Config.CORS.OriginAllowed(req.Header.Get("Origin")) {
-		w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin"))
-	}
-
 	err := decodeJsonRequest(req, command)
 	if err != nil {
 		w.WriteHeader(http.StatusInternalServerError)

+ 7 - 11
server/server.go

@@ -3,7 +3,6 @@ package server
 import (
 	"encoding/json"
 	"fmt"
-	"net"
 	"net/http"
 	"net/http/pprof"
 	"strings"
@@ -30,13 +29,12 @@ type ServerConfig struct {
 // This is the default implementation of the Server interface.
 type Server struct {
 	Config         ServerConfig
+	handler        http.Handler
 	peerServer     *PeerServer
 	registry       *Registry
 	store          store.Store
 	metrics        *metrics.Bucket
 
-	listener net.Listener
-
 	trace          bool
 }
 
@@ -50,6 +48,8 @@ func New(sConfig ServerConfig, peerServer *PeerServer, registry *Registry, store
 		metrics:     mb,
 	}
 
+	s.handler = s.buildHTTPHandler()
+
 	return s
 }
 
@@ -172,7 +172,7 @@ func (s *Server) handleFunc(r *mux.Router, path string, f func(http.ResponseWrit
 	})
 }
 
-func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+func (s *Server) buildHTTPHandler() http.Handler {
 	router := mux.NewRouter()
 
 	// Install the routes.
@@ -185,15 +185,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		s.installDebug(router)
 	}
 
-	router.ServeHTTP(w, r)
+	return router
 }
 
-// Stops the server.
-func (s *Server) Close() {
-	if s.listener != nil {
-		s.listener.Close()
-		s.listener = nil
-	}
+func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	s.handler.ServeHTTP(w, r)
 }
 
 // Dispatch command to the current leader

+ 7 - 7
tests/server_utils.go

@@ -2,6 +2,7 @@ package tests
 
 import (
 	"io/ioutil"
+	"net/http"
 	"os"
 	"time"
 
@@ -27,7 +28,6 @@ func RunServer(f func(*server.Server)) {
 
 	store := store.New()
 	registry := server.NewRegistry(store)
-	corsInfo, _ := server.NewCORSInfo([]string{})
 
 	serverStats := server.NewRaftServerStats(testName)
 	followersStats := server.NewRaftFollowersStats(testName)
@@ -39,7 +39,6 @@ func RunServer(f func(*server.Server)) {
 		Scheme: "http",
 		SnapshotCount: testSnapshotCount,
 		MaxClusterSize: 9,
-		CORS: corsInfo,
 	}
 	ps := server.NewPeerServer(psConfig, registry, store, nil, followersStats, serverStats)
 	psListener, err := server.NewListener(testRaftURL)
@@ -63,7 +62,6 @@ func RunServer(f func(*server.Server)) {
 	sConfig := server.ServerConfig{
 		Name: testName,
 		URL: "http://"+testClientURL,
-		CORS: corsInfo,
 	}
 	s := server.New(sConfig, ps, registry, store, nil)
 	sListener, err := server.NewListener(testClientURL)
@@ -77,14 +75,15 @@ func RunServer(f func(*server.Server)) {
 	c := make(chan bool)
 	go func() {
 		c <- true
-		ps.Serve(psListener, false, []string{})
+		ps.Start(false, []string{})
+		http.Serve(psListener, ps)
 	}()
 	<-c
 
 	// Start up etcd server.
 	go func() {
 		c <- true
-		s.Serve(sListener)
+		http.Serve(sListener, s)
 	}()
 	<-c
 
@@ -95,6 +94,7 @@ func RunServer(f func(*server.Server)) {
 	f(s)
 
 	// Clean up servers.
-	ps.Close()
-	s.Close()
+	ps.Stop()
+	psListener.Close()
+	sListener.Close()
 }