session.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package client
  2. import (
  3. "fmt"
  4. "strings"
  5. "sync"
  6. "time"
  7. "github.com/jcmturner/gokrb5/v8/iana/nametype"
  8. "github.com/jcmturner/gokrb5/v8/krberror"
  9. "github.com/jcmturner/gokrb5/v8/messages"
  10. "github.com/jcmturner/gokrb5/v8/types"
  11. )
  12. // sessions hold TGTs and are keyed on the realm name
  13. type sessions struct {
  14. Entries map[string]*session
  15. mux sync.RWMutex
  16. }
  17. // destroy erases all sessions
  18. func (s *sessions) destroy() {
  19. s.mux.Lock()
  20. defer s.mux.Unlock()
  21. for k, e := range s.Entries {
  22. e.destroy()
  23. delete(s.Entries, k)
  24. }
  25. }
  26. // update replaces a session with the one provided or adds it as a new one
  27. func (s *sessions) update(sess *session) {
  28. s.mux.Lock()
  29. defer s.mux.Unlock()
  30. // if a session already exists for this, cancel its auto renew.
  31. if i, ok := s.Entries[sess.realm]; ok {
  32. if i != sess {
  33. // Session in the sessions cache is not the same as one provided.
  34. // Cancel the one in the cache and add this one.
  35. i.mux.Lock()
  36. defer i.mux.Unlock()
  37. i.cancel <- true
  38. s.Entries[sess.realm] = sess
  39. return
  40. }
  41. }
  42. // No session for this realm was found so just add it
  43. s.Entries[sess.realm] = sess
  44. }
  45. // get returns the session for the realm specified
  46. func (s *sessions) get(realm string) (*session, bool) {
  47. s.mux.RLock()
  48. defer s.mux.RUnlock()
  49. sess, ok := s.Entries[realm]
  50. return sess, ok
  51. }
  52. // session holds the TGT details for a realm
  53. type session struct {
  54. realm string
  55. authTime time.Time
  56. endTime time.Time
  57. renewTill time.Time
  58. tgt messages.Ticket
  59. sessionKey types.EncryptionKey
  60. sessionKeyExpiration time.Time
  61. cancel chan bool
  62. mux sync.RWMutex
  63. }
  64. // AddSession adds a session for a realm with a TGT to the client's session cache.
  65. // A goroutine is started to automatically renew the TGT before expiry.
  66. func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  67. if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
  68. // Not a TGT
  69. return
  70. }
  71. realm := tgt.SName.NameString[len(tgt.SName.NameString)-1]
  72. s := &session{
  73. realm: realm,
  74. authTime: dep.AuthTime,
  75. endTime: dep.EndTime,
  76. renewTill: dep.RenewTill,
  77. tgt: tgt,
  78. sessionKey: dep.Key,
  79. sessionKeyExpiration: dep.KeyExpiration,
  80. }
  81. cl.sessions.update(s)
  82. cl.enableAutoSessionRenewal(s)
  83. cl.Log("TGT session added for %s (EndTime: %v)", realm, dep.EndTime)
  84. }
  85. // update overwrites the session details with those from the TGT and decrypted encPart
  86. func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  87. s.mux.Lock()
  88. defer s.mux.Unlock()
  89. s.authTime = dep.AuthTime
  90. s.endTime = dep.EndTime
  91. s.renewTill = dep.RenewTill
  92. s.tgt = tgt
  93. s.sessionKey = dep.Key
  94. s.sessionKeyExpiration = dep.KeyExpiration
  95. }
  96. // destroy will cancel any auto renewal of the session and set the expiration times to the current time
  97. func (s *session) destroy() {
  98. s.mux.Lock()
  99. defer s.mux.Unlock()
  100. if s.cancel != nil {
  101. s.cancel <- true
  102. }
  103. s.endTime = time.Now().UTC()
  104. s.renewTill = s.endTime
  105. s.sessionKeyExpiration = s.endTime
  106. }
  107. // valid informs if the TGT is still within the valid time window
  108. func (s *session) valid() bool {
  109. s.mux.RLock()
  110. defer s.mux.RUnlock()
  111. t := time.Now().UTC()
  112. if t.Before(s.endTime) && s.authTime.Before(t) {
  113. return true
  114. }
  115. return false
  116. }
  117. // tgtDetails is a thread safe way to get the session's realm, TGT and session key values
  118. func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
  119. s.mux.RLock()
  120. defer s.mux.RUnlock()
  121. return s.realm, s.tgt, s.sessionKey
  122. }
  123. // timeDetails is a thread safe way to get the session's validity time values
  124. func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
  125. s.mux.RLock()
  126. defer s.mux.RUnlock()
  127. return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
  128. }
  129. // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
  130. func (cl *Client) enableAutoSessionRenewal(s *session) {
  131. var timer *time.Timer
  132. s.mux.Lock()
  133. s.cancel = make(chan bool, 1)
  134. s.mux.Unlock()
  135. go func(s *session) {
  136. for {
  137. s.mux.RLock()
  138. w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
  139. s.mux.RUnlock()
  140. if w < 0 {
  141. return
  142. }
  143. timer = time.NewTimer(w)
  144. select {
  145. case <-timer.C:
  146. renewal, err := cl.refreshSession(s)
  147. if err != nil {
  148. cl.Log("error refreshing session: %v", err)
  149. }
  150. if !renewal && err == nil {
  151. // end this goroutine as there will have been a new login and new auto renewal goroutine created.
  152. return
  153. }
  154. case <-s.cancel:
  155. // cancel has been called. Stop the timer and exit.
  156. timer.Stop()
  157. return
  158. }
  159. }
  160. }(s)
  161. }
  162. // renewTGT renews the client's TGT session.
  163. func (cl *Client) renewTGT(s *session) error {
  164. realm, tgt, skey := s.tgtDetails()
  165. spn := types.PrincipalName{
  166. NameType: nametype.KRB_NT_SRV_INST,
  167. NameString: []string{"krbtgt", realm},
  168. }
  169. _, tgsRep, err := cl.TGSREQGenerateAndExchange(spn, cl.Credentials.Domain(), tgt, skey, true)
  170. if err != nil {
  171. return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT for %s", realm)
  172. }
  173. s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
  174. cl.sessions.update(s)
  175. cl.Log("TGT session renewed for %s (EndTime: %v)", realm, tgsRep.DecryptedEncPart.EndTime)
  176. return nil
  177. }
  178. // refreshSession updates either through renewal or creating a new login.
  179. // The boolean indicates if the update was a renewal.
  180. func (cl *Client) refreshSession(s *session) (bool, error) {
  181. s.mux.RLock()
  182. realm := s.realm
  183. renewTill := s.renewTill
  184. s.mux.RUnlock()
  185. cl.Log("refreshing TGT session for %s", realm)
  186. if time.Now().UTC().Before(renewTill) {
  187. err := cl.renewTGT(s)
  188. return true, err
  189. }
  190. err := cl.realmLogin(realm)
  191. return false, err
  192. }
  193. // ensureValidSession makes sure there is a valid session for the realm
  194. func (cl *Client) ensureValidSession(realm string) error {
  195. s, ok := cl.sessions.get(realm)
  196. if ok {
  197. s.mux.RLock()
  198. d := s.endTime.Sub(s.authTime) / 6
  199. if s.endTime.Sub(time.Now().UTC()) > d {
  200. s.mux.RUnlock()
  201. return nil
  202. }
  203. s.mux.RUnlock()
  204. _, err := cl.refreshSession(s)
  205. return err
  206. }
  207. return cl.realmLogin(realm)
  208. }
  209. // sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
  210. func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
  211. err = cl.ensureValidSession(realm)
  212. if err != nil {
  213. return
  214. }
  215. s, ok := cl.sessions.get(realm)
  216. if !ok {
  217. err = fmt.Errorf("could not find TGT session for %s", realm)
  218. return
  219. }
  220. _, tgt, sessionKey = s.tgtDetails()
  221. return
  222. }
  223. func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
  224. s, ok := cl.sessions.get(realm)
  225. if !ok {
  226. err = fmt.Errorf("could not find TGT session for %s", realm)
  227. return
  228. }
  229. _, authTime, endTime, renewTime, sessionExp = s.timeDetails()
  230. return
  231. }
  232. // spnRealm resolves the realm name of a service principal name
  233. func (cl *Client) spnRealm(spn types.PrincipalName) string {
  234. return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
  235. }