Browse Source

Wildcard rules parsing

groovili 7 years ago
parent
commit
45ed25afae
2 changed files with 40 additions and 1 deletions
  1. 4 1
      config.go
  2. 36 0
      cors.go

+ 4 - 1
config.go

@@ -4,7 +4,7 @@ import (
 	"net/http"
 
 	"github.com/gin-gonic/gin"
-)
+		)
 
 type cors struct {
 	allowAllOrigins  bool
@@ -14,6 +14,7 @@ type cors struct {
 	exposeHeaders    []string
 	normalHeaders    http.Header
 	preflightHeaders http.Header
+	wildcardOrigins  [][]string
 }
 
 var (
@@ -33,6 +34,7 @@ func newCors(config Config) *cors {
 	if err := config.Validate(); err != nil {
 		panic(err.Error())
 	}
+
 	return &cors{
 		allowOriginFunc:  config.AllowOriginFunc,
 		allowAllOrigins:  config.AllowAllOrigins,
@@ -40,6 +42,7 @@ func newCors(config Config) *cors {
 		allowOrigins:     normalize(config.AllowOrigins),
 		normalHeaders:    generateNormalHeaders(config),
 		preflightHeaders: generatePreflightHeaders(config),
+		wildcardOrigins:  config.parseWildcardRules(),
 	}
 }
 

+ 36 - 0
cors.go

@@ -42,6 +42,9 @@ type Config struct {
 	// can be cached
 	MaxAge time.Duration
 
+	// Allows to add origins like http://some-domain/*, https://api.* or http://some.*.subdomain.com
+	AllowWildcard bool
+
 	// Allows usage of popular browser extensions schemas
 	AllowBrowserExtensions bool
 }
@@ -98,6 +101,39 @@ func (c Config) Validate() error {
 	return nil
 }
 
+func (c Config) parseWildcardRules() [][]string {
+	var wRules [][]string
+
+	if !c.AllowWildcard {
+		return wRules
+	}
+
+	for _, o := range c.AllowOrigins {
+		if !strings.Contains(o, "*") {
+			continue
+		}
+
+		if c := strings.Count(o, "*"); c > 1 {
+			panic(errors.New("only one * allowed").Error())
+		}
+
+		i := strings.Index(o, "*")
+		if i == 0 {
+			wRules = append(wRules, []string{"*", o[1:]})
+			continue
+		}
+		if i == len(o) {
+			wRules = append(wRules, []string{o[:i-1], "*"})
+			continue
+		}
+		if i != 0 && i != len(o) {
+			wRules = append(wRules, []string{o[:i], o[i:]})
+		}
+	}
+
+	return wRules
+}
+
 // DefaultConfig returns a generic default configuration mapped to localhost.
 func DefaultConfig() Config {
 	return Config{