Ver código fonte

Better unit tests for BasicAuth middleware

Manu Mtz-Almeida 10 anos atrás
pai
commit
a28104fa21
2 arquivos alterados com 124 adições e 60 exclusões
  1. 17 12
      auth.go
  2. 107 48
      auth_test.go

+ 17 - 12
auth.go

@@ -29,6 +29,19 @@ func (a authPairs) Len() int           { return len(a) }
 func (a authPairs) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
 func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value }
 
+func (a authPairs) searchCredential(auth string) (string, bool) {
+	if len(auth) == 0 {
+		return "", false
+	}
+	// Search user in the slice of allowed credentials
+	r := sort.Search(len(a), func(i int) bool { return a[i].Value >= auth })
+	if r < len(a) && secureCompare(a[r].Value, auth) {
+		return a[r].User, true
+	} else {
+		return "", false
+	}
+}
+
 // Implements a basic Basic HTTP Authorization. It takes as arguments a map[string]string where
 // the key is the user name and the value is the password, as well as the name of the Realm
 // (see http://tools.ietf.org/html/rfc2617#section-1.2)
@@ -40,7 +53,7 @@ func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc {
 	pairs := processAccounts(accounts)
 	return func(c *Context) {
 		// Search user in the slice of allowed credentials
-		user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization"))
+		user, ok := pairs.searchCredential(c.Request.Header.Get("Authorization"))
 		if !ok {
 			// Credentials doesn't match, we return 401 Unauthorized and abort request.
 			c.Writer.Header().Set("WWW-Authenticate", realm)
@@ -80,17 +93,9 @@ func processAccounts(accounts Accounts) authPairs {
 	return pairs
 }
 
-func searchCredential(pairs authPairs, auth string) (string, bool) {
-	if len(auth) == 0 {
-		return "", false
-	}
-	// Search user in the slice of allowed credentials
-	r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth })
-	if r < len(pairs) && secureCompare(pairs[r].Value, auth) {
-		return pairs[r].User, true
-	} else {
-		return "", false
-	}
+func authorizationHeader(user, password string) string {
+	base := user + ":" + password
+	return "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
 }
 
 func secureCompare(given, actual string) bool {

+ 107 - 48
auth_test.go

@@ -9,77 +9,136 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"testing"
+
+	"github.com/stretchr/testify/assert"
 )
 
-func TestBasicAuthSucceed(t *testing.T) {
-	req, _ := http.NewRequest("GET", "/login", nil)
-	w := httptest.NewRecorder()
+func TestBasicAuth(t *testing.T) {
+	accounts := Accounts{
+		"admin": "password",
+		"foo":   "bar",
+		"bar":   "foo",
+	}
+	expectedPairs := authPairs{
+		authPair{
+			User:  "admin",
+			Value: "Basic YWRtaW46cGFzc3dvcmQ=",
+		},
+		authPair{
+			User:  "bar",
+			Value: "Basic YmFyOmZvbw==",
+		},
+		authPair{
+			User:  "foo",
+			Value: "Basic Zm9vOmJhcg==",
+		},
+	}
+	pairs := processAccounts(accounts)
+	assert.Equal(t, pairs, expectedPairs)
+}
 
-	r := New()
-	accounts := Accounts{"admin": "password"}
-	r.Use(BasicAuth(accounts))
+func TestBasicAuthFails(t *testing.T) {
+	assert.Panics(t, func() { processAccounts(nil) })
+	assert.Panics(t, func() {
+		processAccounts(Accounts{
+			"":    "password",
+			"foo": "bar",
+		})
+	})
+}
 
-	r.GET("/login", func(c *Context) {
-		c.String(200, "autorized")
+func TestBasicAuthSearchCredential(t *testing.T) {
+	pairs := processAccounts(Accounts{
+		"admin": "password",
+		"foo":   "bar",
+		"bar":   "foo",
 	})
 
-	req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
-	r.ServeHTTP(w, req)
+	user, found := pairs.searchCredential(authorizationHeader("admin", "password"))
+	assert.Equal(t, user, "admin")
+	assert.True(t, found)
 
-	if w.Code != 200 {
-		t.Errorf("Response code should be Ok, was: %d", w.Code)
-	}
-	bodyAsString := w.Body.String()
+	user, found = pairs.searchCredential(authorizationHeader("foo", "bar"))
+	assert.Equal(t, user, "foo")
+	assert.True(t, found)
 
-	if bodyAsString != "autorized" {
-		t.Errorf("Response body should be `autorized`, was  %s", bodyAsString)
-	}
+	user, found = pairs.searchCredential(authorizationHeader("bar", "foo"))
+	assert.Equal(t, user, "bar")
+	assert.True(t, found)
+
+	user, found = pairs.searchCredential(authorizationHeader("admins", "password"))
+	assert.Empty(t, user)
+	assert.False(t, found)
+
+	user, found = pairs.searchCredential(authorizationHeader("foo", "bar "))
+	assert.Empty(t, user)
+	assert.False(t, found)
 }
 
-func TestBasicAuth401(t *testing.T) {
-	req, _ := http.NewRequest("GET", "/login", nil)
+func TestBasicAuthAuthorizationHeader(t *testing.T) {
+	assert.Equal(t, authorizationHeader("admin", "password"), "Basic YWRtaW46cGFzc3dvcmQ=")
+}
+
+func TestBasicAuthSecureCompare(t *testing.T) {
+	assert.True(t, secureCompare("1234567890", "1234567890"))
+	assert.False(t, secureCompare("123456789", "1234567890"))
+	assert.False(t, secureCompare("12345678900", "1234567890"))
+	assert.False(t, secureCompare("1234567891", "1234567890"))
+}
+
+func TestBasicAuthSucceed(t *testing.T) {
+	accounts := Accounts{"admin": "password"}
+	router := New()
+	router.Use(BasicAuth(accounts))
+	router.GET("/login", func(c *Context) {
+		c.String(200, c.MustGet(AuthUserKey).(string))
+	})
+
 	w := httptest.NewRecorder()
+	req, _ := http.NewRequest("GET", "/login", nil)
+	req.Header.Set("Authorization", authorizationHeader("admin", "password"))
+	router.ServeHTTP(w, req)
 
-	r := New()
-	accounts := Accounts{"foo": "bar"}
-	r.Use(BasicAuth(accounts))
+	assert.Equal(t, w.Code, 200)
+	assert.Equal(t, w.Body.String(), "admin")
+}
 
-	r.GET("/login", func(c *Context) {
-		c.String(200, "autorized")
+func TestBasicAuth401(t *testing.T) {
+	called := false
+	accounts := Accounts{"foo": "bar"}
+	router := New()
+	router.Use(BasicAuth(accounts))
+	router.GET("/login", func(c *Context) {
+		called = true
+		c.String(200, c.MustGet(AuthUserKey).(string))
 	})
 
+	w := httptest.NewRecorder()
+	req, _ := http.NewRequest("GET", "/login", nil)
 	req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
-	r.ServeHTTP(w, req)
-
-	if w.Code != 401 {
-		t.Errorf("Response code should be Not autorized, was: %d", w.Code)
-	}
+	router.ServeHTTP(w, req)
 
-	if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"Authorization Required\"" {
-		t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
-	}
+	assert.False(t, called)
+	assert.Equal(t, w.Code, 401)
+	assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"Authorization Required\"")
 }
 
 func TestBasicAuth401WithCustomRealm(t *testing.T) {
-	req, _ := http.NewRequest("GET", "/login", nil)
-	w := httptest.NewRecorder()
-
-	r := New()
+	called := false
 	accounts := Accounts{"foo": "bar"}
-	r.Use(BasicAuthForRealm(accounts, "My Custom Realm"))
-
-	r.GET("/login", func(c *Context) {
-		c.String(200, "autorized")
+	router := New()
+	router.Use(BasicAuthForRealm(accounts, "My Custom Realm"))
+	router.GET("/login", func(c *Context) {
+		called = true
+		c.String(200, c.MustGet(AuthUserKey).(string))
 	})
 
+	w := httptest.NewRecorder()
+	req, _ := http.NewRequest("GET", "/login", nil)
 	req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
-	r.ServeHTTP(w, req)
+	router.ServeHTTP(w, req)
 
-	if w.Code != 401 {
-		t.Errorf("Response code should be Not autorized, was: %d", w.Code)
-	}
-
-	if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"My Custom Realm\"" {
-		t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
-	}
+	assert.False(t, called)
+	assert.Equal(t, w.Code, 401)
+	assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"My Custom Realm\"")
 }