Jonathan Turner 7 éve
szülő
commit
d8699c4e49
4 módosított fájl, 97 hozzáadás és 63 törlés
  1. 7 16
      client/TGSExchange.go
  2. 1 1
      client/client.go
  3. 3 3
      client/client_integration_test.go
  4. 86 43
      client/session.go

+ 7 - 16
client/TGSExchange.go

@@ -1,8 +1,6 @@
 package client
 package client
 
 
 import (
 import (
-	"time"
-
 	"gopkg.in/jcmturner/gokrb5.v6/iana/nametype"
 	"gopkg.in/jcmturner/gokrb5.v6/iana/nametype"
 	"gopkg.in/jcmturner/gokrb5.v6/krberror"
 	"gopkg.in/jcmturner/gokrb5.v6/krberror"
 	"gopkg.in/jcmturner/gokrb5.v6/messages"
 	"gopkg.in/jcmturner/gokrb5.v6/messages"
@@ -67,23 +65,16 @@ func (cl *Client) GetServiceTicket(spn string) (messages.Ticket, types.Encryptio
 		return tkt, skey, nil
 		return tkt, skey, nil
 	}
 	}
 	princ := types.NewPrincipalName(nametype.KRB_NT_PRINCIPAL, spn)
 	princ := types.NewPrincipalName(nametype.KRB_NT_PRINCIPAL, spn)
-	sess, err := cl.sessionFromPrincipalName(princ)
+	realm := cl.Config.ResolveRealm(princ.NameString[len(princ.NameString)-1])
+
+	err := cl.ensureValidSession(realm)
 	if err != nil {
 	if err != nil {
 		return tkt, skey, err
 		return tkt, skey, err
 	}
 	}
-	// Ensure TGT still valid
-	if time.Now().UTC().After(sess.endTime) {
-		_, err := cl.updateSession(sess)
-		if err != nil {
-			return tkt, skey, err
-		}
-		// Get the session again as it could have been replaced by the update
-		sess, err = cl.sessionFromPrincipalName(princ)
-		if err != nil {
-			return tkt, skey, err
-		}
-	}
-	_, tgsRep, err := cl.TGSExchange(princ, sess.realm, sess.tgt, sess.sessionKey, false, 0)
+
+	tgt, skey, err := cl.sessionTGTDetails(realm)
+
+	_, tgsRep, err := cl.TGSExchange(princ, realm, tgt, skey, false, 0)
 	if err != nil {
 	if err != nil {
 		return tkt, skey, err
 		return tkt, skey, err
 	}
 	}

+ 1 - 1
client/client.go

@@ -176,7 +176,7 @@ func (cl *Client) IsConfigured() (bool, error) {
 	}
 	}
 	// Client needs to have either a password, keytab or a session already (later when loading from CCache)
 	// Client needs to have either a password, keytab or a session already (later when loading from CCache)
 	if !cl.Credentials.HasPassword() && !cl.Credentials.HasKeytab() {
 	if !cl.Credentials.HasPassword() && !cl.Credentials.HasKeytab() {
-		sess, err := cl.sessionFromRealm(cl.Credentials.Realm)
+		sess, err := cl.realmSession(cl.Credentials.Realm)
 		if err != nil || sess.authTime.IsZero() {
 		if err != nil || sess.authTime.IsZero() {
 			return false, errors.New("client has neither a keytab nor a password set and no session")
 			return false, errors.New("client has neither a keytab nor a password set and no session")
 		}
 		}

+ 3 - 3
client/client_integration_test.go

@@ -416,7 +416,7 @@ func TestMultiThreadedClientSession(t *testing.T) {
 		t.Fatalf("failed to log in: %v", err)
 		t.Fatalf("failed to log in: %v", err)
 	}
 	}
 
 
-	s, err := cl.sessionFromRealm("TEST.GOKRB5")
+	s, err := cl.realmSession("TEST.GOKRB5")
 	if err != nil {
 	if err != nil {
 		t.Fatalf("error initially getting session: %v", err)
 		t.Fatalf("error initially getting session: %v", err)
 	}
 	}
@@ -435,7 +435,7 @@ func TestMultiThreadedClientSession(t *testing.T) {
 	for i := 0; i < 10; i++ {
 	for i := 0; i < 10; i++ {
 		go func() {
 		go func() {
 			defer wg.Done()
 			defer wg.Done()
-			s, err := cl.sessionFromRealm("TEST.GOKRB5")
+			s, err := cl.realmSession("TEST.GOKRB5")
 			if err != nil {
 			if err != nil {
 				t.Logf("error getting session: %v", err)
 				t.Logf("error getting session: %v", err)
 			}
 			}
@@ -729,7 +729,7 @@ func TestClient_AutoRenew_Goroutine(t *testing.T) {
 	n := runtime.NumGoroutine()
 	n := runtime.NumGoroutine()
 	for i := 0; i < 6; i++ {
 	for i := 0; i < 6; i++ {
 		time.Sleep(time.Second * 20)
 		time.Sleep(time.Second * 20)
-		sess, err := cl.sessionFromRealm("TEST.GOKRB5")
+		sess, err := cl.realmSession("TEST.GOKRB5")
 		if err != nil {
 		if err != nil {
 			t.Errorf("could not get client's session: %v", err)
 			t.Errorf("could not get client's session: %v", err)
 		}
 		}

+ 86 - 43
client/session.go

@@ -11,7 +11,7 @@ import (
 	"gopkg.in/jcmturner/gokrb5.v6/types"
 	"gopkg.in/jcmturner/gokrb5.v6/types"
 )
 )
 
 
-// Sessions are for holding TGTs and are keyed on the realm name
+// sessions hold TGTs and are keyed on the realm name
 type sessions struct {
 type sessions struct {
 	Entries map[string]*session
 	Entries map[string]*session
 	mux     sync.RWMutex
 	mux     sync.RWMutex
@@ -53,7 +53,7 @@ func (s *sessions) get(realm string) (*session, bool) {
 	return sess, ok
 	return sess, ok
 }
 }
 
 
-// session holds the TGT for a realm
+// session holds the TGT details for a realm
 type session struct {
 type session struct {
 	realm                string
 	realm                string
 	authTime             time.Time
 	authTime             time.Time
@@ -66,6 +66,23 @@ type session struct {
 	mux                  sync.RWMutex
 	mux                  sync.RWMutex
 }
 }
 
 
+// AddSession adds a session for a realm with a TGT to the client's session cache.
+// A goroutine is started to automatically renew the TGT before expiry.
+func (cl *Client) AddSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
+	realm := cl.spnRealm(tgt.SName)
+	s := &session{
+		realm:                realm,
+		authTime:             dep.AuthTime,
+		endTime:              dep.EndTime,
+		renewTill:            dep.RenewTill,
+		tgt:                  tgt,
+		sessionKey:           dep.Key,
+		sessionKeyExpiration: dep.KeyExpiration,
+	}
+	cl.sessions.update(s)
+	cl.enableAutoSessionRenewal(s)
+}
+
 // update overwrites the session details with those from the TGT and decrypted encPart
 // update overwrites the session details with those from the TGT and decrypted encPart
 func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
 func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
 	s.mux.Lock()
 	s.mux.Lock()
@@ -90,7 +107,7 @@ func (s *session) destroy() {
 	s.sessionKeyExpiration = s.endTime
 	s.sessionKeyExpiration = s.endTime
 }
 }
 
 
-// valid informs if the TGT is still within valid time window
+// valid informs if the TGT is still within the valid time window
 func (s *session) valid() bool {
 func (s *session) valid() bool {
 	s.mux.RLock()
 	s.mux.RLock()
 	defer s.mux.RUnlock()
 	defer s.mux.RUnlock()
@@ -101,6 +118,13 @@ func (s *session) valid() bool {
 	return false
 	return false
 }
 }
 
 
+// tgtDetails is a thread safe way to get the session's realm, TGT and session key values
+func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
+	s.mux.RLock()
+	defer s.mux.RUnlock()
+	return s.realm, s.tgt, s.sessionKey
+}
+
 // copy returns a copy of the session
 // copy returns a copy of the session
 func (s *session) copy() session {
 func (s *session) copy() session {
 	s.mux.RLock()
 	s.mux.RLock()
@@ -117,23 +141,6 @@ func (s *session) copy() session {
 	return sess
 	return sess
 }
 }
 
 
-// AddSession adds a session for a realm with a TGT to the client's session cache.
-// A goroutine is started to automatically renew the TGT before expiry.
-func (cl *Client) AddSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
-	realm := cl.Config.ResolveRealm(tgt.SName.NameString[len(tgt.SName.NameString)-1])
-	s := &session{
-		realm:                realm,
-		authTime:             dep.AuthTime,
-		endTime:              dep.EndTime,
-		renewTill:            dep.RenewTill,
-		tgt:                  tgt,
-		sessionKey:           dep.Key,
-		sessionKeyExpiration: dep.KeyExpiration,
-	}
-	cl.sessions.update(s)
-	cl.enableAutoSessionRenewal(s)
-}
-
 // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
 // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
 func (cl *Client) enableAutoSessionRenewal(s *session) {
 func (cl *Client) enableAutoSessionRenewal(s *session) {
 	var timer *time.Timer
 	var timer *time.Timer
@@ -149,7 +156,7 @@ func (cl *Client) enableAutoSessionRenewal(s *session) {
 			timer = time.NewTimer(w)
 			timer = time.NewTimer(w)
 			select {
 			select {
 			case <-timer.C:
 			case <-timer.C:
-				renewal, err := cl.updateSession(s)
+				renewal, err := cl.refreshSession(s)
 				if !renewal && err == nil {
 				if !renewal && err == nil {
 					// end this goroutine as there will have been a new login and new auto renewal goroutine created.
 					// end this goroutine as there will have been a new login and new auto renewal goroutine created.
 					return
 					return
@@ -165,11 +172,12 @@ func (cl *Client) enableAutoSessionRenewal(s *session) {
 
 
 // renewTGT renews the client's TGT session.
 // renewTGT renews the client's TGT session.
 func (cl *Client) renewTGT(s *session) error {
 func (cl *Client) renewTGT(s *session) error {
+	realm, tgt, skey := s.tgtDetails()
 	spn := types.PrincipalName{
 	spn := types.PrincipalName{
 		NameType:   nametype.KRB_NT_SRV_INST,
 		NameType:   nametype.KRB_NT_SRV_INST,
-		NameString: []string{"krbtgt", s.realm},
+		NameString: []string{"krbtgt", realm},
 	}
 	}
-	_, tgsRep, err := cl.TGSExchange(spn, s.tgt.Realm, s.tgt, s.sessionKey, true, 0)
+	_, tgsRep, err := cl.TGSExchange(spn, cl.Credentials.Realm, tgt, skey, true, 0)
 	if err != nil {
 	if err != nil {
 		return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT")
 		return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT")
 	}
 	}
@@ -180,30 +188,55 @@ func (cl *Client) renewTGT(s *session) error {
 
 
 // updateSession updates either through renewal or creating a new login.
 // updateSession updates either through renewal or creating a new login.
 // The boolean indicates if the update was a renewal.
 // The boolean indicates if the update was a renewal.
-func (cl *Client) updateSession(s *session) (bool, error) {
-	if time.Now().UTC().Before(s.renewTill) {
+func (cl *Client) refreshSession(s *session) (bool, error) {
+	s.mux.RLock()
+	realm := s.realm
+	renewTill := s.renewTill
+	s.mux.RUnlock()
+	if time.Now().UTC().Before(renewTill) {
 		err := cl.renewTGT(s)
 		err := cl.renewTGT(s)
 		return true, err
 		return true, err
 	}
 	}
+	if realm != cl.Credentials.Realm {
+		// session is not for the client's own realm
+		_, err := cl.remoteRealmSession(realm)
+		if err != nil {
+			return false, err
+		}
+	}
 	err := cl.Login()
 	err := cl.Login()
 	return false, err
 	return false, err
 }
 }
 
 
+// ensureValidSession makes sure there is a valid session for the realm
 func (cl *Client) ensureValidSession(realm string) error {
 func (cl *Client) ensureValidSession(realm string) error {
-	s, _ := cl.sessions.get(realm)
-	if s != nil && s.valid() {
-		return nil
+	s, ok := cl.sessions.get(realm)
+	if ok {
+		s.mux.RLock()
+		defer s.mux.RUnlock()
+		d := s.endTime.Sub(s.authTime) / 6
+		if s.endTime.Sub(time.Now().UTC()) > d {
+			return nil
+		}
+		_, err := cl.refreshSession(s)
+		return err
 	}
 	}
-	_, err := cl.updateSession(s)
-	if err != nil {
+	if realm != cl.Credentials.Realm {
+		// not for the client's own realm
+		_, err := cl.remoteRealmSession(realm)
 		return err
 		return err
 	}
 	}
+	return cl.Login()
 }
 }
 
 
-func (cl *Client) sessionFromRemoteRealm(realm string) (*session, error) {
-	s, ok := cl.sessions.get(realm)
-	if !ok {
-		return nil, fmt.Errorf("client does not have a session for realm %s, login first", cl.Credentials.Realm)
+// remoteRealmSession returns the session for a realm that the client is not a member of but for which there is a trust
+func (cl *Client) remoteRealmSession(realm string) (*session, error) {
+	s, ok := cl.sessions.get(cl.Credentials.Realm)
+	if !ok || !s.valid() {
+		err := cl.Login()
+		if err != nil {
+			return nil, fmt.Errorf("client was unable to login: %v", err)
+		}
 	}
 	}
 
 
 	spn := types.PrincipalName{
 	spn := types.PrincipalName{
@@ -222,22 +255,32 @@ func (cl *Client) sessionFromRemoteRealm(realm string) (*session, error) {
 	return cl.sessions.Entries[realm], nil
 	return cl.sessions.Entries[realm], nil
 }
 }
 
 
-// GetSessionFromRealm returns the session for the realm provided.
-func (cl *Client) sessionFromRealm(realm string) (session, error) {
+// realmSession returns the session for the realm provided.
+func (cl *Client) realmSession(realm string) (*session, error) {
 	s, ok := cl.sessions.get(realm)
 	s, ok := cl.sessions.get(realm)
 	var err error
 	var err error
 	if !ok {
 	if !ok {
 		// Try to request TGT from trusted remote Realm
 		// Try to request TGT from trusted remote Realm
-		s, err = cl.sessionFromRemoteRealm(realm)
+		s, err = cl.remoteRealmSession(realm)
 		if err != nil {
 		if err != nil {
-			return s.copy(), err
+			return s, err
 		}
 		}
 	}
 	}
-	return s.copy(), nil
+	return s, nil
+}
+
+// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
+func (cl *Client) sessionTGTDetails(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
+	var s *session
+	s, err = cl.realmSession(realm)
+	if err != nil {
+		return
+	}
+	realm, tgt, sessionKey = s.tgtDetails()
+	return
 }
 }
 
 
-// GetSessionFromPrincipalName returns the session for the realm of the principal provided.
-func (cl *Client) sessionFromPrincipalName(spn types.PrincipalName) (session, error) {
-	realm := cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
-	return cl.sessionFromRealm(realm)
+// spnRealm resolves the realm name of a service principal name
+func (cl *Client) spnRealm(spn types.PrincipalName) string {
+	return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
 }
 }