session.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package client
  2. import (
  3. "fmt"
  4. "sync"
  5. "time"
  6. "gopkg.in/jcmturner/gokrb5.v6/iana/nametype"
  7. "gopkg.in/jcmturner/gokrb5.v6/krberror"
  8. "gopkg.in/jcmturner/gokrb5.v6/messages"
  9. "gopkg.in/jcmturner/gokrb5.v6/types"
  10. )
  11. // Sessions keyed on the realm name
  12. type sessions struct {
  13. Entries map[string]*session
  14. mux sync.RWMutex
  15. }
  16. func (s *sessions) destroy() {
  17. s.mux.Lock()
  18. defer s.mux.Unlock()
  19. for k, e := range s.Entries {
  20. e.destroy()
  21. delete(s.Entries, k)
  22. }
  23. }
  24. // Client session struct.
  25. type session struct {
  26. realm string
  27. authTime time.Time
  28. endTime time.Time
  29. renewTill time.Time
  30. tgt messages.Ticket
  31. sessionKey types.EncryptionKey
  32. sessionKeyExpiration time.Time
  33. cancel chan bool
  34. mux sync.RWMutex
  35. }
  36. func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  37. s.mux.Lock()
  38. defer s.mux.Unlock()
  39. s.authTime = dep.AuthTime
  40. s.endTime = dep.EndTime
  41. s.renewTill = dep.RenewTill
  42. s.tgt = tgt
  43. s.sessionKey = dep.Key
  44. s.sessionKeyExpiration = dep.KeyExpiration
  45. }
  46. func (s *session) destroy() {
  47. s.mux.Lock()
  48. defer s.mux.Unlock()
  49. s.cancel <- true
  50. s.endTime = time.Now().UTC()
  51. s.renewTill = s.endTime
  52. s.sessionKeyExpiration = s.endTime
  53. }
  54. func (s *session) valid() bool {
  55. s.mux.RLock()
  56. defer s.mux.RUnlock()
  57. t := time.Now().UTC()
  58. if t.Before(s.endTime) && s.authTime.Before(t) {
  59. return true
  60. }
  61. return false
  62. }
  63. // AddSession adds a session for a realm with a TGT to the client's session cache.
  64. // A goroutine is started to automatically renew the TGT before expiry.
  65. func (cl *Client) AddSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
  66. cl.sessions.mux.Lock()
  67. defer cl.sessions.mux.Unlock()
  68. s := &session{
  69. realm: tgt.SName.NameString[1],
  70. authTime: dep.AuthTime,
  71. endTime: dep.EndTime,
  72. renewTill: dep.RenewTill,
  73. tgt: tgt,
  74. sessionKey: dep.Key,
  75. sessionKeyExpiration: dep.KeyExpiration,
  76. cancel: make(chan bool, 1),
  77. }
  78. // if a session already exists for this, cancel its auto renew.
  79. if i, ok := cl.sessions.Entries[tgt.SName.NameString[1]]; ok {
  80. i.cancel <- true
  81. }
  82. cl.sessions.Entries[tgt.SName.NameString[1]] = s
  83. cl.enableAutoSessionRenewal(s)
  84. }
  85. // enableAutoSessionRenewal turns on the automatic renewal for the client's TGT session.
  86. func (cl *Client) enableAutoSessionRenewal(s *session) {
  87. var timer *time.Timer
  88. go func(s *session) {
  89. for {
  90. s.mux.RLock()
  91. w := (s.endTime.Sub(time.Now().UTC()) * 5) / 6
  92. s.mux.RUnlock()
  93. if w < 0 {
  94. return
  95. }
  96. timer = time.NewTimer(w)
  97. select {
  98. case <-timer.C:
  99. renewal, err := cl.updateSession(s)
  100. if !renewal && err == nil {
  101. // end this goroutine as there will have been a new login and new auto renewal goroutine created.
  102. return
  103. }
  104. case <-s.cancel:
  105. // cancel has been called. Stop the timer and exit.
  106. timer.Stop()
  107. return
  108. }
  109. }
  110. }(s)
  111. }
  112. // RenewTGT renews the client's TGT session.
  113. func (cl *Client) renewTGT(s *session) error {
  114. spn := types.PrincipalName{
  115. NameType: nametype.KRB_NT_SRV_INST,
  116. NameString: []string{"krbtgt", s.realm},
  117. }
  118. _, tgsRep, err := cl.TGSExchange(spn, s.tgt.Realm, s.tgt, s.sessionKey, true, 0)
  119. if err != nil {
  120. return krberror.Errorf(err, krberror.KRBMsgError, "error renewing TGT")
  121. }
  122. s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
  123. return nil
  124. }
  125. // updateSession updates either through renewal or creating a new login.
  126. // The boolean indicates if the update was a renewal.
  127. func (cl *Client) updateSession(s *session) (bool, error) {
  128. if time.Now().UTC().Before(s.renewTill) {
  129. err := cl.renewTGT(s)
  130. return true, err
  131. }
  132. err := cl.Login()
  133. return false, err
  134. }
  135. func (cl *Client) sessionFromRemoteRealm(realm string) (*session, error) {
  136. cl.sessions.mux.RLock()
  137. sess, ok := cl.sessions.Entries[cl.Credentials.Realm]
  138. cl.sessions.mux.RUnlock()
  139. if !ok {
  140. return nil, fmt.Errorf("client does not have a session for realm %s, login first", cl.Credentials.Realm)
  141. }
  142. spn := types.PrincipalName{
  143. NameType: nametype.KRB_NT_SRV_INST,
  144. NameString: []string{"krbtgt", realm},
  145. }
  146. _, tgsRep, err := cl.TGSExchange(spn, cl.Credentials.Realm, sess.tgt, sess.sessionKey, false, 0)
  147. if err != nil {
  148. return nil, err
  149. }
  150. cl.AddSession(tgsRep.Ticket, tgsRep.DecryptedEncPart)
  151. cl.sessions.mux.RLock()
  152. defer cl.sessions.mux.RUnlock()
  153. return cl.sessions.Entries[realm], nil
  154. }
  155. // GetSessionFromRealm returns the session for the realm provided.
  156. func (cl *Client) sessionFromRealm(realm string) (sess *session, err error) {
  157. cl.sessions.mux.RLock()
  158. s, ok := cl.sessions.Entries[realm]
  159. cl.sessions.mux.RUnlock()
  160. if !ok {
  161. // Try to request TGT from trusted remote Realm
  162. s, err = cl.sessionFromRemoteRealm(realm)
  163. if err != nil {
  164. return
  165. }
  166. }
  167. // Create another session to return to prevent race condition.
  168. sess = &session{
  169. realm: s.realm,
  170. authTime: s.authTime,
  171. endTime: s.endTime,
  172. renewTill: s.renewTill,
  173. tgt: s.tgt,
  174. sessionKey: s.sessionKey,
  175. sessionKeyExpiration: s.sessionKeyExpiration,
  176. }
  177. return
  178. }
  179. // GetSessionFromPrincipalName returns the session for the realm of the principal provided.
  180. func (cl *Client) sessionFromPrincipalName(spn types.PrincipalName) (*session, error) {
  181. realm := cl.Config.ResolveRealm(spn.NameString[len(spn.NameString)-1])
  182. return cl.sessionFromRealm(realm)
  183. }