瀏覽代碼

Merge pull request #44 from groovili/groovili/allow-domain-wildcard

Add domain wildcard
田欧 7 年之前
父節點
當前提交
488de3ec97
共有 3 個文件被更改,包括 108 次插入4 次删除
  1. 23 0
      config.go
  2. 37 2
      cors.go
  3. 48 2
      cors_test.go

+ 23 - 0
config.go

@@ -2,6 +2,7 @@ package cors
 
 import (
 	"net/http"
+	"strings"
 
 	"github.com/gin-gonic/gin"
 )
@@ -14,6 +15,7 @@ type cors struct {
 	exposeHeaders    []string
 	normalHeaders    http.Header
 	preflightHeaders http.Header
+	wildcardOrigins  [][]string
 }
 
 var (
@@ -40,6 +42,7 @@ func newCors(config Config) *cors {
 	if err := config.Validate(); err != nil {
 		panic(err.Error())
 	}
+
 	return &cors{
 		allowOriginFunc:  config.AllowOriginFunc,
 		allowAllOrigins:  config.AllowAllOrigins,
@@ -47,6 +50,7 @@ func newCors(config Config) *cors {
 		allowOrigins:     normalize(config.AllowOrigins),
 		normalHeaders:    generateNormalHeaders(config),
 		preflightHeaders: generatePreflightHeaders(config),
+		wildcardOrigins:  config.parseWildcardRules(),
 	}
 }
 
@@ -81,6 +85,22 @@ func (cors *cors) applyCors(c *gin.Context) {
 	}
 }
 
+func (cors *cors) validateWildcardOrigin(origin string) bool {
+	for _, w := range cors.wildcardOrigins {
+		if w[0] == "*" && strings.HasSuffix(origin, w[1]) {
+			return true
+		}
+		if w[1] == "*" && strings.HasPrefix(origin, w[0]) {
+			return true
+		}
+		if strings.HasPrefix(origin, w[0]) && strings.HasSuffix(origin, w[1]) {
+			return true
+		}
+	}
+
+	return false
+}
+
 func (cors *cors) validateOrigin(origin string) bool {
 	if cors.allowAllOrigins {
 		return true
@@ -90,6 +110,9 @@ func (cors *cors) validateOrigin(origin string) bool {
 			return true
 		}
 	}
+	if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) {
+		return true
+	}
 	if cors.allowOriginFunc != nil {
 		return cors.allowOriginFunc(origin)
 	}

+ 37 - 2
cors.go

@@ -42,6 +42,9 @@ type Config struct {
 	// can be cached
 	MaxAge time.Duration
 
+	// Allows to add origins like http://some-domain/*, https://api.* or http://some.*.subdomain.com
+	AllowWildcard bool
+
 	// Allows usage of popular browser extensions schemas
 	AllowBrowserExtensions bool
 
@@ -100,13 +103,45 @@ func (c Config) Validate() error {
 		return errors.New("conflict settings: all origins disabled")
 	}
 	for _, origin := range c.AllowOrigins {
-		if origin != "*" && !c.validateAllowedSchemas(origin) {
-			return errors.New("bad origin: origins must either be '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
+		if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) {
+			return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
 		}
 	}
 	return nil
 }
 
+func (c Config) parseWildcardRules() [][]string {
+	var wRules [][]string
+
+	if !c.AllowWildcard {
+		return wRules
+	}
+
+	for _, o := range c.AllowOrigins {
+		if !strings.Contains(o, "*") {
+			continue
+		}
+
+		if c := strings.Count(o, "*"); c > 1 {
+			panic(errors.New("only one * is allowed").Error())
+		}
+
+		i := strings.Index(o, "*")
+		if i == 0 {
+			wRules = append(wRules, []string{"*", o[1:]})
+			continue
+		}
+		if i == (len(o) - 1) {
+			wRules = append(wRules, []string{o[:i-1], "*"})
+			continue
+		}
+
+		wRules = append(wRules, []string{o[:i], o[i+1:]})
+	}
+
+	return wRules
+}
+
 // DefaultConfig returns a generic default configuration mapped to localhost.
 func DefaultConfig() Config {
 	return Config{

+ 48 - 2
cors_test.go

@@ -234,12 +234,17 @@ func TestValidateOrigin(t *testing.T) {
 	assert.False(t, cors.validateOrigin("wss://socket-connection"))
 
 	cors = newCors(Config{
-		AllowOrigins:           []string{"chrome-extension://random-extension-id", "safari-extension://another-ext-id"},
+		AllowOrigins:           []string{"chrome-extension://*", "safari-extension://my-extension-*-app", "*.some-domain.com"},
 		AllowBrowserExtensions: true,
+		AllowWildcard:          true,
 	})
 	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
-	assert.True(t, cors.validateOrigin("safari-extension://another-ext-id"))
+	assert.True(t, cors.validateOrigin("chrome-extension://another-one"))
+	assert.True(t, cors.validateOrigin("safari-extension://my-extension-one-app"))
+	assert.True(t, cors.validateOrigin("safari-extension://my-extension-two-app"))
 	assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow"))
+	assert.True(t, cors.validateOrigin("http://api.some-domain.com"))
+	assert.False(t, cors.validateOrigin("http://api.another-domain.com"))
 
 	cors = newCors(Config{
 		AllowOrigins:    []string{"file://safe-file.js", "wss://some-session-layer-connection"},
@@ -352,4 +357,45 @@ func TestPassesAllowedAllOrigins(t *testing.T) {
 	assert.Equal(t, "Content-Type,Testheader", w.Header().Get("Access-Control-Allow-Headers"))
 	assert.Equal(t, "36000", w.Header().Get("Access-Control-Max-Age"))
 	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
+
+}
+
+func TestWildcard(t *testing.T) {
+	router := newTestRouter(Config{
+		AllowOrigins:  []string{"https://*.github.com", "https://api.*", "http://*", "https://facebook.com", "*.golang.org"},
+		AllowMethods:  []string{"GET"},
+		AllowWildcard: true,
+	})
+
+	w := performRequest(router, "GET", "https://gist.github.com")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://api.github.com/v1/users")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://giphy.com/")
+	assert.Equal(t, 403, w.Code)
+
+	w = performRequest(router, "GET", "http://hard-to-find-http-example.com")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://facebook.com")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://something.golang.org")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://something.go.org")
+	assert.Equal(t, 403, w.Code)
+
+	router = newTestRouter(Config{
+		AllowOrigins: []string{"https://github.com", "https://facebook.com"},
+		AllowMethods: []string{"GET"},
+	})
+
+	w = performRequest(router, "GET", "https://gist.github.com")
+	assert.Equal(t, 403, w.Code)
+
+	w = performRequest(router, "GET", "https://github.com")
+	assert.Equal(t, 200, w.Code)
 }