Browse Source

Handle domain wildcard + allow wilcard without schema

groovili 7 years ago
parent
commit
926da6846c
3 changed files with 70 additions and 10 deletions
  1. 21 1
      config.go
  2. 5 5
      cors.go
  3. 44 4
      cors_test.go

+ 21 - 1
config.go

@@ -4,7 +4,8 @@ import (
 	"net/http"
 
 	"github.com/gin-gonic/gin"
-		)
+	"strings"
+	)
 
 type cors struct {
 	allowAllOrigins  bool
@@ -77,6 +78,22 @@ func (cors *cors) applyCors(c *gin.Context) {
 	}
 }
 
+func (cors *cors) validateWildcardOrigin(origin string) bool {
+	for _, w := range cors.wildcardOrigins {
+		if w[0] == "*" && strings.HasSuffix(origin, w[1]) {
+			return true
+		}
+		if w[1] == "*" && strings.HasPrefix(origin, w[0]) {
+			return true
+		}
+		if strings.HasPrefix(origin, w[0]) && strings.HasSuffix(origin, w[1]) {
+			return true
+		}
+	}
+
+	return false
+}
+
 func (cors *cors) validateOrigin(origin string) bool {
 	if cors.allowAllOrigins {
 		return true
@@ -86,6 +103,9 @@ func (cors *cors) validateOrigin(origin string) bool {
 			return true
 		}
 	}
+	if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) {
+		return true
+	}
 	if cors.allowOriginFunc != nil {
 		return cors.allowOriginFunc(origin)
 	}

+ 5 - 5
cors.go

@@ -6,7 +6,7 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
-)
+	)
 
 // Config represents all available options for the middleware.
 type Config struct {
@@ -94,8 +94,8 @@ func (c Config) Validate() error {
 		return errors.New("conflict settings: all origins disabled")
 	}
 	for _, origin := range c.AllowOrigins {
-		if origin != "*" && !c.validateAllowedSchemas(origin) {
-			return errors.New("bad origin: origins must either be '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
+		if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) {
+			return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
 		}
 	}
 	return nil
@@ -114,7 +114,7 @@ func (c Config) parseWildcardRules() [][]string {
 		}
 
 		if c := strings.Count(o, "*"); c > 1 {
-			panic(errors.New("only one * allowed").Error())
+			panic(errors.New("only one * is allowed").Error())
 		}
 
 		i := strings.Index(o, "*")
@@ -127,7 +127,7 @@ func (c Config) parseWildcardRules() [][]string {
 			continue
 		}
 		if i != 0 && i != len(o) {
-			wRules = append(wRules, []string{o[:i], o[i:]})
+			wRules = append(wRules, []string{o[:i], o[i+1:]})
 		}
 	}
 

+ 44 - 4
cors_test.go

@@ -216,7 +216,7 @@ func TestValidateOrigin(t *testing.T) {
 		AllowOriginFunc: func(origin string) bool {
 			return (origin == "http://news.ycombinator.com")
 		},
-		AllowBrowserExtensions:true,
+		AllowBrowserExtensions: true,
 	})
 	assert.False(t, cors.validateOrigin("http://google.com"))
 	assert.True(t, cors.validateOrigin("https://google.com"))
@@ -232,9 +232,8 @@ func TestValidateOrigin(t *testing.T) {
 	assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
 
 	cors = newCors(Config{
-		AllowOrigins:[]string{"chrome-extension://random-extension-id", "safari-extension://another-ext-id"},
-		AllowBrowserExtensions:true,
-
+		AllowOrigins:           []string{"chrome-extension://random-extension-id", "safari-extension://another-ext-id"},
+		AllowBrowserExtensions: true,
 	})
 	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
 	assert.True(t, cors.validateOrigin("safari-extension://another-ext-id"))
@@ -341,4 +340,45 @@ func TestPassesAllowedAllOrigins(t *testing.T) {
 	assert.Equal(t, "Content-Type,Testheader", w.Header().Get("Access-Control-Allow-Headers"))
 	assert.Equal(t, "36000", w.Header().Get("Access-Control-Max-Age"))
 	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
+
+}
+
+func TestWildcard(t *testing.T) {
+	router := newTestRouter(Config{
+		AllowOrigins:  []string{"https://*.github.com", "https://api.*", "http://*", "https://facebook.com", "*.golang.org"},
+		AllowMethods:  []string{"GET"},
+		AllowWildcard: true,
+	})
+
+	w := performRequest(router, "GET", "https://gist.github.com")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://api.github.com/v1/users")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://giphy.com/")
+	assert.Equal(t, 403, w.Code)
+
+	w = performRequest(router, "GET", "http://hard-to-find-http-example.com")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://facebook.com")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://something.golang.org")
+	assert.Equal(t, 200, w.Code)
+
+	w = performRequest(router, "GET", "https://something.go.org")
+	assert.Equal(t, 403, w.Code)
+
+	router = newTestRouter(Config{
+		AllowOrigins: []string{"https://github.com", "https://facebook.com"},
+		AllowMethods: []string{"GET"},
+	})
+
+	w = performRequest(router, "GET", "https://gist.github.com")
+	assert.Equal(t, 403, w.Code)
+
+	w = performRequest(router, "GET", "https://github.com")
+	assert.Equal(t, 200, w.Code)
 }