|
|
@@ -9,77 +9,136 @@ import (
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
"testing"
|
|
|
+
|
|
|
+ "github.com/stretchr/testify/assert"
|
|
|
)
|
|
|
|
|
|
-func TestBasicAuthSucceed(t *testing.T) {
|
|
|
- req, _ := http.NewRequest("GET", "/login", nil)
|
|
|
- w := httptest.NewRecorder()
|
|
|
+func TestBasicAuth(t *testing.T) {
|
|
|
+ accounts := Accounts{
|
|
|
+ "admin": "password",
|
|
|
+ "foo": "bar",
|
|
|
+ "bar": "foo",
|
|
|
+ }
|
|
|
+ expectedPairs := authPairs{
|
|
|
+ authPair{
|
|
|
+ User: "admin",
|
|
|
+ Value: "Basic YWRtaW46cGFzc3dvcmQ=",
|
|
|
+ },
|
|
|
+ authPair{
|
|
|
+ User: "bar",
|
|
|
+ Value: "Basic YmFyOmZvbw==",
|
|
|
+ },
|
|
|
+ authPair{
|
|
|
+ User: "foo",
|
|
|
+ Value: "Basic Zm9vOmJhcg==",
|
|
|
+ },
|
|
|
+ }
|
|
|
+ pairs := processAccounts(accounts)
|
|
|
+ assert.Equal(t, pairs, expectedPairs)
|
|
|
+}
|
|
|
|
|
|
- r := New()
|
|
|
- accounts := Accounts{"admin": "password"}
|
|
|
- r.Use(BasicAuth(accounts))
|
|
|
+func TestBasicAuthFails(t *testing.T) {
|
|
|
+ assert.Panics(t, func() { processAccounts(nil) })
|
|
|
+ assert.Panics(t, func() {
|
|
|
+ processAccounts(Accounts{
|
|
|
+ "": "password",
|
|
|
+ "foo": "bar",
|
|
|
+ })
|
|
|
+ })
|
|
|
+}
|
|
|
|
|
|
- r.GET("/login", func(c *Context) {
|
|
|
- c.String(200, "autorized")
|
|
|
+func TestBasicAuthSearchCredential(t *testing.T) {
|
|
|
+ pairs := processAccounts(Accounts{
|
|
|
+ "admin": "password",
|
|
|
+ "foo": "bar",
|
|
|
+ "bar": "foo",
|
|
|
})
|
|
|
|
|
|
- req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
|
|
|
- r.ServeHTTP(w, req)
|
|
|
+ user, found := pairs.searchCredential(authorizationHeader("admin", "password"))
|
|
|
+ assert.Equal(t, user, "admin")
|
|
|
+ assert.True(t, found)
|
|
|
|
|
|
- if w.Code != 200 {
|
|
|
- t.Errorf("Response code should be Ok, was: %d", w.Code)
|
|
|
- }
|
|
|
- bodyAsString := w.Body.String()
|
|
|
+ user, found = pairs.searchCredential(authorizationHeader("foo", "bar"))
|
|
|
+ assert.Equal(t, user, "foo")
|
|
|
+ assert.True(t, found)
|
|
|
|
|
|
- if bodyAsString != "autorized" {
|
|
|
- t.Errorf("Response body should be `autorized`, was %s", bodyAsString)
|
|
|
- }
|
|
|
+ user, found = pairs.searchCredential(authorizationHeader("bar", "foo"))
|
|
|
+ assert.Equal(t, user, "bar")
|
|
|
+ assert.True(t, found)
|
|
|
+
|
|
|
+ user, found = pairs.searchCredential(authorizationHeader("admins", "password"))
|
|
|
+ assert.Empty(t, user)
|
|
|
+ assert.False(t, found)
|
|
|
+
|
|
|
+ user, found = pairs.searchCredential(authorizationHeader("foo", "bar "))
|
|
|
+ assert.Empty(t, user)
|
|
|
+ assert.False(t, found)
|
|
|
}
|
|
|
|
|
|
-func TestBasicAuth401(t *testing.T) {
|
|
|
- req, _ := http.NewRequest("GET", "/login", nil)
|
|
|
+func TestBasicAuthAuthorizationHeader(t *testing.T) {
|
|
|
+ assert.Equal(t, authorizationHeader("admin", "password"), "Basic YWRtaW46cGFzc3dvcmQ=")
|
|
|
+}
|
|
|
+
|
|
|
+func TestBasicAuthSecureCompare(t *testing.T) {
|
|
|
+ assert.True(t, secureCompare("1234567890", "1234567890"))
|
|
|
+ assert.False(t, secureCompare("123456789", "1234567890"))
|
|
|
+ assert.False(t, secureCompare("12345678900", "1234567890"))
|
|
|
+ assert.False(t, secureCompare("1234567891", "1234567890"))
|
|
|
+}
|
|
|
+
|
|
|
+func TestBasicAuthSucceed(t *testing.T) {
|
|
|
+ accounts := Accounts{"admin": "password"}
|
|
|
+ router := New()
|
|
|
+ router.Use(BasicAuth(accounts))
|
|
|
+ router.GET("/login", func(c *Context) {
|
|
|
+ c.String(200, c.MustGet(AuthUserKey).(string))
|
|
|
+ })
|
|
|
+
|
|
|
w := httptest.NewRecorder()
|
|
|
+ req, _ := http.NewRequest("GET", "/login", nil)
|
|
|
+ req.Header.Set("Authorization", authorizationHeader("admin", "password"))
|
|
|
+ router.ServeHTTP(w, req)
|
|
|
|
|
|
- r := New()
|
|
|
- accounts := Accounts{"foo": "bar"}
|
|
|
- r.Use(BasicAuth(accounts))
|
|
|
+ assert.Equal(t, w.Code, 200)
|
|
|
+ assert.Equal(t, w.Body.String(), "admin")
|
|
|
+}
|
|
|
|
|
|
- r.GET("/login", func(c *Context) {
|
|
|
- c.String(200, "autorized")
|
|
|
+func TestBasicAuth401(t *testing.T) {
|
|
|
+ called := false
|
|
|
+ accounts := Accounts{"foo": "bar"}
|
|
|
+ router := New()
|
|
|
+ router.Use(BasicAuth(accounts))
|
|
|
+ router.GET("/login", func(c *Context) {
|
|
|
+ called = true
|
|
|
+ c.String(200, c.MustGet(AuthUserKey).(string))
|
|
|
})
|
|
|
|
|
|
+ w := httptest.NewRecorder()
|
|
|
+ req, _ := http.NewRequest("GET", "/login", nil)
|
|
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
|
|
|
- r.ServeHTTP(w, req)
|
|
|
-
|
|
|
- if w.Code != 401 {
|
|
|
- t.Errorf("Response code should be Not autorized, was: %d", w.Code)
|
|
|
- }
|
|
|
+ router.ServeHTTP(w, req)
|
|
|
|
|
|
- if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"Authorization Required\"" {
|
|
|
- t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
|
|
|
- }
|
|
|
+ assert.False(t, called)
|
|
|
+ assert.Equal(t, w.Code, 401)
|
|
|
+ assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"Authorization Required\"")
|
|
|
}
|
|
|
|
|
|
func TestBasicAuth401WithCustomRealm(t *testing.T) {
|
|
|
- req, _ := http.NewRequest("GET", "/login", nil)
|
|
|
- w := httptest.NewRecorder()
|
|
|
-
|
|
|
- r := New()
|
|
|
+ called := false
|
|
|
accounts := Accounts{"foo": "bar"}
|
|
|
- r.Use(BasicAuthForRealm(accounts, "My Custom Realm"))
|
|
|
-
|
|
|
- r.GET("/login", func(c *Context) {
|
|
|
- c.String(200, "autorized")
|
|
|
+ router := New()
|
|
|
+ router.Use(BasicAuthForRealm(accounts, "My Custom Realm"))
|
|
|
+ router.GET("/login", func(c *Context) {
|
|
|
+ called = true
|
|
|
+ c.String(200, c.MustGet(AuthUserKey).(string))
|
|
|
})
|
|
|
|
|
|
+ w := httptest.NewRecorder()
|
|
|
+ req, _ := http.NewRequest("GET", "/login", nil)
|
|
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
|
|
|
- r.ServeHTTP(w, req)
|
|
|
+ router.ServeHTTP(w, req)
|
|
|
|
|
|
- if w.Code != 401 {
|
|
|
- t.Errorf("Response code should be Not autorized, was: %d", w.Code)
|
|
|
- }
|
|
|
-
|
|
|
- if w.HeaderMap.Get("WWW-Authenticate") != "Basic realm=\"My Custom Realm\"" {
|
|
|
- t.Errorf("WWW-Authenticate header is incorrect: %s", w.HeaderMap.Get("Content-Type"))
|
|
|
- }
|
|
|
+ assert.False(t, called)
|
|
|
+ assert.Equal(t, w.Code, 401)
|
|
|
+ assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"My Custom Realm\"")
|
|
|
}
|