Browse Source

embed: support "CORS" handler in v3 HTTP requests

Signed-off-by: Gyuho Lee <gyuhox@gmail.com>
Gyuho Lee 7 years ago
parent
commit
9ea8be0c2b
2 changed files with 66 additions and 13 deletions
  1. 9 4
      embed/etcd.go
  2. 57 9
      embed/serve.go

+ 9 - 4
embed/etcd.go

@@ -23,6 +23,7 @@ import (
 	"net"
 	"net/http"
 	"net/url"
+	"sort"
 	"strconv"
 	"sync"
 	"time"
@@ -33,7 +34,6 @@ import (
 	"github.com/coreos/etcd/etcdserver/api/v2v3"
 	"github.com/coreos/etcd/etcdserver/api/v3client"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc"
-	"github.com/coreos/etcd/pkg/cors"
 	"github.com/coreos/etcd/pkg/debugutil"
 	runtimeutil "github.com/coreos/etcd/pkg/runtime"
 	"github.com/coreos/etcd/pkg/transport"
@@ -168,6 +168,7 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
 		StrictReconfigCheck:     cfg.StrictReconfigCheck,
 		ClientCertAuthEnabled:   cfg.ClientTLSInfo.ClientCertAuth,
 		AuthToken:               cfg.AuthToken,
+		CORS:                    cfg.CORS,
 		HostWhitelist:           cfg.HostWhitelist,
 		InitialCorruptCheck:     cfg.ExperimentalInitialCorruptCheck,
 		CorruptCheckTime:        cfg.ExperimentalCorruptCheckTime,
@@ -473,8 +474,13 @@ func (e *Etcd) serveClients() (err error) {
 		plog.Infof("ClientTLS: %s", e.cfg.ClientTLSInfo)
 	}
 
-	if e.cfg.CorsInfo.String() != "" {
-		plog.Infof("cors = %s", e.cfg.CorsInfo)
+	if len(e.cfg.CORS) > 0 {
+		ss := make([]string, 0, len(e.cfg.CORS))
+		for v := range e.cfg.CORS {
+			ss = append(ss, v)
+		}
+		sort.Strings(ss)
+		plog.Infof("cors = %q", ss)
 	}
 
 	// Start a client server goroutine for each listen address
@@ -491,7 +497,6 @@ func (e *Etcd) serveClients() (err error) {
 		etcdhttp.HandleBasic(mux, e.Server)
 		h = mux
 	}
-	h = http.Handler(&cors.CORSHandler{Handler: h, Info: e.cfg.CorsInfo})
 
 	gopts := []grpc.ServerOption{}
 	if e.cfg.GRPCKeepAliveMinTime > time.Duration(0) {

+ 57 - 9
embed/serve.go

@@ -116,7 +116,7 @@ func (sctx *serveCtx) serve(
 		httpmux := sctx.createMux(gwmux, handler)
 
 		srvhttp := &http.Server{
-			Handler:  wrapMux(s, httpmux),
+			Handler:  createAccessController(s, httpmux),
 			ErrorLog: logger, // do not log user error
 		}
 		httpl := m.Match(cmux.HTTP1())
@@ -159,7 +159,7 @@ func (sctx *serveCtx) serve(
 		httpmux := sctx.createMux(gwmux, handler)
 
 		srv := &http.Server{
-			Handler:   wrapMux(s, httpmux),
+			Handler:   createAccessController(s, httpmux),
 			TLSConfig: tlscfg,
 			ErrorLog:  logger, // do not log user error
 		}
@@ -250,20 +250,20 @@ func (sctx *serveCtx) createMux(gwmux *gw.ServeMux, handler http.Handler) *http.
 	return httpmux
 }
 
-// wrapMux wraps HTTP multiplexer:
+// createAccessController wraps HTTP multiplexer:
 // - mutate gRPC gateway request paths
 // - check hostname whitelist
 // client HTTP requests goes here first
-func wrapMux(s *etcdserver.EtcdServer, mux *http.ServeMux) http.Handler {
-	return &httpWrapper{s: s, mux: mux}
+func createAccessController(s *etcdserver.EtcdServer, mux *http.ServeMux) http.Handler {
+	return &accessController{s: s, mux: mux}
 }
 
-type httpWrapper struct {
+type accessController struct {
 	s   *etcdserver.EtcdServer
 	mux *http.ServeMux
 }
 
-func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+func (ac *accessController) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 	// redirect for backward compatibilities
 	if req != nil && req.URL != nil && strings.HasPrefix(req.URL.Path, "/v3beta/") {
 		req.URL.Path = strings.Replace(req.URL.Path, "/v3beta/", "/v3/", 1)
@@ -271,7 +271,7 @@ func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 
 	if req.TLS == nil { // check origin if client connection is not secure
 		host := httputil.GetHostname(req)
-		if !m.s.IsHostWhitelisted(host) {
+		if !ac.s.AccessController.IsHostWhitelisted(host) {
 			plog.Warningf("rejecting HTTP request from %q to prevent DNS rebinding attacks", host)
 			// TODO: use Go's "http.StatusMisdirectedRequest" (421)
 			// https://github.com/golang/go/commit/4b8a7eafef039af1834ef9bfa879257c4a72b7b5
@@ -280,7 +280,26 @@ func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 		}
 	}
 
-	m.mux.ServeHTTP(rw, req)
+	// Write CORS header.
+	if ac.s.AccessController.OriginAllowed("*") {
+		addCORSHeader(rw, "*")
+	} else if origin := req.Header.Get("Origin"); ac.s.OriginAllowed(origin) {
+		addCORSHeader(rw, origin)
+	}
+
+	if req.Method == "OPTIONS" {
+		rw.WriteHeader(http.StatusOK)
+		return
+	}
+
+	ac.mux.ServeHTTP(rw, req)
+}
+
+// addCORSHeader adds the correct cors headers given an origin
+func addCORSHeader(w http.ResponseWriter, origin string) {
+	w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
+	w.Header().Add("Access-Control-Allow-Origin", origin)
+	w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization")
 }
 
 // https://github.com/transmission/transmission/pull/468
@@ -297,6 +316,35 @@ This requirement has been added to help prevent "DNS Rebinding" attacks (CVE-201
 `, host)
 }
 
+// WrapCORS wraps existing handler with CORS.
+// TODO: deprecate this after v2 proxy deprecate
+func WrapCORS(cors map[string]struct{}, h http.Handler) http.Handler {
+	return &corsHandler{
+		ac: &etcdserver.AccessController{CORS: cors},
+		h:  h,
+	}
+}
+
+type corsHandler struct {
+	ac *etcdserver.AccessController
+	h  http.Handler
+}
+
+func (ch *corsHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+	if ch.ac.OriginAllowed("*") {
+		addCORSHeader(rw, "*")
+	} else if origin := req.Header.Get("Origin"); ch.ac.OriginAllowed(origin) {
+		addCORSHeader(rw, origin)
+	}
+
+	if req.Method == "OPTIONS" {
+		rw.WriteHeader(http.StatusOK)
+		return
+	}
+
+	ch.h.ServeHTTP(rw, req)
+}
+
 func (sctx *serveCtx) registerUserHandler(s string, h http.Handler) {
 	if sctx.userHandlers[s] != nil {
 		plog.Warningf("path %s already registered by user handler", s)