auth_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
  2. // Use of this source code is governed by a MIT style
  3. // license that can be found in the LICENSE file.
  4. package gin
  5. import (
  6. "encoding/base64"
  7. "net/http"
  8. "net/http/httptest"
  9. "testing"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. func TestBasicAuth(t *testing.T) {
  13. accounts := Accounts{
  14. "admin": "password",
  15. "foo": "bar",
  16. "bar": "foo",
  17. }
  18. expectedPairs := authPairs{
  19. authPair{
  20. User: "admin",
  21. Value: "Basic YWRtaW46cGFzc3dvcmQ=",
  22. },
  23. authPair{
  24. User: "foo",
  25. Value: "Basic Zm9vOmJhcg==",
  26. },
  27. authPair{
  28. User: "bar",
  29. Value: "Basic YmFyOmZvbw==",
  30. },
  31. }
  32. pairs := processAccounts(accounts)
  33. assert.Equal(t, pairs, expectedPairs)
  34. }
  35. func TestBasicAuthFails(t *testing.T) {
  36. assert.Panics(t, func() { processAccounts(nil) })
  37. assert.Panics(t, func() {
  38. processAccounts(Accounts{
  39. "": "password",
  40. "foo": "bar",
  41. })
  42. })
  43. }
  44. func TestBasicAuthSearchCredential(t *testing.T) {
  45. pairs := processAccounts(Accounts{
  46. "admin": "password",
  47. "foo": "bar",
  48. "bar": "foo",
  49. })
  50. user, found := pairs.searchCredential(authorizationHeader("admin", "password"))
  51. assert.Equal(t, user, "admin")
  52. assert.True(t, found)
  53. user, found = pairs.searchCredential(authorizationHeader("foo", "bar"))
  54. assert.Equal(t, user, "foo")
  55. assert.True(t, found)
  56. user, found = pairs.searchCredential(authorizationHeader("bar", "foo"))
  57. assert.Equal(t, user, "bar")
  58. assert.True(t, found)
  59. user, found = pairs.searchCredential(authorizationHeader("admins", "password"))
  60. assert.Empty(t, user)
  61. assert.False(t, found)
  62. user, found = pairs.searchCredential(authorizationHeader("foo", "bar "))
  63. assert.Empty(t, user)
  64. assert.False(t, found)
  65. user, found = pairs.searchCredential("")
  66. assert.Empty(t, user)
  67. assert.False(t, found)
  68. }
  69. func TestBasicAuthAuthorizationHeader(t *testing.T) {
  70. assert.Equal(t, authorizationHeader("admin", "password"), "Basic YWRtaW46cGFzc3dvcmQ=")
  71. }
  72. func TestBasicAuthSecureCompare(t *testing.T) {
  73. assert.True(t, secureCompare("1234567890", "1234567890"))
  74. assert.False(t, secureCompare("123456789", "1234567890"))
  75. assert.False(t, secureCompare("12345678900", "1234567890"))
  76. assert.False(t, secureCompare("1234567891", "1234567890"))
  77. }
  78. func TestBasicAuthSucceed(t *testing.T) {
  79. accounts := Accounts{"admin": "password"}
  80. router := New()
  81. router.Use(BasicAuth(accounts))
  82. router.GET("/login", func(c *Context) {
  83. c.String(200, c.MustGet(AuthUserKey).(string))
  84. })
  85. w := httptest.NewRecorder()
  86. req, _ := http.NewRequest("GET", "/login", nil)
  87. req.Header.Set("Authorization", authorizationHeader("admin", "password"))
  88. router.ServeHTTP(w, req)
  89. assert.Equal(t, w.Code, 200)
  90. assert.Equal(t, w.Body.String(), "admin")
  91. }
  92. func TestBasicAuth401(t *testing.T) {
  93. called := false
  94. accounts := Accounts{"foo": "bar"}
  95. router := New()
  96. router.Use(BasicAuth(accounts))
  97. router.GET("/login", func(c *Context) {
  98. called = true
  99. c.String(200, c.MustGet(AuthUserKey).(string))
  100. })
  101. w := httptest.NewRecorder()
  102. req, _ := http.NewRequest("GET", "/login", nil)
  103. req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
  104. router.ServeHTTP(w, req)
  105. assert.False(t, called)
  106. assert.Equal(t, w.Code, 401)
  107. assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"Authorization Required\"")
  108. }
  109. func TestBasicAuth401WithCustomRealm(t *testing.T) {
  110. called := false
  111. accounts := Accounts{"foo": "bar"}
  112. router := New()
  113. router.Use(BasicAuthForRealm(accounts, "My Custom \"Realm\""))
  114. router.GET("/login", func(c *Context) {
  115. called = true
  116. c.String(200, c.MustGet(AuthUserKey).(string))
  117. })
  118. w := httptest.NewRecorder()
  119. req, _ := http.NewRequest("GET", "/login", nil)
  120. req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password")))
  121. router.ServeHTTP(w, req)
  122. assert.False(t, called)
  123. assert.Equal(t, w.Code, 401)
  124. assert.Equal(t, w.HeaderMap.Get("WWW-Authenticate"), "Basic realm=\"My Custom \\\"Realm\\\"\"")
  125. }