فهرست منبع

v2http: refactor http basic auth

refactor http basic auth code to combine basic auth extraction and validation
rob boll 10 سال پیش
والد
کامیت
ab17165352
2فایلهای تغییر یافته به همراه99 افزوده شده و 26 حذف شده
  1. 28 26
      etcdserver/api/v2http/client_auth.go
  2. 71 0
      etcdserver/api/v2http/client_auth_test.go

+ 28 - 26
etcdserver/api/v2http/client_auth.go

@@ -37,6 +37,25 @@ func hasWriteRootAccess(sec auth.Store, r *http.Request) bool {
 	return hasRootAccess(sec, r)
 }
 
+func userFromBasicAuth(sec auth.Store, r *http.Request) *auth.User {
+	username, password, ok := r.BasicAuth()
+	if !ok {
+		plog.Warningf("auth: malformed basic auth encoding")
+		return nil
+	}
+	user, err := sec.GetUser(username)
+	if err != nil {
+		return nil
+	}
+
+	ok = sec.CheckPassword(user, password)
+	if !ok {
+		plog.Warningf("auth: incorrect password for user: %s", username)
+		return nil
+	}
+	return &user
+}
+
 func hasRootAccess(sec auth.Store, r *http.Request) bool {
 	if sec == nil {
 		// No store means no auth available, eg, tests.
@@ -45,26 +64,18 @@ func hasRootAccess(sec auth.Store, r *http.Request) bool {
 	if !sec.AuthEnabled() {
 		return true
 	}
-	username, password, ok := r.BasicAuth()
-	if !ok {
-		return false
-	}
-	rootUser, err := sec.GetUser(username)
-	if err != nil {
-		return false
-	}
 
-	ok = sec.CheckPassword(rootUser, password)
-	if !ok {
-		plog.Warningf("auth: wrong password for user %s", username)
+	rootUser := userFromBasicAuth(sec, r)
+	if rootUser == nil {
 		return false
 	}
+
 	for _, role := range rootUser.Roles {
 		if role == auth.RootRoleName {
 			return true
 		}
 	}
-	plog.Warningf("auth: user %s does not have the %s role for resource %s.", username, auth.RootRoleName, r.URL.Path)
+	plog.Warningf("auth: user %s does not have the %s role for resource %s.", rootUser.User, auth.RootRoleName, r.URL.Path)
 	return false
 }
 
@@ -80,21 +91,12 @@ func hasKeyPrefixAccess(sec auth.Store, r *http.Request, key string, recursive b
 		plog.Warningf("auth: no authorization provided, checking guest access")
 		return hasGuestAccess(sec, r, key)
 	}
-	username, password, ok := r.BasicAuth()
-	if !ok {
-		plog.Warningf("auth: malformed basic auth encoding")
-		return false
-	}
-	user, err := sec.GetUser(username)
-	if err != nil {
-		plog.Warningf("auth: no such user: %s.", username)
-		return false
-	}
-	authAsUser := sec.CheckPassword(user, password)
-	if !authAsUser {
-		plog.Warningf("auth: incorrect password for user: %s.", username)
+
+	user := userFromBasicAuth(sec, r)
+	if user == nil {
 		return false
 	}
+
 	writeAccess := r.Method != "GET" && r.Method != "HEAD"
 	for _, roleName := range user.Roles {
 		role, err := sec.GetRole(roleName)
@@ -109,7 +111,7 @@ func hasKeyPrefixAccess(sec auth.Store, r *http.Request, key string, recursive b
 			return true
 		}
 	}
-	plog.Warningf("auth: invalid access for user %s on key %s.", username, key)
+	plog.Warningf("auth: invalid access for user %s on key %s.", user.User, key)
 	return false
 }
 

+ 71 - 0
etcdserver/api/v2http/client_auth_test.go

@@ -440,6 +440,14 @@ func mustAuthRequest(method, username, password string) *http.Request {
 	return req
 }
 
+func unauthedRequest(method string) *http.Request {
+	req, err := http.NewRequest(method, "path", strings.NewReader(""))
+	if err != nil {
+		panic("Cannot make request: " + err.Error())
+	}
+	return req
+}
+
 func TestPrefixAccess(t *testing.T) {
 	var table = []struct {
 		key                string
@@ -701,3 +709,66 @@ func TestPrefixAccess(t *testing.T) {
 		}
 	}
 }
+
+func TestUserFromBasicAuth(t *testing.T) {
+	sec := &mockAuthStore{
+		users: map[string]*auth.User{
+			"user": {
+				User:     "user",
+				Roles:    []string{"root"},
+				Password: "password",
+			},
+		},
+		roles: map[string]*auth.Role{
+			"root": {
+				Role: "root",
+			},
+		},
+	}
+
+	var table = []struct {
+		username   string
+		req        *http.Request
+		userExists bool
+	}{
+		{
+			// valid user, valid pass
+			username:   "user",
+			req:        mustAuthRequest("GET", "user", "password"),
+			userExists: true,
+		},
+		{
+			// valid user, bad pass
+			username:   "user",
+			req:        mustAuthRequest("GET", "user", "badpass"),
+			userExists: false,
+		},
+		{
+			// valid user, no pass
+			username:   "user",
+			req:        mustAuthRequest("GET", "user", ""),
+			userExists: false,
+		},
+		{
+			// missing user
+			username:   "missing",
+			req:        mustAuthRequest("GET", "missing", "badpass"),
+			userExists: false,
+		},
+		{
+			// no basic auth
+			req:        unauthedRequest("GET"),
+			userExists: false,
+		},
+	}
+
+	for i, tt := range table {
+		user := userFromBasicAuth(sec, tt.req)
+		if tt.userExists == (user == nil) {
+			t.Errorf("#%d: userFromBasicAuth doesn't match (expected %v)", i, tt.userExists)
+		}
+		if user != nil && (tt.username != user.User) {
+			t.Errorf("#%d: userFromBasicAuth username doesn't match (expected %s, got %s)", i, tt.username, user.User)
+		}
+	}
+}