Browse Source

v2http: refactor http basic auth

refactor http basic auth code to combine basic auth extraction and validation
rob boll 10 years ago
parent
commit
ab17165352
2 changed files with 99 additions and 26 deletions
  1. 28 26
      etcdserver/api/v2http/client_auth.go
  2. 71 0
      etcdserver/api/v2http/client_auth_test.go

+ 28 - 26
etcdserver/api/v2http/client_auth.go

@@ -37,6 +37,25 @@ func hasWriteRootAccess(sec auth.Store, r *http.Request) bool {
 	return hasRootAccess(sec, r)
 }
 
+func userFromBasicAuth(sec auth.Store, r *http.Request) *auth.User {
+	username, password, ok := r.BasicAuth()
+	if !ok {
+		plog.Warningf("auth: malformed basic auth encoding")
+		return nil
+	}
+	user, err := sec.GetUser(username)
+	if err != nil {
+		return nil
+	}
+
+	ok = sec.CheckPassword(user, password)
+	if !ok {
+		plog.Warningf("auth: incorrect password for user: %s", username)
+		return nil
+	}
+	return &user
+}
+
 func hasRootAccess(sec auth.Store, r *http.Request) bool {
 	if sec == nil {
 		// No store means no auth available, eg, tests.
@@ -45,26 +64,18 @@ func hasRootAccess(sec auth.Store, r *http.Request) bool {
 	if !sec.AuthEnabled() {
 		return true
 	}
-	username, password, ok := r.BasicAuth()
-	if !ok {
-		return false
-	}
-	rootUser, err := sec.GetUser(username)
-	if err != nil {
-		return false
-	}
 
-	ok = sec.CheckPassword(rootUser, password)
-	if !ok {
-		plog.Warningf("auth: wrong password for user %s", username)
+	rootUser := userFromBasicAuth(sec, r)
+	if rootUser == nil {
 		return false
 	}
+
 	for _, role := range rootUser.Roles {
 		if role == auth.RootRoleName {
 			return true
 		}
 	}
-	plog.Warningf("auth: user %s does not have the %s role for resource %s.", username, auth.RootRoleName, r.URL.Path)
+	plog.Warningf("auth: user %s does not have the %s role for resource %s.", rootUser.User, auth.RootRoleName, r.URL.Path)
 	return false
 }
 
@@ -80,21 +91,12 @@ func hasKeyPrefixAccess(sec auth.Store, r *http.Request, key string, recursive b
 		plog.Warningf("auth: no authorization provided, checking guest access")
 		return hasGuestAccess(sec, r, key)
 	}
-	username, password, ok := r.BasicAuth()
-	if !ok {
-		plog.Warningf("auth: malformed basic auth encoding")
-		return false
-	}
-	user, err := sec.GetUser(username)
-	if err != nil {
-		plog.Warningf("auth: no such user: %s.", username)
-		return false
-	}
-	authAsUser := sec.CheckPassword(user, password)
-	if !authAsUser {
-		plog.Warningf("auth: incorrect password for user: %s.", username)
+
+	user := userFromBasicAuth(sec, r)
+	if user == nil {
 		return false
 	}
+
 	writeAccess := r.Method != "GET" && r.Method != "HEAD"
 	for _, roleName := range user.Roles {
 		role, err := sec.GetRole(roleName)
@@ -109,7 +111,7 @@ func hasKeyPrefixAccess(sec auth.Store, r *http.Request, key string, recursive b
 			return true
 		}
 	}
-	plog.Warningf("auth: invalid access for user %s on key %s.", username, key)
+	plog.Warningf("auth: invalid access for user %s on key %s.", user.User, key)
 	return false
 }
 

+ 71 - 0
etcdserver/api/v2http/client_auth_test.go

@@ -440,6 +440,14 @@ func mustAuthRequest(method, username, password string) *http.Request {
 	return req
 }
 
+func unauthedRequest(method string) *http.Request {
+	req, err := http.NewRequest(method, "path", strings.NewReader(""))
+	if err != nil {
+		panic("Cannot make request: " + err.Error())
+	}
+	return req
+}
+
 func TestPrefixAccess(t *testing.T) {
 	var table = []struct {
 		key                string
@@ -701,3 +709,66 @@ func TestPrefixAccess(t *testing.T) {
 		}
 	}
 }
+
+func TestUserFromBasicAuth(t *testing.T) {
+	sec := &mockAuthStore{
+		users: map[string]*auth.User{
+			"user": {
+				User:     "user",
+				Roles:    []string{"root"},
+				Password: "password",
+			},
+		},
+		roles: map[string]*auth.Role{
+			"root": {
+				Role: "root",
+			},
+		},
+	}
+
+	var table = []struct {
+		username   string
+		req        *http.Request
+		userExists bool
+	}{
+		{
+			// valid user, valid pass
+			username:   "user",
+			req:        mustAuthRequest("GET", "user", "password"),
+			userExists: true,
+		},
+		{
+			// valid user, bad pass
+			username:   "user",
+			req:        mustAuthRequest("GET", "user", "badpass"),
+			userExists: false,
+		},
+		{
+			// valid user, no pass
+			username:   "user",
+			req:        mustAuthRequest("GET", "user", ""),
+			userExists: false,
+		},
+		{
+			// missing user
+			username:   "missing",
+			req:        mustAuthRequest("GET", "missing", "badpass"),
+			userExists: false,
+		},
+		{
+			// no basic auth
+			req:        unauthedRequest("GET"),
+			userExists: false,
+		},
+	}
+
+	for i, tt := range table {
+		user := userFromBasicAuth(sec, tt.req)
+		if tt.userExists == (user == nil) {
+			t.Errorf("#%d: userFromBasicAuth doesn't match (expected %v)", i, tt.userExists)
+		}
+		if user != nil && (tt.username != user.User) {
+			t.Errorf("#%d: userFromBasicAuth username doesn't match (expected %s, got %s)", i, tt.username, user.User)
+		}
+	}
+}