auth_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
  2. // Use of this source code is governed by a MIT style
  3. // license that can be found in the LICENSE file.
  4. package gin
  5. import (
  6. "encoding/base64"
  7. "net/http"
  8. "net/http/httptest"
  9. "testing"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. func TestBasicAuth(t *testing.T) {
  13. pairs := processAccounts(Accounts{
  14. "admin": "password",
  15. "foo": "bar",
  16. "bar": "foo",
  17. })
  18. assert.Len(t, pairs, 3)
  19. assert.Contains(t, pairs, authPair{
  20. user: "bar",
  21. value: "Basic YmFyOmZvbw==",
  22. })
  23. assert.Contains(t, pairs, authPair{
  24. user: "foo",
  25. value: "Basic Zm9vOmJhcg==",
  26. })
  27. assert.Contains(t, pairs, authPair{
  28. user: "admin",
  29. value: "Basic YWRtaW46cGFzc3dvcmQ=",
  30. })
  31. }
  32. func TestBasicAuthFails(t *testing.T) {
  33. assert.Panics(t, func() { processAccounts(nil) })
  34. assert.Panics(t, func() {
  35. processAccounts(Accounts{
  36. "": "password",
  37. "foo": "bar",
  38. })
  39. })
  40. }
  41. func TestBasicAuthSearchCredential(t *testing.T) {
  42. pairs := processAccounts(Accounts{
  43. "admin": "password",
  44. "foo": "bar",
  45. "bar": "foo",
  46. })
  47. user, found := pairs.searchCredential(authorizationHeader("admin", "password"))
  48. assert.Equal(t, "admin", user)
  49. assert.True(t, found)
  50. user, found = pairs.searchCredential(authorizationHeader("foo", "bar"))
  51. assert.Equal(t, "foo", user)
  52. assert.True(t, found)
  53. user, found = pairs.searchCredential(authorizationHeader("bar", "foo"))
  54. assert.Equal(t, "bar", user)
  55. assert.True(t, found)
  56. user, found = pairs.searchCredential(authorizationHeader("admins", "password"))
  57. assert.Empty(t, user)
  58. assert.False(t, found)
  59. user, found = pairs.searchCredential(authorizationHeader("foo", "bar "))
  60. assert.Empty(t, user)
  61. assert.False(t, found)
  62. user, found = pairs.searchCredential("")
  63. assert.Empty(t, user)
  64. assert.False(t, found)
  65. }
  66. func TestBasicAuthAuthorizationHeader(t *testing.T) {
  67. assert.Equal(t, "Basic YWRtaW46cGFzc3dvcmQ=", authorizationHeader("admin", "password"))
  68. }
  69. func TestBasicAuthSucceed(t *testing.T) {
  70. accounts := Accounts{"admin": "password"}
  71. router := New()
  72. router.Use(BasicAuth(accounts))
  73. router.GET("/login", func(c *Context) {
  74. c.String(http.StatusOK, c.MustGet(AuthUserKey).(string))
  75. })
  76. w := httptest.NewRecorder()
  77. req, _ := http.NewRequest("GET", "/login", nil)
  78. req.Header.Set("Authorization", authorizationHeader("admin", "password"))
  79. router.ServeHTTP(w, req)
  80. assert.Equal(t, http.StatusOK, w.Code)
  81. assert.Equal(t, "admin", w.Body.String())
  82. }
  83. func TestBasicAuth401(t *testing.T) {
  84. called := false
  85. accounts := Accounts{"foo": "bar"}
  86. router := New()
  87. router.Use(BasicAuth(accounts))
  88. router.GET("/login", func(c *Context) {
  89. called = true
  90. c.String(http.StatusOK, c.MustGet(AuthUserKey).(string))
  91. })
  92. w := httptest.NewRecorder()
  93. req, _ := http.NewRequest("GET", "/login", nil)
  94. req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
  95. router.ServeHTTP(w, req)
  96. assert.False(t, called)
  97. assert.Equal(t, http.StatusUnauthorized, w.Code)
  98. assert.Equal(t, "Basic realm=\"Authorization Required\"", w.Header().Get("WWW-Authenticate"))
  99. }
  100. func TestBasicAuth401WithCustomRealm(t *testing.T) {
  101. called := false
  102. accounts := Accounts{"foo": "bar"}
  103. router := New()
  104. router.Use(BasicAuthForRealm(accounts, "My Custom \"Realm\""))
  105. router.GET("/login", func(c *Context) {
  106. called = true
  107. c.String(http.StatusOK, c.MustGet(AuthUserKey).(string))
  108. })
  109. w := httptest.NewRecorder()
  110. req, _ := http.NewRequest("GET", "/login", nil)
  111. req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
  112. router.ServeHTTP(w, req)
  113. assert.False(t, called)
  114. assert.Equal(t, http.StatusUnauthorized, w.Code)
  115. assert.Equal(t, "Basic realm=\"My Custom \\\"Realm\\\"\"", w.Header().Get("WWW-Authenticate"))
  116. }