Преглед изворни кода

Add usage of websockets and files schemas (#45)

This change will allow usage of ```ws://```, ```wss://``` and ```file://``` schemas by adding additional boolean configs.
Max Ivanov пре 7 година
родитељ
комит
214f318e72
3 измењених фајлова са 34 додато и 7 уклоњено
  1. 7 0
      config.go
  2. 12 3
      cors.go
  3. 15 4
      cors_test.go

+ 7 - 0
config.go

@@ -27,6 +27,13 @@ var (
 		"moz-extension://",
 		"ms-browser-extension://",
 	}
+	FileSchemas = []string{
+		"file://",
+	}
+	WebSocketSchemas = []string{
+		"ws://",
+		"wss://",
+	}
 )
 
 func newCors(config Config) *cors {

+ 12 - 3
cors.go

@@ -44,6 +44,12 @@ type Config struct {
 
 	// Allows usage of popular browser extensions schemas
 	AllowBrowserExtensions bool
+
+	// Allows usage of WebSocket protocol
+	AllowWebSockets bool
+
+	// Allows usage of file:// schema (dangerous!) use it only when you 100% sure it's needed
+	AllowFiles bool
 }
 
 // AddAllowMethods is allowed to add custom methods
@@ -66,19 +72,22 @@ func (c Config) getAllowedSchemas() []string {
 	if c.AllowBrowserExtensions {
 		allowedSchemas = append(allowedSchemas, ExtensionSchemas...)
 	}
-
+	if c.AllowWebSockets {
+		allowedSchemas = append(allowedSchemas, WebSocketSchemas...)
+	}
+	if c.AllowFiles {
+		allowedSchemas = append(allowedSchemas, FileSchemas...)
+	}
 	return allowedSchemas
 }
 
 func (c Config) validateAllowedSchemas(origin string) bool {
 	allowedSchemas := c.getAllowedSchemas()
-
 	for _, schema := range allowedSchemas {
 		if strings.HasPrefix(origin, schema) {
 			return true
 		}
 	}
-
 	return false
 }
 

+ 15 - 4
cors_test.go

@@ -216,7 +216,7 @@ func TestValidateOrigin(t *testing.T) {
 		AllowOriginFunc: func(origin string) bool {
 			return (origin == "http://news.ycombinator.com")
 		},
-		AllowBrowserExtensions:true,
+		AllowBrowserExtensions: true,
 	})
 	assert.False(t, cors.validateOrigin("http://google.com"))
 	assert.True(t, cors.validateOrigin("https://google.com"))
@@ -230,15 +230,26 @@ func TestValidateOrigin(t *testing.T) {
 		AllowOrigins: []string{"https://google.com", "https://github.com"},
 	})
 	assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
+	assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
+	assert.False(t, cors.validateOrigin("wss://socket-connection"))
 
 	cors = newCors(Config{
-		AllowOrigins:[]string{"chrome-extension://random-extension-id", "safari-extension://another-ext-id"},
-		AllowBrowserExtensions:true,
-
+		AllowOrigins:           []string{"chrome-extension://random-extension-id", "safari-extension://another-ext-id"},
+		AllowBrowserExtensions: true,
 	})
 	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
 	assert.True(t, cors.validateOrigin("safari-extension://another-ext-id"))
 	assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow"))
+
+	cors = newCors(Config{
+		AllowOrigins:    []string{"file://safe-file.js", "wss://some-session-layer-connection"},
+		AllowFiles:      true,
+		AllowWebSockets: true,
+	})
+	assert.True(t, cors.validateOrigin("file://safe-file.js"))
+	assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
+	assert.True(t, cors.validateOrigin("wss://some-session-layer-connection"))
+	assert.False(t, cors.validateOrigin("ws://not-what-we-expected"))
 }
 
 func TestPassesAllowedOrigins(t *testing.T) {