Jonathan Turner 7 лет назад
Родитель
Сommit
05f9505869
1 измененных файлов с 76 добавлено и 34 удалено
  1. 76 34
      client/session.go

+ 76 - 34
client/session.go

@@ -11,12 +11,13 @@ import (
 	"gopkg.in/jcmturner/gokrb5.v6/types"
 )
 
-// Sessions keyed on the realm name
+// Sessions are for holding TGTs and are keyed on the realm name
 type sessions struct {
 	Entries map[string]*session
 	mux     sync.RWMutex
 }
 
+// destroy erases all sessions
 func (s *sessions) destroy() {
 	s.mux.Lock()
 	defer s.mux.Unlock()
@@ -26,7 +27,33 @@ func (s *sessions) destroy() {
 	}
 }
 
-// Client session struct.
+// update replaces a session with the one provided or adds it as a new one
+func (s *sessions) update(sess *session) {
+	s.mux.Lock()
+	defer s.mux.Unlock()
+	// if a session already exists for this, cancel its auto renew.
+	if i, ok := s.Entries[sess.realm]; ok {
+		if i != sess {
+			// Session in the sessions cache is not the same as one provided.
+			// Cancel the one in the cache and add this one.
+			i.cancel <- true
+			s.Entries[sess.realm] = sess
+			return
+		}
+	}
+	// No session for this realm was found so just add it
+	s.Entries[sess.realm] = sess
+}
+
+// get returns the session for the realm specified
+func (s *sessions) get(realm string) (*session, bool) {
+	s.mux.RLock()
+	defer s.mux.RUnlock()
+	sess, ok := s.Entries[realm]
+	return sess, ok
+}
+
+// session holds the TGT for a realm
 type session struct {
 	realm                string
 	authTime             time.Time
@@ -39,6 +66,7 @@ type session struct {
 	mux                  sync.RWMutex
 }
 
+// update overwrites the session details with those from the TGT and decrypted encPart
 func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
 	s.mux.Lock()
 	defer s.mux.Unlock()
@@ -50,15 +78,19 @@ func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
 	s.sessionKeyExpiration = dep.KeyExpiration
 }
 
+// destroy will cancel any auto renewal of the session and set the expiration times to the current time
 func (s *session) destroy() {
 	s.mux.Lock()
 	defer s.mux.Unlock()
-	s.cancel <- true
+	if s.cancel != nil {
+		s.cancel <- true
+	}
 	s.endTime = time.Now().UTC()
 	s.renewTill = s.endTime
 	s.sessionKeyExpiration = s.endTime
 }
 
+// valid informs if the TGT is still within valid time window
 func (s *session) valid() bool {
 	s.mux.RLock()
 	defer s.mux.RUnlock()
@@ -69,32 +101,43 @@ func (s *session) valid() bool {
 	return false
 }
 
+// copy returns a copy of the session
+func (s *session) copy() session {
+	s.mux.RLock()
+	defer s.mux.RUnlock()
+	sess := session{
+		realm:                s.realm,
+		authTime:             s.authTime,
+		endTime:              s.endTime,
+		renewTill:            s.renewTill,
+		tgt:                  s.tgt,
+		sessionKey:           s.sessionKey,
+		sessionKeyExpiration: s.sessionKeyExpiration,
+	}
+	return sess
+}
+
 // AddSession adds a session for a realm with a TGT to the client's session cache.
 // A goroutine is started to automatically renew the TGT before expiry.
 func (cl *Client) AddSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
-	cl.sessions.mux.Lock()
-	defer cl.sessions.mux.Unlock()
+	realm := cl.Config.ResolveRealm(tgt.SName.NameString[len(tgt.SName.NameString)-1])
 	s := &session{
-		realm:                tgt.SName.NameString[1],
+		realm:                realm,
 		authTime:             dep.AuthTime,
 		endTime:              dep.EndTime,
 		renewTill:            dep.RenewTill,
 		tgt:                  tgt,
 		sessionKey:           dep.Key,
 		sessionKeyExpiration: dep.KeyExpiration,
-		cancel:               make(chan bool, 1),
 	}
-	// if a session already exists for this, cancel its auto renew.
-	if i, ok := cl.sessions.Entries[tgt.SName.NameString[1]]; ok {
-		i.cancel <- true
-	}
-	cl.sessions.Entries[tgt.SName.NameString[1]] = s
+	cl.sessions.update(s)
 	cl.enableAutoSessionRenewal(s)
 }
 
 // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
 func (cl *Client) enableAutoSessionRenewal(s *session) {
 	var timer *time.Timer
+	s.cancel = make(chan bool, 1)
 	go func(s *session) {
 		for {
 			s.mux.RLock()
@@ -120,7 +163,7 @@ func (cl *Client) enableAutoSessionRenewal(s *session) {
 	}(s)
 }
 
-// RenewTGT renews the client's TGT session.
+// renewTGT renews the client's TGT session.
 func (cl *Client) renewTGT(s *session) error {
 	spn := types.PrincipalName{
 		NameType:   nametype.KRB_NT_SRV_INST,
@@ -131,6 +174,7 @@ func (cl *Client) renewTGT(s *session) error {
 		return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT")
 	}
 	s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
+	cl.sessions.update(s)
 	return nil
 }
 
@@ -145,10 +189,19 @@ func (cl *Client) updateSession(s *session) (bool, error) {
 	return false, err
 }
 
+func (cl *Client) ensureValidSession(realm string) error {
+	s, _ := cl.sessions.get(realm)
+	if s != nil && s.valid() {
+		return nil
+	}
+	_, err := cl.updateSession(s)
+	if err != nil {
+		return err
+	}
+}
+
 func (cl *Client) sessionFromRemoteRealm(realm string) (*session, error) {
-	cl.sessions.mux.RLock()
-	sess, ok := cl.sessions.Entries[cl.Credentials.Realm]
-	cl.sessions.mux.RUnlock()
+	s, ok := cl.sessions.get(realm)
 	if !ok {
 		return nil, fmt.Errorf("client does not have a session for realm %s, login first", cl.Credentials.Realm)
 	}
@@ -158,7 +211,7 @@ func (cl *Client) sessionFromRemoteRealm(realm string) (*session, error) {
 		NameString: []string{"krbtgt", realm},
 	}
 
-	_, tgsRep, err := cl.TGSExchange(spn, cl.Credentials.Realm, sess.tgt, sess.sessionKey, false, 0)
+	_, tgsRep, err := cl.TGSExchange(spn, cl.Credentials.Realm, s.tgt, s.sessionKey, false, 0)
 	if err != nil {
 		return nil, err
 	}
@@ -170,32 +223,21 @@ func (cl *Client) sessionFromRemoteRealm(realm string) (*session, error) {
 }
 
 // GetSessionFromRealm returns the session for the realm provided.
-func (cl *Client) sessionFromRealm(realm string) (sess *session, err error) {
-	cl.sessions.mux.RLock()
-	s, ok := cl.sessions.Entries[realm]
-	cl.sessions.mux.RUnlock()
+func (cl *Client) sessionFromRealm(realm string) (session, error) {
+	s, ok := cl.sessions.get(realm)
+	var err error
 	if !ok {
 		// Try to request TGT from trusted remote Realm
 		s, err = cl.sessionFromRemoteRealm(realm)
 		if err != nil {
-			return
+			return s.copy(), err
 		}
 	}
-	// Create another session to return to prevent race condition.
-	sess = &session{
-		realm:                s.realm,
-		authTime:             s.authTime,
-		endTime:              s.endTime,
-		renewTill:            s.renewTill,
-		tgt:                  s.tgt,
-		sessionKey:           s.sessionKey,
-		sessionKeyExpiration: s.sessionKeyExpiration,
-	}
-	return
+	return s.copy(), nil
 }
 
 // GetSessionFromPrincipalName returns the session for the realm of the principal provided.
-func (cl *Client) sessionFromPrincipalName(spn types.PrincipalName) (*session, error) {
+func (cl *Client) sessionFromPrincipalName(spn types.PrincipalName) (session, error) {
 	realm := cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
 	return cl.sessionFromRealm(realm)
 }