Browse Source

If "*" value is present in the list, all origins will be allowed. (#47)

According to godoc
```
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
// If the special "*" value is present in the list, all origins will be allowed.
```
This feature is missing.
So this pull request is about to fix it.
billy.Yuan 7 years ago
parent
commit
7c641a7a7d
2 changed files with 13 additions and 2 deletions
  1. 5 2
      cors.go
  2. 8 0
      cors_test.go

+ 5 - 2
cors.go

@@ -95,7 +95,7 @@ func (c Config) validateAllowedSchemas(origin string) bool {
 }
 
 // Validate is check configuration of user defined.
-func (c Config) Validate() error {
+func (c *Config) Validate() error {
 	if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
 		return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
 	}
@@ -103,7 +103,10 @@ func (c Config) Validate() error {
 		return errors.New("conflict settings: all origins disabled")
 	}
 	for _, origin := range c.AllowOrigins {
-		if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) {
+		if origin == "*" {
+			c.AllowAllOrigins = true
+			return nil
+		} else if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) {
 			return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
 		}
 	}

+ 8 - 0
cors_test.go

@@ -255,6 +255,14 @@ func TestValidateOrigin(t *testing.T) {
 	assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
 	assert.True(t, cors.validateOrigin("wss://some-session-layer-connection"))
 	assert.False(t, cors.validateOrigin("ws://not-what-we-expected"))
+
+	cors = newCors(Config{
+		AllowOrigins: []string{"*"},
+	})
+	assert.True(t, cors.validateOrigin("http://google.com"))
+	assert.True(t, cors.validateOrigin("https://google.com"))
+	assert.True(t, cors.validateOrigin("example.com"))
+	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
 }
 
 func TestPassesAllowedOrigins(t *testing.T) {