Browse Source

Fixes CORS

Manu Mtz.-Almeida 10 năm trước cách đây
mục cha
commit
b08ec4416c
4 tập tin đã thay đổi với 271 bổ sung213 xóa
  1. 22 118
      config.go
  2. 5 13
      cors.go
  3. 175 82
      cors_test.go
  4. 69 0
      utils.go

+ 22 - 118
config.go

@@ -2,10 +2,6 @@ package cors
 
 import (
 	"net/http"
-	"net/textproto"
-	"strconv"
-	"strings"
-	"time"
 
 	"github.com/gin-gonic/gin"
 )
@@ -14,25 +10,21 @@ type cors struct {
 	allowAllOrigins   bool
 	allowedOriginFunc func(string) bool
 	allowedOrigins    []string
-	allowedMethods    []string
-	allowedHeaders    []string
 	exposedHeaders    []string
 	normalHeaders     http.Header
 	preflightHeaders  http.Header
 }
 
-func newCors(c Config) *cors {
-	if err := c.Validate(); err != nil {
+func newCors(config Config) *cors {
+	if err := config.Validate(); err != nil {
 		panic(err.Error())
 	}
 	return &cors{
-		allowedOriginFunc: c.AllowOriginFunc,
-		allowAllOrigins:   c.AllowAllOrigins,
-		allowedOrigins:    normalize(c.AllowedOrigins),
-		allowedMethods:    normalize(c.AllowedMethods),
-		allowedHeaders:    normalize(c.AllowedHeaders),
-		normalHeaders:     generateNormalHeaders(c),
-		preflightHeaders:  generatePreflightHeaders(c),
+		allowedOriginFunc: config.AllowOriginFunc,
+		allowAllOrigins:   config.AllowAllOrigins,
+		allowedOrigins:    normalize(config.AllowedOrigins),
+		normalHeaders:     generateNormalHeaders(config),
+		preflightHeaders:  generatePreflightHeaders(config),
 	}
 }
 
@@ -43,135 +35,47 @@ func (cors *cors) applyCors(c *gin.Context) {
 		return
 	}
 	if !cors.validateOrigin(origin) {
-		goto failed
+		c.AbortWithStatus(http.StatusForbidden)
+		return
 	}
 
 	if c.Request.Method == "OPTIONS" {
-		if !cors.handlePreflight(c) {
-			goto failed
-		}
-	} else if !cors.handleNormal(c) {
-		goto failed
-	}
-	if cors.allowAllOrigins {
-		c.Header("Access-Control-Allow-Origin", "*")
+		cors.handlePreflight(c)
 	} else {
-		c.Header("Access-Control-Allow-Origin", origin)
+		cors.handleNormal(c)
 	}
-	return
 
-failed:
-	c.AbortWithStatus(http.StatusForbidden)
+	if !cors.allowAllOrigins {
+		c.Header("Access-Control-Allow-Origin", origin)
+	}
 }
 
 func (cors *cors) validateOrigin(origin string) bool {
 	if cors.allowAllOrigins {
 		return true
 	}
-	if cors.allowedOriginFunc != nil {
-		return cors.allowedOriginFunc(origin)
-	}
 	for _, value := range cors.allowedOrigins {
 		if value == origin {
 			return true
 		}
 	}
-	return false
-}
-
-func (cors *cors) validateMethod(method string) bool {
-	for _, value := range cors.allowedMethods {
-		if strings.EqualFold(value, method) {
-			return true
-		}
-	}
-	return false
-}
-
-func (cors *cors) validateHeader(header string) bool {
-	for _, value := range cors.allowedHeaders {
-		if strings.EqualFold(value, header) {
-			return true
-		}
+	if cors.allowedOriginFunc != nil {
+		return cors.allowedOriginFunc(origin)
 	}
 	return false
 }
 
-func (cors *cors) handlePreflight(c *gin.Context) bool {
+func (cors *cors) handlePreflight(c *gin.Context) {
 	c.AbortWithStatus(200)
-	if !cors.validateMethod(c.Request.Header.Get("Access-Control-Request-Method")) {
-		return false
-	}
-	if !cors.validateHeader(c.Request.Header.Get("Access-Control-Request-Header")) {
-		return false
-	}
+	header := c.Writer.Header()
 	for key, value := range cors.preflightHeaders {
-		c.Writer.Header()[key] = value
+		header[key] = value
 	}
-	return true
 }
 
-func (cors *cors) handleNormal(c *gin.Context) bool {
+func (cors *cors) handleNormal(c *gin.Context) {
+	header := c.Writer.Header()
 	for key, value := range cors.normalHeaders {
-		c.Writer.Header()[key] = value
-	}
-	return true
-}
-
-func generateNormalHeaders(c Config) http.Header {
-	headers := make(http.Header)
-	if c.AllowCredentials {
-		headers.Set("Access-Control-Allow-Credentials", "true")
-	}
-	if len(c.ExposedHeaders) > 0 {
-		headers.Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ", "))
-	}
-	if c.AllowAllOrigins {
-		headers.Set("Access-Control-Allow-Origin", "*")
-	} else {
-		headers.Set("Vary", "Origin")
-	}
-	return headers
-}
-
-func generatePreflightHeaders(c Config) http.Header {
-	headers := make(http.Header)
-	if c.AllowCredentials {
-		headers.Set("Access-Control-Allow-Credentials", "true")
-	}
-	if len(c.AllowedMethods) > 0 {
-		value := strings.Join(c.AllowedMethods, ", ")
-		headers.Set("Access-Control-Allow-Methods", value)
-	}
-	if len(c.AllowedHeaders) > 0 {
-		value := strings.Join(c.AllowedHeaders, ", ")
-		headers.Set("Access-Control-Allow-Headers", value)
-	}
-	if c.MaxAge > time.Duration(0) {
-		value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
-		headers.Set("Access-Control-Max-Age", value)
-	}
-	if c.AllowAllOrigins {
-		headers.Set("Access-Control-Allow-Origin", "*")
-	} else {
-		headers.Set("Vary", "Origin")
-	}
-	return headers
-}
-
-func normalize(values []string) []string {
-	if values == nil {
-		return nil
-	}
-	distinctMap := make(map[string]bool, len(values))
-	normalized := make([]string, 0, len(values))
-	for _, value := range values {
-		value = strings.TrimSpace(value)
-		value = textproto.CanonicalMIMEHeaderKey(value)
-		if _, seen := distinctMap[value]; !seen {
-			normalized = append(normalized, value)
-			distinctMap[value] = true
-		}
+		header[key] = value
 	}
-	return normalized
 }

+ 5 - 13
cors.go

@@ -9,7 +9,6 @@ import (
 )
 
 type Config struct {
-	AbortOnError    bool
 	AllowAllOrigins bool
 
 	// AllowedOrigins is a list of origins a cross-domain request can be executed from.
@@ -64,9 +63,6 @@ func (c Config) Validate() error {
 	if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowedOrigins) == 0 {
 		return errors.New("conflict settings: all origins disabled")
 	}
-	if c.AllowOriginFunc != nil && len(c.AllowedOrigins) > 0 {
-		return errors.New("conflict settings: if a allow origin func is provided, AllowedOrigins is not needed")
-	}
 	for _, origin := range c.AllowedOrigins {
 		if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
 			return errors.New("bad origin: origins must include http:// or https://")
@@ -77,11 +73,9 @@ func (c Config) Validate() error {
 
 func DefaultConfig() Config {
 	return Config{
-		AbortOnError:    false,
-		AllowAllOrigins: true,
-		AllowedMethods:  []string{"GET", "POST", "PUT", "PATCH", "HEAD"},
-		AllowedHeaders:  []string{"Content-Type"},
-		//ExposedHeaders:   "",
+		AllowAllOrigins:  true,
+		AllowedMethods:   []string{"GET", "POST", "PUT", "HEAD"},
+		AllowedHeaders:   []string{"Content-Type"},
 		AllowCredentials: false,
 		MaxAge:           12 * time.Hour,
 	}
@@ -92,10 +86,8 @@ func Default() gin.HandlerFunc {
 }
 
 func New(config Config) gin.HandlerFunc {
-	s := newCors(config)
-
-	// Algorithm based in http://www.html5rocks.com/static/images/cors_server_flowchart.png
+	cors := newCors(config)
 	return func(c *gin.Context) {
-		s.applyCors(c)
+		cors.applyCors(c)
 	}
 }

+ 175 - 82
cors_test.go

@@ -4,6 +4,7 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"testing"
+	"time"
 
 	"github.com/gin-gonic/gin"
 	"github.com/stretchr/testify/assert"
@@ -34,12 +35,6 @@ func TestBadConfig(t *testing.T) {
 			AllowOriginFunc: func(origin string) bool { return false },
 		})
 	})
-	assert.Panics(t, func() {
-		New(Config{
-			AllowedOrigins:  []string{"http://google.com"},
-			AllowOriginFunc: func(origin string) bool { return false },
-		})
-	})
 	assert.Panics(t, func() {
 		New(Config{
 			AllowedOrigins: []string{"google.com"},
@@ -49,10 +44,10 @@ func TestBadConfig(t *testing.T) {
 
 func TestNormalize(t *testing.T) {
 	values := normalize([]string{
-		"http-access ", "post", "POST", " poSt  ",
+		"http-Access ", "Post", "POST", " poSt  ",
 		"HTTP-Access", "",
 	})
-	assert.Equal(t, values, []string{"Http-Access", "Post", ""})
+	assert.Equal(t, values, []string{"http-access", "post", ""})
 
 	values = normalize(nil)
 	assert.Nil(t, values)
@@ -61,91 +56,189 @@ func TestNormalize(t *testing.T) {
 	assert.Equal(t, values, []string{})
 }
 
-func TestGenerateNormalHeaders(t *testing.T) {
+func TestGenerateNormalHeaders_AllowAllOrigins(t *testing.T) {
 	header := generateNormalHeaders(Config{
 		AllowAllOrigins: false,
 	})
-	assert.Contains(t, header.Get("Access-Control-Allow-Origin"), "")
-	assert.Contains(t, header.Get("Vary"), "Origin")
+	assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
+	assert.Equal(t, header.Get("Vary"), "Origin")
+	assert.Len(t, header, 1)
 
 	header = generateNormalHeaders(Config{
 		AllowAllOrigins: true,
 	})
-	assert.Contains(t, header.Get("Access-Control-Allow-Origin"), "*")
-	assert.Contains(t, header.Get("Vary"), "")
+	assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
+	assert.Equal(t, header.Get("Vary"), "")
+	assert.Len(t, header, 1)
+}
 
-	header = generateNormalHeaders(Config{
+func TestGenerateNormalHeaders_AllowCredentials(t *testing.T) {
+	header := generateNormalHeaders(Config{
 		AllowCredentials: true,
 	})
-	assert.Contains(t, header.Get("Access-Control-Allow-Credentials"), "true")
+	assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
+	assert.Equal(t, header.Get("Vary"), "Origin")
+	assert.Len(t, header, 2)
+}
 
-	header = generateNormalHeaders(Config{
-		AllowCredentials: false,
+func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) {
+	header := generateNormalHeaders(Config{
+		ExposedHeaders: []string{"X-user", "xPassword"},
 	})
-	assert.Contains(t, header.Get("Access-Control-Allow-Credentials"), "")
+	assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "x-user, xpassword")
+	assert.Equal(t, header.Get("Vary"), "Origin")
+	assert.Len(t, header, 2)
+}
+
+func TestGeneratePreflightHeaders(t *testing.T) {
+	header := generatePreflightHeaders(Config{
+		AllowAllOrigins: false,
+	})
+	assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
+	assert.Equal(t, header.Get("Vary"), "Origin")
+	assert.Len(t, header, 1)
 
 	header = generateNormalHeaders(Config{
-		ExposedHeaders: []string{"x-user", "xpassword"},
-	})
-	assert.Contains(t, header.Get("Access-Control-Expose-Headers"), "x-user, xpassword")
-}
-
-//
-// func TestDeny0(t *testing.T) {
-// 	called := false
-//
-// 	router := gin.New()
-// 	router.Use(New(Config{
-// 		AllowedOrigins: []string{"http://example.com"},
-// 	}))
-// 	router.GET("/", func(c *gin.Context) {
-// 		called = true
-// 	})
-// 	w := httptest.NewRecorder()
-// 	req, _ := http.NewRequest("GET", "/", nil)
-// 	req.Header.Set("Origin", "https://example.com")
-// 	router.ServeHTTP(w, req)
-//
-// 	assert.True(t, called)
-// 	assert.NotContains(t, w.Header(), "Access-Control")
-// }
-//
-// func TestDenyAbortOnError(t *testing.T) {
-// 	called := false
-//
-// 	router := gin.New()
-// 	router.Use(New(Config{
-// 		AbortOnError:   true,
-// 		AllowedOrigins: []string{"http://example.com"},
-// 	}))
-// 	router.GET("/", func(c *gin.Context) {
-// 		called = true
-// 	})
-//
-// 	w := httptest.NewRecorder()
-// 	req, _ := http.NewRequest("GET", "/", nil)
-// 	req.Header.Set("Origin", "https://example.com")
-// 	router.ServeHTTP(w, req)
-//
-// 	assert.False(t, called)
-// 	assert.NotContains(t, w.Header(), "Access-Control")
-// }
-//
-// func TestDeny2(t *testing.T) {
-//
-// }
-// func TestDeny3(t *testing.T) {
-//
-// }
-//
-// func TestPasses0(t *testing.T) {
-//
-// }
-//
-// func TestPasses1(t *testing.T) {
-//
-// }
-//
-// func TestPasses2(t *testing.T) {
-//
-// }
+		AllowAllOrigins: true,
+	})
+	assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
+	assert.Equal(t, header.Get("Vary"), "")
+	assert.Len(t, header, 1)
+}
+
+func TestGeneratePreflightHeaders_AllowCredentials(t *testing.T) {
+	header := generatePreflightHeaders(Config{
+		AllowCredentials: true,
+	})
+	assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
+	assert.Equal(t, header.Get("Vary"), "Origin")
+	assert.Len(t, header, 2)
+}
+
+func TestGeneratePreflightHeaders_AllowedMethods(t *testing.T) {
+	header := generatePreflightHeaders(Config{
+		AllowedMethods: []string{"GET ", "post", "PUT", " put  "},
+	})
+	assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "get, post, put")
+	assert.Equal(t, header.Get("Vary"), "Origin")
+	assert.Len(t, header, 2)
+}
+
+func TestGeneratePreflightHeaders_AllowedHeaders(t *testing.T) {
+	header := generatePreflightHeaders(Config{
+		AllowedHeaders: []string{"X-user", "Content-Type"},
+	})
+	assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "x-user, content-type")
+	assert.Equal(t, header.Get("Vary"), "Origin")
+	assert.Len(t, header, 2)
+}
+
+func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
+	header := generatePreflightHeaders(Config{
+		MaxAge: 12 * time.Hour,
+	})
+	assert.Equal(t, header.Get("Access-Control-Max-Age"), "43200") // 12*60*60
+	assert.Equal(t, header.Get("Vary"), "Origin")
+	assert.Len(t, header, 2)
+}
+
+func TestValidateOrigin(t *testing.T) {
+	cors := newCors(Config{
+		AllowAllOrigins: true,
+	})
+	assert.True(t, cors.validateOrigin("http://google.com"))
+	assert.True(t, cors.validateOrigin("https://google.com"))
+	assert.True(t, cors.validateOrigin("example.com"))
+
+	cors = newCors(Config{
+		AllowedOrigins: []string{"https://google.com", "https://github.com"},
+		AllowOriginFunc: func(origin string) bool {
+			return (origin == "http://news.ycombinator.com")
+		},
+	})
+	assert.False(t, cors.validateOrigin("http://google.com"))
+	assert.True(t, cors.validateOrigin("https://google.com"))
+	assert.True(t, cors.validateOrigin("https://github.com"))
+	assert.True(t, cors.validateOrigin("http://news.ycombinator.com"))
+	assert.False(t, cors.validateOrigin("http://example.com"))
+	assert.False(t, cors.validateOrigin("google.com"))
+}
+
+func TestPasses0(t *testing.T) {
+	called := false
+	router := gin.New()
+	router.Use(New(Config{
+		AllowedOrigins:   []string{"http://google.com"},
+		AllowedMethods:   []string{" GeT ", "get", "post", "PUT  ", "Head", "POST"},
+		AllowedHeaders:   []string{"Content-type", "timeStamp "},
+		ExposedHeaders:   []string{"Data", "x-User"},
+		AllowCredentials: true,
+		MaxAge:           12 * time.Hour,
+		AllowOriginFunc: func(origin string) bool {
+			return origin == "http://github.com"
+		},
+	}))
+	router.GET("/", func(c *gin.Context) {
+		called = true
+	})
+
+	w := httptest.NewRecorder()
+	req, _ := http.NewRequest("GET", "/", nil)
+	router.ServeHTTP(w, req)
+	assert.True(t, called)
+	assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin")
+	assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials")
+	assert.NotContains(t, w.Header(), "Access-Control-Expose-Headers")
+
+	called = false
+	w = httptest.NewRecorder()
+	req, _ = http.NewRequest("GET", "/", nil)
+	req.Header.Set("Origin", "http://google.com")
+	router.ServeHTTP(w, req)
+	assert.True(t, called)
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://google.com")
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true")
+	assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data, x-user")
+
+	called = false
+	w = httptest.NewRecorder()
+	req, _ = http.NewRequest("GET", "/", nil)
+	req.Header.Set("Origin", "https://google.com")
+	router.ServeHTTP(w, req)
+	assert.False(t, called)
+	assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin")
+	assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials")
+	assert.NotContains(t, w.Header(), "Access-Control-Expose-Headers")
+
+	called = false
+	w = httptest.NewRecorder()
+	req, _ = http.NewRequest("OPTIONS", "/", nil)
+	req.Header.Set("Origin", "http://github.com")
+	router.ServeHTTP(w, req)
+	assert.False(t, called)
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://github.com")
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true")
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "get, post, put, head")
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type, timestamp")
+	assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "43200")
+
+	called = false
+	w = httptest.NewRecorder()
+	req, _ = http.NewRequest("OPTIONS", "/", nil)
+	req.Header.Set("Origin", "http://example.com")
+	router.ServeHTTP(w, req)
+	assert.False(t, called)
+	assert.NotContains(t, w.Header(), "Access-Control-Allow-Origin")
+	assert.NotContains(t, w.Header(), "Access-Control-Allow-Credentials")
+	assert.NotContains(t, w.Header(), "Access-Control-Allow-Methods")
+	assert.NotContains(t, w.Header(), "Access-Control-Allow-Headers")
+	assert.NotContains(t, w.Header(), "Access-Control-Max-Age")
+}
+
+func TestPasses1(t *testing.T) {
+
+}
+
+func TestPasses2(t *testing.T) {
+
+}

+ 69 - 0
utils.go

@@ -0,0 +1,69 @@
+package cors
+
+import (
+	"net/http"
+	"strconv"
+	"strings"
+	"time"
+)
+
+func generateNormalHeaders(c Config) http.Header {
+	headers := make(http.Header)
+	if c.AllowCredentials {
+		headers.Set("Access-Control-Allow-Credentials", "true")
+	}
+	if len(c.ExposedHeaders) > 0 {
+		exposedHeaders := normalize(c.ExposedHeaders)
+		headers.Set("Access-Control-Expose-Headers", strings.Join(exposedHeaders, ", "))
+	}
+	if c.AllowAllOrigins {
+		headers.Set("Access-Control-Allow-Origin", "*")
+	} else {
+		headers.Set("Vary", "Origin")
+	}
+	return headers
+}
+
+func generatePreflightHeaders(c Config) http.Header {
+	headers := make(http.Header)
+	if c.AllowCredentials {
+		headers.Set("Access-Control-Allow-Credentials", "true")
+	}
+	if len(c.AllowedMethods) > 0 {
+		allowedMethods := normalize(c.AllowedMethods)
+		value := strings.Join(allowedMethods, ", ")
+		headers.Set("Access-Control-Allow-Methods", value)
+	}
+	if len(c.AllowedHeaders) > 0 {
+		allowedHeaders := normalize(c.AllowedHeaders)
+		value := strings.Join(allowedHeaders, ", ")
+		headers.Set("Access-Control-Allow-Headers", value)
+	}
+	if c.MaxAge > time.Duration(0) {
+		value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
+		headers.Set("Access-Control-Max-Age", value)
+	}
+	if c.AllowAllOrigins {
+		headers.Set("Access-Control-Allow-Origin", "*")
+	} else {
+		headers.Set("Vary", "Origin")
+	}
+	return headers
+}
+
+func normalize(values []string) []string {
+	if values == nil {
+		return nil
+	}
+	distinctMap := make(map[string]bool, len(values))
+	normalized := make([]string, 0, len(values))
+	for _, value := range values {
+		value = strings.TrimSpace(value)
+		value = strings.ToLower(value)
+		if _, seen := distinctMap[value]; !seen {
+			normalized = append(normalized, value)
+			distinctMap[value] = true
+		}
+	}
+	return normalized
+}