Просмотр исходного кода

credential context change to goidentity v3

Jonathan Turner 7 лет назад
Родитель
Сommit
b68e2282a3

+ 21 - 5
credentials/credentials.go

@@ -12,7 +12,7 @@ import (
 
 const (
 	// AttributeKeyADCredentials assigned number for AD credentials.
-	AttributeKeyADCredentials = 1
+	AttributeKeyADCredentials = "gokrb5AttributeKeyADCredentials"
 )
 
 // Credentials struct for a user.
@@ -25,7 +25,7 @@ type Credentials struct {
 	CName       types.PrincipalName
 	Keytab      keytab.Keytab
 	Password    string
-	Attributes  map[int]interface{}
+	attributes  map[string]interface{}
 	ValidUntil  time.Time
 
 	authenticated   bool
@@ -62,7 +62,7 @@ func NewCredentials(username string, realm string) Credentials {
 		Realm:       realm,
 		CName:       types.NewPrincipalName(nametype.KRB_NT_PRINCIPAL, username),
 		Keytab:      keytab.NewKeytab(),
-		Attributes:  make(map[int]interface{}),
+		attributes:  make(map[string]interface{}),
 		sessionID:   uid,
 	}
 }
@@ -79,7 +79,7 @@ func NewCredentialsFromPrincipal(cname types.PrincipalName, realm string) Creden
 		Realm:           realm,
 		CName:           cname,
 		Keytab:          keytab.NewKeytab(),
-		Attributes:      make(map[int]interface{}),
+		attributes:      make(map[string]interface{}),
 		groupMembership: make(map[string]bool),
 		sessionID:       uid,
 	}
@@ -120,7 +120,7 @@ func (c *Credentials) HasPassword() bool {
 
 // SetADCredentials adds ADCredentials attributes to the credentials
 func (c *Credentials) SetADCredentials(a ADCredentials) {
-	c.Attributes[AttributeKeyADCredentials] = a
+	c.SetAttribute(AttributeKeyADCredentials, a)
 	if a.FullName != "" {
 		c.SetDisplayName(a.FullName)
 	}
@@ -255,3 +255,19 @@ func (c *Credentials) Expired() bool {
 	}
 	return false
 }
+
+func (c *Credentials) Attributes() map[string]interface{} {
+	return c.attributes
+}
+
+func (c *Credentials) SetAttribute(k string, v interface{}) {
+	c.attributes[k] = v
+}
+
+func (c *Credentials) SetAttributes(a map[string]interface{}) {
+	c.attributes = a
+}
+
+func (c *Credentials) RemoveAttribute(k string) {
+	delete(c.attributes, k)
+}

+ 1 - 1
credentials/credentials_test.go

@@ -2,7 +2,7 @@ package credentials
 
 import (
 	"github.com/stretchr/testify/assert"
-	goidentity "gopkg.in/jcmturner/goidentity.v2"
+	goidentity "gopkg.in/jcmturner/goidentity.v3"
 	"testing"
 )
 

+ 3 - 2
examples/example-AD.go

@@ -5,6 +5,7 @@ package main
 import (
 	"encoding/hex"
 	"fmt"
+	"gopkg.in/jcmturner/goidentity.v3"
 	"gopkg.in/jcmturner/gokrb5.v5/client"
 	"gopkg.in/jcmturner/gokrb5.v5/config"
 	"gopkg.in/jcmturner/gokrb5.v5/credentials"
@@ -75,7 +76,7 @@ func testAppHandler(w http.ResponseWriter, r *http.Request) {
 	ctx := r.Context()
 	fmt.Fprint(w, "<html>\n<p><h1>TEST.GOKRB5 Handler</h1></p>\n")
 	if validuser, ok := ctx.Value(service.CTXKeyAuthenticated).(bool); ok && validuser {
-		if creds, ok := ctx.Value(service.CTXKeyCredentials).(credentials.Credentials); ok {
+		if creds, ok := ctx.Value(service.CTXKeyCredentials).(goidentity.Identity); ok {
 			fmt.Fprintf(w, "<ul><li>Authenticed user: %s</li>\n", creds.UserName())
 			fmt.Fprintf(w, "<li>User's realm: %s</li>\n", creds.Domain())
 			fmt.Fprint(w, "<li>Authz Attributes (Group Memberships):</li><ul>\n")
@@ -83,7 +84,7 @@ func testAppHandler(w http.ResponseWriter, r *http.Request) {
 				fmt.Fprintf(w, "<li>%v</li>\n", s)
 			}
 			fmt.Fprint(w, "</ul>\n")
-			if ADCreds, ok := creds.Attributes[credentials.AttributeKeyADCredentials].(credentials.ADCredentials); ok {
+			if ADCreds, ok := creds.Attributes()[credentials.AttributeKeyADCredentials].(credentials.ADCredentials); ok {
 				// Now access the fields of the ADCredentials struct. For example:
 				fmt.Fprintf(w, "<li>EffectiveName: %v</li>\n", ADCreds.EffectiveName)
 				fmt.Fprintf(w, "<li>FullName: %v</li>\n", ADCreds.FullName)

+ 8 - 7
examples/example.go

@@ -5,17 +5,18 @@ package main
 import (
 	"encoding/hex"
 	"fmt"
-	"gopkg.in/jcmturner/gokrb5.v5/client"
-	"gopkg.in/jcmturner/gokrb5.v5/config"
-	"gopkg.in/jcmturner/gokrb5.v5/credentials"
-	"gopkg.in/jcmturner/gokrb5.v5/keytab"
-	"gopkg.in/jcmturner/gokrb5.v5/service"
-	"gopkg.in/jcmturner/gokrb5.v5/testdata"
 	"io/ioutil"
 	"log"
 	"net/http"
 	"net/http/httptest"
 	"os"
+
+	"gopkg.in/jcmturner/goidentity.v3"
+	"gopkg.in/jcmturner/gokrb5.v5/client"
+	"gopkg.in/jcmturner/gokrb5.v5/config"
+	"gopkg.in/jcmturner/gokrb5.v5/keytab"
+	"gopkg.in/jcmturner/gokrb5.v5/service"
+	"gopkg.in/jcmturner/gokrb5.v5/testdata"
 )
 
 func main() {
@@ -74,7 +75,7 @@ func testAppHandler(w http.ResponseWriter, r *http.Request) {
 	ctx := r.Context()
 	fmt.Fprint(w, "<html>\n<p><h1>TEST.GOKRB5 Handler</h1></p>\n")
 	if validuser, ok := ctx.Value(service.CTXKeyAuthenticated).(bool); ok && validuser {
-		if creds, ok := ctx.Value(service.CTXKeyCredentials).(credentials.Credentials); ok {
+		if creds, ok := ctx.Value(service.CTXKeyCredentials).(goidentity.Identity); ok {
 			fmt.Fprintf(w, "<ul><li>Authenticed user: %s</li>\n", creds.Username)
 			fmt.Fprintf(w, "<li>User's realm: %s</li></ul>\n", creds.Realm)
 		}

+ 23 - 19
service/APExchange.go

@@ -80,25 +80,29 @@ func ValidateAPREQ(APReq messages.APReq, c *SPNEGOAuthenticator) (bool, credenti
 	creds.SetAuthTime(t)
 	creds.SetAuthenticated(true)
 	creds.SetValidUntil(APReq.Ticket.DecryptedEncPart.EndTime)
-	isPAC, pac, err := APReq.Ticket.GetPACType(*c.Keytab, c.ServicePrincipal)
-	if isPAC && err != nil {
-		return false, creds, err
-	}
-	if isPAC {
-		// There is a valid PAC. Adding attributes to creds
-		creds.SetADCredentials(credentials.ADCredentials{
-			GroupMembershipSIDs: pac.KerbValidationInfo.GetGroupMembershipSIDs(),
-			LogOnTime:           pac.KerbValidationInfo.LogOnTime.Time(),
-			LogOffTime:          pac.KerbValidationInfo.LogOffTime.Time(),
-			PasswordLastSet:     pac.KerbValidationInfo.PasswordLastSet.Time(),
-			EffectiveName:       pac.KerbValidationInfo.EffectiveName.Value,
-			FullName:            pac.KerbValidationInfo.FullName.Value,
-			UserID:              int(pac.KerbValidationInfo.UserID),
-			PrimaryGroupID:      int(pac.KerbValidationInfo.PrimaryGroupID),
-			LogonServer:         pac.KerbValidationInfo.LogonServer.Value,
-			LogonDomainName:     pac.KerbValidationInfo.LogonDomainName.Value,
-			LogonDomainID:       pac.KerbValidationInfo.LogonDomainID.String(),
-		})
+
+	//PAC decoding
+	if !c.DisablePACDecoding {
+		isPAC, pac, err := APReq.Ticket.GetPACType(*c.Keytab, c.ServicePrincipal)
+		if isPAC && err != nil {
+			return false, creds, err
+		}
+		if isPAC {
+			// There is a valid PAC. Adding attributes to creds
+			creds.SetADCredentials(credentials.ADCredentials{
+				GroupMembershipSIDs: pac.KerbValidationInfo.GetGroupMembershipSIDs(),
+				LogOnTime:           pac.KerbValidationInfo.LogOnTime.Time(),
+				LogOffTime:          pac.KerbValidationInfo.LogOffTime.Time(),
+				PasswordLastSet:     pac.KerbValidationInfo.PasswordLastSet.Time(),
+				EffectiveName:       pac.KerbValidationInfo.EffectiveName.Value,
+				FullName:            pac.KerbValidationInfo.FullName.Value,
+				UserID:              int(pac.KerbValidationInfo.UserID),
+				PrimaryGroupID:      int(pac.KerbValidationInfo.PrimaryGroupID),
+				LogonServer:         pac.KerbValidationInfo.LogonServer.Value,
+				LogonDomainName:     pac.KerbValidationInfo.LogonDomainName.Value,
+				LogonDomainID:       pac.KerbValidationInfo.LogonDomainID.String(),
+			})
+		}
 	}
 	return true, creds, nil
 }

+ 22 - 9
service/APExchange_test.go

@@ -51,7 +51,9 @@ func TestValidateAPREQ(t *testing.T) {
 		t.Fatalf("Error getting test AP_REQ: %v", err)
 	}
 
-	ok, _, err := ValidateAPREQ(APReq, kt, "", "127.0.0.1", false)
+	c := NewSPNEGOAuthenticator(kt)
+	c.ClientAddr = "127.0.0.1"
+	ok, _, err := ValidateAPREQ(APReq, c)
 	if !ok || err != nil {
 		t.Fatalf("Validation of AP_REQ failed when it should not have: %v", err)
 	}
@@ -94,8 +96,9 @@ func TestValidateAPREQ_KRB_AP_ERR_BADMATCH(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Error getting test AP_REQ: %v", err)
 	}
-
-	ok, _, err := ValidateAPREQ(APReq, kt, "", "127.0.0.1", false)
+	c := NewSPNEGOAuthenticator(kt)
+	c.ClientAddr = "127.0.0.1"
+	ok, _, err := ValidateAPREQ(APReq, c)
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
@@ -141,7 +144,9 @@ func TestValidateAPREQ_LargeClockSkew(t *testing.T) {
 		t.Fatalf("Error getting test AP_REQ: %v", err)
 	}
 
-	ok, _, err := ValidateAPREQ(APReq, kt, "", "127.0.0.1", false)
+	c := NewSPNEGOAuthenticator(kt)
+	c.ClientAddr = "127.0.0.1"
+	ok, _, err := ValidateAPREQ(APReq, c)
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
@@ -185,12 +190,14 @@ func TestValidateAPREQ_Replay(t *testing.T) {
 		t.Fatalf("Error getting test AP_REQ: %v", err)
 	}
 
-	ok, _, err := ValidateAPREQ(APReq, kt, "", "127.0.0.1", false)
+	c := NewSPNEGOAuthenticator(kt)
+	c.ClientAddr = "127.0.0.1"
+	ok, _, err := ValidateAPREQ(APReq, c)
 	if !ok || err != nil {
 		t.Fatalf("Validation of AP_REQ failed when it should not have: %v", err)
 	}
 	// Replay
-	ok, _, err = ValidateAPREQ(APReq, kt, "", "127.0.0.1", false)
+	ok, _, err = ValidateAPREQ(APReq, c)
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
@@ -232,7 +239,9 @@ func TestValidateAPREQ_FutureTicket(t *testing.T) {
 		t.Fatalf("Error getting test AP_REQ: %v", err)
 	}
 
-	ok, _, err := ValidateAPREQ(APReq, kt, "", "127.0.0.1", false)
+	c := NewSPNEGOAuthenticator(kt)
+	c.ClientAddr = "127.0.0.1"
+	ok, _, err := ValidateAPREQ(APReq, c)
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
@@ -278,7 +287,9 @@ func TestValidateAPREQ_InvalidTicket(t *testing.T) {
 		t.Fatalf("Error getting test AP_REQ: %v", err)
 	}
 
-	ok, _, err := ValidateAPREQ(APReq, kt, "", "127.0.0.1", false)
+	c := NewSPNEGOAuthenticator(kt)
+	c.ClientAddr = "127.0.0.1"
+	ok, _, err := ValidateAPREQ(APReq, c)
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}
@@ -323,7 +334,9 @@ func TestValidateAPREQ_ExpiredTicket(t *testing.T) {
 		t.Fatalf("Error getting test AP_REQ: %v", err)
 	}
 
-	ok, _, err := ValidateAPREQ(APReq, kt, "", "127.0.0.1", false)
+	c := NewSPNEGOAuthenticator(kt)
+	c.ClientAddr = "127.0.0.1"
+	ok, _, err := ValidateAPREQ(APReq, c)
 	if ok || err == nil {
 		t.Fatal("Validation of AP_REQ passed when it should not have")
 	}

+ 5 - 5
service/authenticator.go

@@ -7,7 +7,7 @@ import (
 	"strings"
 	"time"
 
-	goidentity "gopkg.in/jcmturner/goidentity.v2"
+	goidentity "gopkg.in/jcmturner/goidentity.v3"
 	"gopkg.in/jcmturner/gokrb5.v5/client"
 	"gopkg.in/jcmturner/gokrb5.v5/config"
 	"gopkg.in/jcmturner/gokrb5.v5/credentials"
@@ -15,7 +15,7 @@ import (
 	"gopkg.in/jcmturner/gokrb5.v5/keytab"
 )
 
-// SPNEGOAuthenticator implements gopkg.in/jcmturner/goidentity.v2.Authenticator interface
+// SPNEGOAuthenticator implements gopkg.in/jcmturner/goidentity.v3.Authenticator interface
 type SPNEGOAuthenticator struct {
 	SPNEGOHeaderValue  string
 	Keytab             *keytab.Keytab
@@ -30,7 +30,7 @@ func NewSPNEGOAuthenticator(kt keytab.Keytab) *SPNEGOAuthenticator {
 }
 
 // Authenticate and retrieve a goidentity.Identity. In this case it is a pointer to a credentials.Credentials
-func (a *SPNEGOAuthenticator) Authenticate() (i goidentity.Identity, ok bool, err error) {
+func (a SPNEGOAuthenticator) Authenticate() (i goidentity.Identity, ok bool, err error) {
 	b, err := base64.StdEncoding.DecodeString(a.SPNEGOHeaderValue)
 	if err != nil {
 		err = fmt.Errorf("SPNEGO error in base64 decoding negotiation header: %v", err)
@@ -57,7 +57,7 @@ func (a *SPNEGOAuthenticator) Authenticate() (i goidentity.Identity, ok bool, er
 		return
 	}
 
-	ok, creds, err := ValidateAPREQ(mt.APReq, a)
+	ok, creds, err := ValidateAPREQ(mt.APReq, &a)
 	if err != nil {
 		err = fmt.Errorf("SPNEGO validation error: %v", err)
 		return
@@ -71,7 +71,7 @@ func (a SPNEGOAuthenticator) Mechanism() string {
 	return "SPNEGO Kerberos"
 }
 
-// KRB5BasicAuthenticator implements gopkg.in/jcmturner/goidentity.v2.Authenticator interface.
+// KRB5BasicAuthenticator implements gopkg.in/jcmturner/goidentity.v3.Authenticator interface.
 // It takes username and password so can be used for basic authentication.
 type KRB5BasicAuthenticator struct {
 	BasicHeaderValue string

+ 1 - 1
service/authenticator_test.go

@@ -4,7 +4,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-	"gopkg.in/jcmturner/goidentity.v2"
+	"gopkg.in/jcmturner/goidentity.v3"
 )
 
 func TestImplementsInterface(t *testing.T) {

+ 6 - 3
service/http_test.go

@@ -12,8 +12,8 @@ import (
 	"time"
 
 	"github.com/stretchr/testify/assert"
+	"gopkg.in/jcmturner/goidentity.v3"
 	"gopkg.in/jcmturner/gokrb5.v5/client"
-	"gopkg.in/jcmturner/gokrb5.v5/credentials"
 	"gopkg.in/jcmturner/gokrb5.v5/iana/nametype"
 	"gopkg.in/jcmturner/gokrb5.v5/keytab"
 	"gopkg.in/jcmturner/gokrb5.v5/messages"
@@ -246,13 +246,16 @@ func httpServer() *httptest.Server {
 	b, _ := hex.DecodeString(testdata.HTTP_KEYTAB)
 	kt, _ := keytab.Parse(b)
 	th := http.HandlerFunc(testAppHandler)
-	s := httptest.NewServer(SPNEGOKRB5Authenticate(th, kt, "", false, l))
+	c := NewSPNEGOAuthenticator(kt)
+	s := httptest.NewServer(SPNEGOKRB5Authenticate(th, c, l))
 	return s
 }
 
 func testAppHandler(w http.ResponseWriter, r *http.Request) {
 	w.WriteHeader(http.StatusOK)
 	ctx := r.Context()
-	fmt.Fprintf(w, "<html>\nTEST.GOKRB5 Handler\nAuthenticed user: %s\nUser's realm: %s\n</html>", ctx.Value(CTXKeyCredentials).(credentials.Credentials).Username, ctx.Value(CTXKeyCredentials).(credentials.Credentials).Realm)
+	fmt.Fprintf(w, "<html>\nTEST.GOKRB5 Handler\nAuthenticed user: %s\nUser's realm: %s\n</html>",
+		ctx.Value(CTXKeyCredentials).(goidentity.Identity).UserName(),
+		ctx.Value(CTXKeyCredentials).(goidentity.Identity).Domain())
 	return
 }