session.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package client
  2. import (
  3. "fmt"
  4. "strings"
  5. "sync"
  6. "time"
  7. "gopkg.in/jcmturner/gokrb5.v6/iana/nametype"
  8. "gopkg.in/jcmturner/gokrb5.v6/krberror"
  9. "gopkg.in/jcmturner/gokrb5.v6/messages"
  10. "gopkg.in/jcmturner/gokrb5.v6/types"
  11. )
  12. // sessions hold TGTs and are keyed on the realm name
  13. type sessions struct {
  14. Entries map[string]*session
  15. mux sync.RWMutex
  16. }
  17. // destroy erases all sessions
  18. func (s *sessions) destroy() {
  19. s.mux.Lock()
  20. defer s.mux.Unlock()
  21. for k, e := range s.Entries {
  22. e.destroy()
  23. delete(s.Entries, k)
  24. }
  25. }
  26. // update replaces a session with the one provided or adds it as a new one
  27. func (s *sessions) update(sess *session) {
  28. s.mux.Lock()
  29. defer s.mux.Unlock()
  30. // if a session already exists for this, cancel its auto renew.
  31. if i, ok := s.Entries[sess.realm]; ok {
  32. if i != sess {
  33. // Session in the sessions cache is not the same as one provided.
  34. // Cancel the one in the cache and add this one.
  35. i.mux.Lock()
  36. defer i.mux.Unlock()
  37. i.cancel <- true
  38. s.Entries[sess.realm] = sess
  39. return
  40. }
  41. }
  42. // No session for this realm was found so just add it
  43. s.Entries[sess.realm] = sess
  44. }
  45. // get returns the session for the realm specified
  46. func (s *sessions) get(realm string) (*session, bool) {
  47. s.mux.RLock()
  48. defer s.mux.RUnlock()
  49. sess, ok := s.Entries[realm]
  50. return sess, ok
  51. }
  52. // session holds the TGT details for a realm
  53. type session struct {
  54. realm string
  55. authTime time.Time
  56. endTime time.Time
  57. renewTill time.Time
  58. tgt messages.Ticket
  59. sessionKey types.EncryptionKey
  60. sessionKeyExpiration time.Time
  61. cancel chan bool
  62. mux sync.RWMutex
  63. }
  64. // AddSession adds a session for a realm with a TGT to the client's session cache.
  65. // A goroutine is started to automatically renew the TGT before expiry.
  66. func (cl *Client) AddSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  67. if strings.ToLower(tgt.SName.NameString[0]) != "krbtgt" {
  68. // Not a TGT
  69. return
  70. }
  71. s := &session{
  72. realm: tgt.SName.NameString[len(tgt.SName.NameString)-1],
  73. authTime: dep.AuthTime,
  74. endTime: dep.EndTime,
  75. renewTill: dep.RenewTill,
  76. tgt: tgt,
  77. sessionKey: dep.Key,
  78. sessionKeyExpiration: dep.KeyExpiration,
  79. }
  80. cl.sessions.update(s)
  81. cl.enableAutoSessionRenewal(s)
  82. }
  83. // update overwrites the session details with those from the TGT and decrypted encPart
  84. func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  85. s.mux.Lock()
  86. defer s.mux.Unlock()
  87. s.authTime = dep.AuthTime
  88. s.endTime = dep.EndTime
  89. s.renewTill = dep.RenewTill
  90. s.tgt = tgt
  91. s.sessionKey = dep.Key
  92. s.sessionKeyExpiration = dep.KeyExpiration
  93. }
  94. // destroy will cancel any auto renewal of the session and set the expiration times to the current time
  95. func (s *session) destroy() {
  96. s.mux.Lock()
  97. defer s.mux.Unlock()
  98. if s.cancel != nil {
  99. s.cancel <- true
  100. }
  101. s.endTime = time.Now().UTC()
  102. s.renewTill = s.endTime
  103. s.sessionKeyExpiration = s.endTime
  104. }
  105. // valid informs if the TGT is still within the valid time window
  106. func (s *session) valid() bool {
  107. s.mux.RLock()
  108. defer s.mux.RUnlock()
  109. t := time.Now().UTC()
  110. if t.Before(s.endTime) && s.authTime.Before(t) {
  111. return true
  112. }
  113. return false
  114. }
  115. // tgtDetails is a thread safe way to get the session's realm, TGT and session key values
  116. func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
  117. s.mux.RLock()
  118. defer s.mux.RUnlock()
  119. return s.realm, s.tgt, s.sessionKey
  120. }
  121. // timeDetails is a thread safe way to get the session's validity time values
  122. func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
  123. s.mux.RLock()
  124. defer s.mux.RUnlock()
  125. return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
  126. }
  127. // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
  128. func (cl *Client) enableAutoSessionRenewal(s *session) {
  129. var timer *time.Timer
  130. s.mux.Lock()
  131. s.cancel = make(chan bool, 1)
  132. s.mux.Unlock()
  133. go func(s *session) {
  134. for {
  135. s.mux.RLock()
  136. w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
  137. s.mux.RUnlock()
  138. if w < 0 {
  139. return
  140. }
  141. timer = time.NewTimer(w)
  142. select {
  143. case <-timer.C:
  144. renewal, err := cl.refreshSession(s)
  145. if !renewal && err == nil {
  146. // end this goroutine as there will have been a new login and new auto renewal goroutine created.
  147. return
  148. }
  149. case <-s.cancel:
  150. // cancel has been called. Stop the timer and exit.
  151. timer.Stop()
  152. return
  153. }
  154. }
  155. }(s)
  156. }
  157. // renewTGT renews the client's TGT session.
  158. func (cl *Client) renewTGT(s *session) error {
  159. realm, tgt, skey := s.tgtDetails()
  160. spn := types.PrincipalName{
  161. NameType: nametype.KRB_NT_SRV_INST,
  162. NameString: []string{"krbtgt", realm},
  163. }
  164. _, tgsRep, err := cl.TGSExchange(spn, cl.Credentials.Realm, tgt, skey, true, 0)
  165. if err != nil {
  166. return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT")
  167. }
  168. s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
  169. cl.sessions.update(s)
  170. return nil
  171. }
  172. // refreshSession updates either through renewal or creating a new login.
  173. // The boolean indicates if the update was a renewal.
  174. func (cl *Client) refreshSession(s *session) (bool, error) {
  175. s.mux.RLock()
  176. realm := s.realm
  177. renewTill := s.renewTill
  178. s.mux.RUnlock()
  179. if time.Now().UTC().Before(renewTill) {
  180. err := cl.renewTGT(s)
  181. return true, err
  182. }
  183. err := cl.realmLogin(realm)
  184. return false, err
  185. }
  186. // ensureValidSession makes sure there is a valid session for the realm
  187. func (cl *Client) ensureValidSession(realm string) error {
  188. s, ok := cl.sessions.get(realm)
  189. if ok {
  190. s.mux.RLock()
  191. d := s.endTime.Sub(s.authTime) / 6
  192. if s.endTime.Sub(time.Now().UTC()) > d {
  193. s.mux.RUnlock()
  194. return nil
  195. }
  196. s.mux.RUnlock()
  197. _, err := cl.refreshSession(s)
  198. return err
  199. }
  200. return cl.realmLogin(realm)
  201. }
  202. // sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
  203. func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
  204. err = cl.ensureValidSession(realm)
  205. if err != nil {
  206. return
  207. }
  208. s, ok := cl.sessions.get(realm)
  209. if !ok {
  210. err = fmt.Errorf("could not find TGT session for %s", realm)
  211. return
  212. }
  213. _, tgt, sessionKey = s.tgtDetails()
  214. return
  215. }
  216. func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
  217. s, ok := cl.sessions.get(realm)
  218. if !ok {
  219. err = fmt.Errorf("could not find TGT session for %s", realm)
  220. return
  221. }
  222. _, authTime, endTime, renewTime, sessionExp = s.timeDetails()
  223. return
  224. }
  225. // spnRealm resolves the realm name of a service principal name
  226. func (cl *Client) spnRealm(spn types.PrincipalName) string {
  227. return cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
  228. }