|
|
@@ -11,7 +11,7 @@ import (
|
|
|
"gopkg.in/jcmturner/gokrb5.v6/types"
|
|
|
)
|
|
|
|
|
|
-// Sessions are for holding TGTs and are keyed on the realm name
|
|
|
+// sessions hold TGTs and are keyed on the realm name
|
|
|
type sessions struct {
|
|
|
Entries map[string]*session
|
|
|
mux sync.RWMutex
|
|
|
@@ -53,7 +53,7 @@ func (s *sessions) get(realm string) (*session, bool) {
|
|
|
return sess, ok
|
|
|
}
|
|
|
|
|
|
-// session holds the TGT for a realm
|
|
|
+// session holds the TGT details for a realm
|
|
|
type session struct {
|
|
|
realm string
|
|
|
authTime time.Time
|
|
|
@@ -66,6 +66,23 @@ type session struct {
|
|
|
mux sync.RWMutex
|
|
|
}
|
|
|
|
|
|
+// AddSession adds a session for a realm with a TGT to the client's session cache.
|
|
|
+// A goroutine is started to automatically renew the TGT before expiry.
|
|
|
+func (cl *Client) AddSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
|
|
|
+ realm := cl.spnRealm(tgt.SName)
|
|
|
+ s := &session{
|
|
|
+ realm: realm,
|
|
|
+ authTime: dep.AuthTime,
|
|
|
+ endTime: dep.EndTime,
|
|
|
+ renewTill: dep.RenewTill,
|
|
|
+ tgt: tgt,
|
|
|
+ sessionKey: dep.Key,
|
|
|
+ sessionKeyExpiration: dep.KeyExpiration,
|
|
|
+ }
|
|
|
+ cl.sessions.update(s)
|
|
|
+ cl.enableAutoSessionRenewal(s)
|
|
|
+}
|
|
|
+
|
|
|
// update overwrites the session details with those from the TGT and decrypted encPart
|
|
|
func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
|
|
|
s.mux.Lock()
|
|
|
@@ -90,7 +107,7 @@ func (s *session) destroy() {
|
|
|
s.sessionKeyExpiration = s.endTime
|
|
|
}
|
|
|
|
|
|
-// valid informs if the TGT is still within valid time window
|
|
|
+// valid informs if the TGT is still within the valid time window
|
|
|
func (s *session) valid() bool {
|
|
|
s.mux.RLock()
|
|
|
defer s.mux.RUnlock()
|
|
|
@@ -101,6 +118,13 @@ func (s *session) valid() bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
+// tgtDetails is a thread safe way to get the session's realm, TGT and session key values
|
|
|
+func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
|
|
|
+ s.mux.RLock()
|
|
|
+ defer s.mux.RUnlock()
|
|
|
+ return s.realm, s.tgt, s.sessionKey
|
|
|
+}
|
|
|
+
|
|
|
// copy returns a copy of the session
|
|
|
func (s *session) copy() session {
|
|
|
s.mux.RLock()
|
|
|
@@ -117,23 +141,6 @@ func (s *session) copy() session {
|
|
|
return sess
|
|
|
}
|
|
|
|
|
|
-// AddSession adds a session for a realm with a TGT to the client's session cache.
|
|
|
-// A goroutine is started to automatically renew the TGT before expiry.
|
|
|
-func (cl *Client) AddSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
|
|
|
- realm := cl.Config.ResolveRealm(tgt.SName.NameString[len(tgt.SName.NameString)-1])
|
|
|
- s := &session{
|
|
|
- realm: realm,
|
|
|
- authTime: dep.AuthTime,
|
|
|
- endTime: dep.EndTime,
|
|
|
- renewTill: dep.RenewTill,
|
|
|
- tgt: tgt,
|
|
|
- sessionKey: dep.Key,
|
|
|
- sessionKeyExpiration: dep.KeyExpiration,
|
|
|
- }
|
|
|
- cl.sessions.update(s)
|
|
|
- cl.enableAutoSessionRenewal(s)
|
|
|
-}
|
|
|
-
|
|
|
// enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
|
|
|
func (cl *Client) enableAutoSessionRenewal(s *session) {
|
|
|
var timer *time.Timer
|
|
|
@@ -149,7 +156,7 @@ func (cl *Client) enableAutoSessionRenewal(s *session) {
|
|
|
timer = time.NewTimer(w)
|
|
|
select {
|
|
|
case <-timer.C:
|
|
|
- renewal, err := cl.updateSession(s)
|
|
|
+ renewal, err := cl.refreshSession(s)
|
|
|
if !renewal && err == nil {
|
|
|
// end this goroutine as there will have been a new login and new auto renewal goroutine created.
|
|
|
return
|
|
|
@@ -165,11 +172,12 @@ func (cl *Client) enableAutoSessionRenewal(s *session) {
|
|
|
|
|
|
// renewTGT renews the client's TGT session.
|
|
|
func (cl *Client) renewTGT(s *session) error {
|
|
|
+ realm, tgt, skey := s.tgtDetails()
|
|
|
spn := types.PrincipalName{
|
|
|
NameType: nametype.KRB_NT_SRV_INST,
|
|
|
- NameString: []string{"krbtgt", s.realm},
|
|
|
+ NameString: []string{"krbtgt", realm},
|
|
|
}
|
|
|
- _, tgsRep, err := cl.TGSExchange(spn, s.tgt.Realm, s.tgt, s.sessionKey, true, 0)
|
|
|
+ _, tgsRep, err := cl.TGSExchange(spn, cl.Credentials.Realm, tgt, skey, true, 0)
|
|
|
if err != nil {
|
|
|
return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT")
|
|
|
}
|
|
|
@@ -180,30 +188,55 @@ func (cl *Client) renewTGT(s *session) error {
|
|
|
|
|
|
// updateSession updates either through renewal or creating a new login.
|
|
|
// The boolean indicates if the update was a renewal.
|
|
|
-func (cl *Client) updateSession(s *session) (bool, error) {
|
|
|
- if time.Now().UTC().Before(s.renewTill) {
|
|
|
+func (cl *Client) refreshSession(s *session) (bool, error) {
|
|
|
+ s.mux.RLock()
|
|
|
+ realm := s.realm
|
|
|
+ renewTill := s.renewTill
|
|
|
+ s.mux.RUnlock()
|
|
|
+ if time.Now().UTC().Before(renewTill) {
|
|
|
err := cl.renewTGT(s)
|
|
|
return true, err
|
|
|
}
|
|
|
+ if realm != cl.Credentials.Realm {
|
|
|
+ // session is not for the client's own realm
|
|
|
+ _, err := cl.remoteRealmSession(realm)
|
|
|
+ if err != nil {
|
|
|
+ return false, err
|
|
|
+ }
|
|
|
+ }
|
|
|
err := cl.Login()
|
|
|
return false, err
|
|
|
}
|
|
|
|
|
|
+// ensureValidSession makes sure there is a valid session for the realm
|
|
|
func (cl *Client) ensureValidSession(realm string) error {
|
|
|
- s, _ := cl.sessions.get(realm)
|
|
|
- if s != nil && s.valid() {
|
|
|
- return nil
|
|
|
+ s, ok := cl.sessions.get(realm)
|
|
|
+ if ok {
|
|
|
+ s.mux.RLock()
|
|
|
+ defer s.mux.RUnlock()
|
|
|
+ d := s.endTime.Sub(s.authTime) / 6
|
|
|
+ if s.endTime.Sub(time.Now().UTC()) > d {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ _, err := cl.refreshSession(s)
|
|
|
+ return err
|
|
|
}
|
|
|
- _, err := cl.updateSession(s)
|
|
|
- if err != nil {
|
|
|
+ if realm != cl.Credentials.Realm {
|
|
|
+ // not for the client's own realm
|
|
|
+ _, err := cl.remoteRealmSession(realm)
|
|
|
return err
|
|
|
}
|
|
|
+ return cl.Login()
|
|
|
}
|
|
|
|
|
|
-func (cl *Client) sessionFromRemoteRealm(realm string) (*session, error) {
|
|
|
- s, ok := cl.sessions.get(realm)
|
|
|
- if !ok {
|
|
|
- return nil, fmt.Errorf("client does not have a session for realm %s, login first", cl.Credentials.Realm)
|
|
|
+// remoteRealmSession returns the session for a realm that the client is not a member of but for which there is a trust
|
|
|
+func (cl *Client) remoteRealmSession(realm string) (*session, error) {
|
|
|
+ s, ok := cl.sessions.get(cl.Credentials.Realm)
|
|
|
+ if !ok || !s.valid() {
|
|
|
+ err := cl.Login()
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("client was unable to login: %v", err)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
spn := types.PrincipalName{
|
|
|
@@ -222,22 +255,32 @@ func (cl *Client) sessionFromRemoteRealm(realm string) (*session, error) {
|
|
|
return cl.sessions.Entries[realm], nil
|
|
|
}
|
|
|
|
|
|
-// GetSessionFromRealm returns the session for the realm provided.
|
|
|
-func (cl *Client) sessionFromRealm(realm string) (session, error) {
|
|
|
+// realmSession returns the session for the realm provided.
|
|
|
+func (cl *Client) realmSession(realm string) (*session, error) {
|
|
|
s, ok := cl.sessions.get(realm)
|
|
|
var err error
|
|
|
if !ok {
|
|
|
// Try to request TGT from trusted remote Realm
|
|
|
- s, err = cl.sessionFromRemoteRealm(realm)
|
|
|
+ s, err = cl.remoteRealmSession(realm)
|
|
|
if err != nil {
|
|
|
- return s.copy(), err
|
|
|
+ return s, err
|
|
|
}
|
|
|
}
|
|
|
- return s.copy(), nil
|
|
|
+ return s, nil
|
|
|
+}
|
|
|
+
|
|
|
+// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
|
|
|
+func (cl *Client) sessionTGTDetails(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
|
|
|
+ var s *session
|
|
|
+ s, err = cl.realmSession(realm)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ realm, tgt, sessionKey = s.tgtDetails()
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
-// GetSessionFromPrincipalName returns the session for the realm of the principal provided.
|
|
|
-func (cl *Client) sessionFromPrincipalName(spn types.PrincipalName) (session, error) {
|
|
|
- realm := cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
|
|
|
- return cl.sessionFromRealm(realm)
|
|
|
+// spnRealm resolves the realm name of a service principal name
|
|
|
+func (cl *Client) spnRealm(spn types.PrincipalName) string {
|
|
|
+ return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
|
|
|
}
|