Explorar o código

Renames some options

Manu Mtz.-Almeida %!s(int64=10) %!d(string=hai) anos
pai
achega
ee70b8845a
Modificáronse 4 ficheiros con 66 adicións e 65 borrados
  1. 14 14
      config.go
  2. 21 20
      cors.go
  3. 22 22
      cors_test.go
  4. 9 9
      utils.go

+ 14 - 14
config.go

@@ -7,12 +7,12 @@ import (
 )
 
 type cors struct {
-	allowAllOrigins   bool
-	allowedOriginFunc func(string) bool
-	allowedOrigins    []string
-	exposedHeaders    []string
-	normalHeaders     http.Header
-	preflightHeaders  http.Header
+	allowAllOrigins  bool
+	allowOriginFunc  func(string) bool
+	allowOrigins     []string
+	exposeHeaders    []string
+	normalHeaders    http.Header
+	preflightHeaders http.Header
 }
 
 func newCors(config Config) *cors {
@@ -20,11 +20,11 @@ func newCors(config Config) *cors {
 		panic(err.Error())
 	}
 	return &cors{
-		allowedOriginFunc: config.AllowOriginFunc,
-		allowAllOrigins:   config.AllowAllOrigins,
-		allowedOrigins:    normalize(config.AllowedOrigins),
-		normalHeaders:     generateNormalHeaders(config),
-		preflightHeaders:  generatePreflightHeaders(config),
+		allowOriginFunc:  config.AllowOriginFunc,
+		allowAllOrigins:  config.AllowAllOrigins,
+		allowOrigins:     normalize(config.AllowOrigins),
+		normalHeaders:    generateNormalHeaders(config),
+		preflightHeaders: generatePreflightHeaders(config),
 	}
 }
 
@@ -54,13 +54,13 @@ func (cors *cors) validateOrigin(origin string) bool {
 	if cors.allowAllOrigins {
 		return true
 	}
-	for _, value := range cors.allowedOrigins {
+	for _, value := range cors.allowOrigins {
 		if value == origin {
 			return true
 		}
 	}
-	if cors.allowedOriginFunc != nil {
-		return cors.allowedOriginFunc(origin)
+	if cors.allowOriginFunc != nil {
+		return cors.allowOriginFunc(origin)
 	}
 	return false
 }

+ 21 - 20
cors.go

@@ -14,7 +14,7 @@ type Config struct {
 	// AllowedOrigins is a list of origins a cross-domain request can be executed from.
 	// If the special "*" value is present in the list, all origins will be allowed.
 	// Default value is ["*"]
-	AllowedOrigins []string
+	AllowOrigins []string
 
 	// AllowOriginFunc is a custom function to validate the origin. It take the origin
 	// as argument and returns true if allowed or false otherwise. If this option is
@@ -23,47 +23,47 @@ type Config struct {
 
 	// AllowedMethods is a list of methods the client is allowed to use with
 	// cross-domain requests. Default value is simple methods (GET and POST)
-	AllowedMethods []string
+	AllowMethods []string
 
 	// AllowedHeaders is list of non simple headers the client is allowed to use with
 	// cross-domain requests.
 	// If the special "*" value is present in the list, all headers will be allowed.
 	// Default value is [] but "Origin" is always appended to the list.
-	AllowedHeaders []string
-
-	// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
-	// API specification
-	ExposedHeaders []string
+	AllowHeaders []string
 
 	// AllowCredentials indicates whether the request can include user credentials like
 	// cookies, HTTP authentication or client side SSL certificates.
 	AllowCredentials bool
 
+	// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
+	// API specification
+	ExposeHeaders []string
+
 	// MaxAge indicates how long (in seconds) the results of a preflight request
 	// can be cached
 	MaxAge time.Duration
 }
 
-func (c *Config) AddAllowedMethods(methods ...string) {
-	c.AllowedMethods = append(c.AllowedMethods, methods...)
+func (c *Config) AddAllowMethods(methods ...string) {
+	c.AllowMethods = append(c.AllowMethods, methods...)
 }
 
-func (c *Config) AddAllowedHeaders(headers ...string) {
-	c.AllowedHeaders = append(c.AllowedHeaders, headers...)
+func (c *Config) AddAllowHeaders(headers ...string) {
+	c.AllowHeaders = append(c.AllowHeaders, headers...)
 }
 
-func (c *Config) AddExposedHeaders(headers ...string) {
-	c.ExposedHeaders = append(c.ExposedHeaders, headers...)
+func (c *Config) AddExposeHeaders(headers ...string) {
+	c.ExposeHeaders = append(c.ExposeHeaders, headers...)
 }
 
 func (c Config) Validate() error {
-	if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowedOrigins) > 0) {
+	if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
 		return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
 	}
-	if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowedOrigins) == 0 {
+	if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
 		return errors.New("conflict settings: all origins disabled")
 	}
-	for _, origin := range c.AllowedOrigins {
+	for _, origin := range c.AllowOrigins {
 		if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
 			return errors.New("bad origin: origins must include http:// or https://")
 		}
@@ -73,16 +73,17 @@ func (c Config) Validate() error {
 
 func DefaultConfig() Config {
 	return Config{
-		AllowAllOrigins:  true,
-		AllowedMethods:   []string{"GET", "POST", "PUT", "HEAD"},
-		AllowedHeaders:   []string{"Content-Type"},
+		AllowMethods:     []string{"GET", "POST", "PUT", "HEAD"},
+		AllowHeaders:     []string{"Origin", "Content-Length", "Content-Type"},
 		AllowCredentials: false,
 		MaxAge:           12 * time.Hour,
 	}
 }
 
 func Default() gin.HandlerFunc {
-	return New(DefaultConfig())
+	config := DefaultConfig()
+	config.AllowAllOrigins = true
+	return New(config)
 }
 
 func New(config Config) gin.HandlerFunc {

+ 22 - 22
cors_test.go

@@ -44,7 +44,7 @@ func TestBadConfig(t *testing.T) {
 	assert.Panics(t, func() {
 		New(Config{
 			AllowAllOrigins: true,
-			AllowedOrigins:  []string{"http://google.com"},
+			AllowOrigins:    []string{"http://google.com"},
 		})
 	})
 	assert.Panics(t, func() {
@@ -55,7 +55,7 @@ func TestBadConfig(t *testing.T) {
 	})
 	assert.Panics(t, func() {
 		New(Config{
-			AllowedOrigins: []string{"google.com"},
+			AllowOrigins: []string{"google.com"},
 		})
 	})
 }
@@ -101,9 +101,9 @@ func TestGenerateNormalHeaders_AllowCredentials(t *testing.T) {
 
 func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) {
 	header := generateNormalHeaders(Config{
-		ExposedHeaders: []string{"X-user", "xPassword"},
+		ExposeHeaders: []string{"X-user", "xPassword"},
 	})
-	assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "x-user, xpassword")
+	assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "x-user,xpassword")
 	assert.Equal(t, header.Get("Vary"), "Origin")
 	assert.Len(t, header, 2)
 }
@@ -135,18 +135,18 @@ func TestGeneratePreflightHeaders_AllowCredentials(t *testing.T) {
 
 func TestGeneratePreflightHeaders_AllowedMethods(t *testing.T) {
 	header := generatePreflightHeaders(Config{
-		AllowedMethods: []string{"GET ", "post", "PUT", " put  "},
+		AllowMethods: []string{"GET ", "post", "PUT", " put  "},
 	})
-	assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "get, post, put")
+	assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "get,post,put")
 	assert.Equal(t, header.Get("Vary"), "Origin")
 	assert.Len(t, header, 2)
 }
 
 func TestGeneratePreflightHeaders_AllowedHeaders(t *testing.T) {
 	header := generatePreflightHeaders(Config{
-		AllowedHeaders: []string{"X-user", "Content-Type"},
+		AllowHeaders: []string{"X-user", "Content-Type"},
 	})
-	assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "x-user, content-type")
+	assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "x-user,content-type")
 	assert.Equal(t, header.Get("Vary"), "Origin")
 	assert.Len(t, header, 2)
 }
@@ -169,7 +169,7 @@ func TestValidateOrigin(t *testing.T) {
 	assert.True(t, cors.validateOrigin("example.com"))
 
 	cors = newCors(Config{
-		AllowedOrigins: []string{"https://google.com", "https://github.com"},
+		AllowOrigins: []string{"https://google.com", "https://github.com"},
 		AllowOriginFunc: func(origin string) bool {
 			return (origin == "http://news.ycombinator.com")
 		},
@@ -184,10 +184,10 @@ func TestValidateOrigin(t *testing.T) {
 
 func TestPassesAllowedOrigins(t *testing.T) {
 	router := newTestRouter(Config{
-		AllowedOrigins:   []string{"http://google.com"},
-		AllowedMethods:   []string{" GeT ", "get", "post", "PUT  ", "Head", "POST"},
-		AllowedHeaders:   []string{"Content-type", "timeStamp "},
-		ExposedHeaders:   []string{"Data", "x-User"},
+		AllowOrigins:     []string{"http://google.com"},
+		AllowMethods:     []string{" GeT ", "get", "post", "PUT  ", "Head", "POST"},
+		AllowHeaders:     []string{"Content-type", "timeStamp "},
+		ExposeHeaders:    []string{"Data", "x-User"},
 		AllowCredentials: true,
 		MaxAge:           12 * time.Hour,
 		AllowOriginFunc: func(origin string) bool {
@@ -207,7 +207,7 @@ func TestPassesAllowedOrigins(t *testing.T) {
 	assert.Equal(t, w.Body.String(), "get")
 	assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://google.com")
 	assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true")
-	assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data, x-user")
+	assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data,x-user")
 
 	// deny CORS request
 	w = performRequest(router, "GET", "https://google.com")
@@ -221,8 +221,8 @@ func TestPassesAllowedOrigins(t *testing.T) {
 	assert.Equal(t, w.Code, 200)
 	assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://github.com")
 	assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true")
-	assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "get, post, put, head")
-	assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type, timestamp")
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "get,post,put,head")
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type,timestamp")
 	assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "43200")
 
 	// deny CORS prefligh request
@@ -238,9 +238,9 @@ func TestPassesAllowedOrigins(t *testing.T) {
 func TestPassesAllowedAllOrigins(t *testing.T) {
 	router := newTestRouter(Config{
 		AllowAllOrigins:  true,
-		AllowedMethods:   []string{" Patch ", "get", "post", "POST"},
-		AllowedHeaders:   []string{"Content-type", "  testheader "},
-		ExposedHeaders:   []string{"Data2", "x-User2"},
+		AllowMethods:     []string{" Patch ", "get", "post", "POST"},
+		AllowHeaders:     []string{"Content-type", "  testheader "},
+		ExposeHeaders:    []string{"Data2", "x-User2"},
 		AllowCredentials: false,
 		MaxAge:           10 * time.Hour,
 	})
@@ -256,15 +256,15 @@ func TestPassesAllowedAllOrigins(t *testing.T) {
 	w = performRequest(router, "POST", "example.com")
 	assert.Equal(t, w.Body.String(), "post")
 	assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*")
-	assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data2, x-user2")
+	assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data2,x-user2")
 	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
 
 	// allowed CORS prefligh request
 	w = performRequest(router, "OPTIONS", "https://facebook.com")
 	assert.Equal(t, w.Code, 200)
 	assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*")
-	assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "patch, get, post")
-	assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type, testheader")
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "patch,get,post")
+	assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type,testheader")
 	assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "36000")
 	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
 }

+ 9 - 9
utils.go

@@ -12,9 +12,9 @@ func generateNormalHeaders(c Config) http.Header {
 	if c.AllowCredentials {
 		headers.Set("Access-Control-Allow-Credentials", "true")
 	}
-	if len(c.ExposedHeaders) > 0 {
-		exposedHeaders := normalize(c.ExposedHeaders)
-		headers.Set("Access-Control-Expose-Headers", strings.Join(exposedHeaders, ", "))
+	if len(c.ExposeHeaders) > 0 {
+		exposeHeaders := normalize(c.ExposeHeaders)
+		headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
 	}
 	if c.AllowAllOrigins {
 		headers.Set("Access-Control-Allow-Origin", "*")
@@ -29,14 +29,14 @@ func generatePreflightHeaders(c Config) http.Header {
 	if c.AllowCredentials {
 		headers.Set("Access-Control-Allow-Credentials", "true")
 	}
-	if len(c.AllowedMethods) > 0 {
-		allowedMethods := normalize(c.AllowedMethods)
-		value := strings.Join(allowedMethods, ", ")
+	if len(c.AllowMethods) > 0 {
+		allowMethods := normalize(c.AllowMethods)
+		value := strings.Join(allowMethods, ",")
 		headers.Set("Access-Control-Allow-Methods", value)
 	}
-	if len(c.AllowedHeaders) > 0 {
-		allowedHeaders := normalize(c.AllowedHeaders)
-		value := strings.Join(allowedHeaders, ", ")
+	if len(c.AllowHeaders) > 0 {
+		allowHeaders := normalize(c.AllowHeaders)
+		value := strings.Join(allowHeaders, ",")
 		headers.Set("Access-Control-Allow-Headers", value)
 	}
 	if c.MaxAge > time.Duration(0) {