Kaynağa Gözat

Add support of all browser extensions schemas (#43)

At the moment only supported schemas are - **http://**, **https://** and ***** wildcard.
This PR adds **AllowedSchemas** configurations to add the **ability to extend schemas list**.
By this change, I want to add an option to **send requests from browser extensions**.
I saw that there was PR for only chrome-extension:// scheme, but this change is a more general way to handle that.
Maxim Ivanov 7 yıl önce
ebeveyn
işleme
a3af05a9b1
3 değiştirilmiş dosya ile 58 ekleme ve 2 silme
  1. 15 0
      config.go
  2. 26 2
      cors.go
  3. 17 0
      cors_test.go

+ 15 - 0
config.go

@@ -16,6 +16,19 @@ type cors struct {
 	preflightHeaders http.Header
 }
 
+var (
+	DefaultSchemas = []string{
+		"http://",
+		"https://",
+	}
+	ExtensionSchemas = []string{
+		"chrome-extension://",
+		"safari-extension://",
+		"moz-extension://",
+		"ms-browser-extension://",
+	}
+)
+
 func newCors(config Config) *cors {
 	if err := config.Validate(); err != nil {
 		panic(err.Error())
@@ -37,11 +50,13 @@ func (cors *cors) applyCors(c *gin.Context) {
 		return
 	}
 	host := c.Request.Header.Get("Host")
+
 	if origin == "http://"+host || origin == "https://"+host {
 		// request is not a CORS request but have origin header.
 		// for example, use fetch api
 		return
 	}
+
 	if !cors.validateOrigin(origin) {
 		c.AbortWithStatus(http.StatusForbidden)
 		return

+ 26 - 2
cors.go

@@ -41,6 +41,9 @@ type Config struct {
 	// MaxAge indicates how long (in seconds) the results of a preflight request
 	// can be cached
 	MaxAge time.Duration
+
+	// Allows usage of popular browser extensions schemas
+	AllowBrowserExtensions bool
 }
 
 // AddAllowMethods is allowed to add custom methods
@@ -58,6 +61,27 @@ func (c *Config) AddExposeHeaders(headers ...string) {
 	c.ExposeHeaders = append(c.ExposeHeaders, headers...)
 }
 
+func (c Config) getAllowedSchemas() []string {
+	allowedSchemas := DefaultSchemas
+	if c.AllowBrowserExtensions {
+		allowedSchemas = append(allowedSchemas, ExtensionSchemas...)
+	}
+
+	return allowedSchemas
+}
+
+func (c Config) validateAllowedSchemas(origin string) bool {
+	allowedSchemas := c.getAllowedSchemas()
+
+	for _, schema := range allowedSchemas {
+		if strings.HasPrefix(origin, schema) {
+			return true
+		}
+	}
+
+	return false
+}
+
 // Validate is check configuration of user defined.
 func (c Config) Validate() error {
 	if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
@@ -67,8 +91,8 @@ func (c Config) Validate() error {
 		return errors.New("conflict settings: all origins disabled")
 	}
 	for _, origin := range c.AllowOrigins {
-		if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
-			return errors.New("bad origin: origins must either be '*' or include http:// or https://")
+		if origin != "*" && !c.validateAllowedSchemas(origin) {
+			return errors.New("bad origin: origins must either be '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
 		}
 	}
 	return nil

+ 17 - 0
cors_test.go

@@ -209,12 +209,14 @@ func TestValidateOrigin(t *testing.T) {
 	assert.True(t, cors.validateOrigin("http://google.com"))
 	assert.True(t, cors.validateOrigin("https://google.com"))
 	assert.True(t, cors.validateOrigin("example.com"))
+	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
 
 	cors = newCors(Config{
 		AllowOrigins: []string{"https://google.com", "https://github.com"},
 		AllowOriginFunc: func(origin string) bool {
 			return (origin == "http://news.ycombinator.com")
 		},
+		AllowBrowserExtensions:true,
 	})
 	assert.False(t, cors.validateOrigin("http://google.com"))
 	assert.True(t, cors.validateOrigin("https://google.com"))
@@ -222,6 +224,21 @@ func TestValidateOrigin(t *testing.T) {
 	assert.True(t, cors.validateOrigin("http://news.ycombinator.com"))
 	assert.False(t, cors.validateOrigin("http://example.com"))
 	assert.False(t, cors.validateOrigin("google.com"))
+	assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
+
+	cors = newCors(Config{
+		AllowOrigins: []string{"https://google.com", "https://github.com"},
+	})
+	assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
+
+	cors = newCors(Config{
+		AllowOrigins:[]string{"chrome-extension://random-extension-id", "safari-extension://another-ext-id"},
+		AllowBrowserExtensions:true,
+
+	})
+	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
+	assert.True(t, cors.validateOrigin("safari-extension://another-ext-id"))
+	assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow"))
 }
 
 func TestPassesAllowedOrigins(t *testing.T) {