Просмотр исходного кода

can only specify one origin in cors

kevin 3 лет назад
Родитель
Сommit
9c8f31cf83
2 измененных файлов с 36 добавлено и 23 удалено
  1. 6 8
      rest/handlers.go
  2. 30 15
      rest/handlers_test.go

+ 6 - 8
rest/handlers.go

@@ -1,26 +1,24 @@
 package rest
 
-import (
-	"net/http"
-	"strings"
-)
+import "net/http"
 
 const (
 	allowOrigin  = "Access-Control-Allow-Origin"
-	allOrigin    = "*"
+	allOrigins   = "*"
 	allowMethods = "Access-Control-Allow-Methods"
 	allowHeaders = "Access-Control-Allow-Headers"
 	headers      = "Content-Type, Content-Length, Origin"
 	methods      = "GET, HEAD, POST, PATCH, PUT, DELETE"
-	separator    = ", "
 )
 
+// CorsHandler handles cross domain OPTIONS requests.
+// At most one origin can be specified, other origins are ignored if given.
 func CorsHandler(origins ...string) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if len(origins) > 0 {
-			w.Header().Set(allowOrigin, strings.Join(origins, separator))
+			w.Header().Set(allowOrigin, origins[0])
 		} else {
-			w.Header().Set(allowOrigin, allOrigin)
+			w.Header().Set(allowOrigin, allOrigins)
 		}
 		w.Header().Set(allowMethods, methods)
 		w.Header().Set(allowHeaders, headers)

+ 30 - 15
rest/handlers_test.go

@@ -3,25 +3,40 @@ package rest
 import (
 	"net/http"
 	"net/http/httptest"
-	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
 )
 
-func TestCorsHandler(t *testing.T) {
-	w := httptest.NewRecorder()
-	handler := CorsHandler()
-	handler.ServeHTTP(w, nil)
-	assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
-	assert.Equal(t, allOrigin, w.Header().Get(allowOrigin))
-}
-
 func TestCorsHandlerWithOrigins(t *testing.T) {
-	origins := []string{"local", "remote"}
-	w := httptest.NewRecorder()
-	handler := CorsHandler(origins...)
-	handler.ServeHTTP(w, nil)
-	assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
-	assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin))
+	tests := []struct {
+		name    string
+		origins []string
+		expect  string
+	}{
+		{
+			name:   "allow all origins",
+			expect: allOrigins,
+		},
+		{
+			name:    "allow one origin",
+			origins: []string{"local"},
+			expect:  "local",
+		},
+		{
+			name:    "allow many origins",
+			origins: []string{"local", "remote"},
+			expect:  "local",
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			w := httptest.NewRecorder()
+			handler := CorsHandler(test.origins...)
+			handler.ServeHTTP(w, nil)
+			assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
+			assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
+		})
+	}
 }