Jonathan Turner 9 лет назад
Родитель
Сommit
cf666c5ebf
7 измененных файлов с 114 добавлено и 56 удалено
  1. 1 0
      client/ASExchange.go
  2. 30 12
      client/TGSExchange.go
  3. 22 26
      client/cache.go
  4. 41 0
      client/session.go
  5. 3 1
      debug.go
  6. 1 1
      messages/KDCRep.go
  7. 16 16
      messages/KDCReq.go

+ 1 - 0
client/ASExchange.go

@@ -8,6 +8,7 @@ import (
 	"github.com/jcmturner/gokrb5/iana/patype"
 	"github.com/jcmturner/gokrb5/messages"
 	"github.com/jcmturner/gokrb5/types"
+	"os"
 	"sort"
 )
 

+ 30 - 12
client/TGSExchange.go

@@ -3,38 +3,56 @@ package client
 import (
 	"errors"
 	"fmt"
+	"github.com/jcmturner/gokrb5/iana/nametype"
 	"github.com/jcmturner/gokrb5/messages"
+	"github.com/jcmturner/gokrb5/types"
+	"strings"
 )
 
 // Perform a TGS exchange to retrieve a ticket to the specified SPN.
 // The ticket retrieved is added to the client's cache.
-func (cl *Client) TGSExchange(spn string) error {
+func (cl *Client) TGSExchange(spn types.PrincipalName, renewal bool) (tgsReq messages.TGSReq, tgsRep messages.TGSRep, err error) {
 	if cl.Session == nil {
-		return errors.New("Error client does not have a session. Client needs to login first")
+		return tgsReq, tgsRep, errors.New("Error client does not have a session. Client needs to login first")
 	}
-	tgs, err := messages.NewTGSReq(cl.Credentials.Username, cl.Config, cl.Session.TGT, cl.Session.SessionKey, "HTTP/host.test.gokrb5")
+	tgsReq, err = messages.NewTGSReq(cl.Credentials.Username, cl.Config, cl.Session.TGT, cl.Session.SessionKey, spn, renewal)
 	if err != nil {
-		return fmt.Errorf("Error generating New TGS_REQ: %v", err)
+		return tgsReq, tgsRep, fmt.Errorf("Error generating New TGS_REQ: %v", err)
 	}
-	b, err := tgs.Marshal()
+	b, err := tgsReq.Marshal()
 	if err != nil {
-		return fmt.Errorf("Error marshalling TGS_REQ: %v", err)
+		return tgsReq, tgsRep, fmt.Errorf("Error marshalling TGS_REQ: %v", err)
 	}
 	r, err := cl.SendToKDC(b)
 	if err != nil {
-		return fmt.Errorf("Error sending TGS_REQ to KDC: %v", err)
+		return tgsReq, tgsRep, fmt.Errorf("Error sending TGS_REQ to KDC: %v", err)
 	}
-	var tgsRep messages.TGSRep
 	err = tgsRep.Unmarshal(r)
 	if err != nil {
-		return fmt.Errorf("Error unmarshalling TGS_REP: %v", err)
+		return tgsReq, tgsRep, fmt.Errorf("Error unmarshalling TGS_REP: %v", err)
 	}
 	err = tgsRep.DecryptEncPart(cl.Session.SessionKey)
 	if err != nil {
-		return fmt.Errorf("Error decrypting EncPart of TGS_REP: %v", err)
+		return tgsReq, tgsRep, fmt.Errorf("Error decrypting EncPart of TGS_REP: %v", err)
 	}
-	if ok, err := tgsRep.IsValid(cl.Config, tgs); !ok {
-		return fmt.Errorf("TGS_REP is not valid: %v", err)
+	if ok, err := tgsRep.IsValid(cl.Config, tgsReq); !ok {
+		return tgsReq, tgsRep, fmt.Errorf("TGS_REP is not valid: %v", err)
+	}
+	return tgsReq, tgsRep, nil
+}
+
+// Make a request to get a service ticket for the SPN specified
+// SPN format: <SERVICE>/<FQDN> Eg. HTTP/www.example.com
+// The ticket will be added to the client's ticket cache
+func (cl *Client) GetServiceTicket(spn string) error {
+	s := strings.Split(spn, "/")
+	princ := types.PrincipalName{
+		NameType:   nametype.KRB_NT_PRINCIPAL,
+		NameString: s,
+	}
+	_, tgsRep, err := cl.TGSExchange(princ, false)
+	if err != nil {
+		return err
 	}
 	cl.Cache.AddEntry(tgsRep.Ticket, tgsRep.DecryptedEncPart.AuthTime, tgsRep.DecryptedEncPart.EndTime, tgsRep.DecryptedEncPart.RenewTill)
 	return nil

+ 22 - 26
client/cache.go

@@ -1,8 +1,6 @@
 package client
 
 import (
-	"errors"
-	"fmt"
 	"github.com/jcmturner/gokrb5/types"
 	"strings"
 	"time"
@@ -19,7 +17,6 @@ type CacheEntry struct {
 	AuthTime  time.Time
 	EndTime   time.Time
 	RenewTill time.Time
-	AutoRenew bool
 }
 
 // Create a new client ticket cache.
@@ -48,14 +45,6 @@ func (c *Cache) GetTicket(spn string) (types.Ticket, bool) {
 	return tkt, false
 }
 
-// Renew a ticket in the cache for the specified SPN.
-func (c *Cache) RenewEntry(spn string) error {
-	if e, ok := c.GetEntry(spn); ok {
-		return e.Renew()
-	}
-	return fmt.Errorf("No entry for this SPN: %s", spn)
-}
-
 // Add a ticket to the cache.
 func (c *Cache) AddEntry(tkt types.Ticket, authTime, endTime, renewTill time.Time) {
 	(*c).Entries[strings.Join(tkt.SName.NameString, "/")] = CacheEntry{
@@ -71,21 +60,28 @@ func (c *Cache) RemoveEntry(spn string) {
 	delete(c.Entries, spn)
 }
 
-// Enable background auto renew of the ticket for the specified SPN.
-func (c *Cache) EnableAutoRenew(spn string) error {
-	return nil
-}
+// Renew a ticket in the cache for the specified SPN.
+//func (c *Cache) RenewEntry(spn string) error {
+//	if e, ok := c.GetEntry(spn); ok {
+//		return e.Renew()
+//	}
+//	return fmt.Errorf("No entry for this SPN: %s", spn)
+//}
 
-// Disable background auto renew of the ticket for the specified SPN.
-func (c *Cache) DisableAutoRenew(spn string) error {
-	return nil
-}
+// Enable background auto renew of the ticket for the specified SPN.
+//func (cl *Client) EnableAutoRenew(spn string) {
+//	go func() {
+//		for {
+//
+//		}
+//	}()
+//}
 
 // Renew the cache entry.
-func (e *CacheEntry) Renew() error {
-	if time.Now().After(e.RenewTill) {
-		return errors.New("Past renew till time. Cannot renew.")
-	}
-	//TODO put renew action here
-	return nil
-}
+//func (e *CacheEntry) Renew() error {
+//	if time.Now().After(e.RenewTill) {
+//		return errors.New("Past renew till time. Cannot renew.")
+//	}
+//	//TODO put renew action here
+//	return nil
+//}

+ 41 - 0
client/session.go

@@ -1,7 +1,10 @@
 package client
 
 import (
+	"fmt"
+	"github.com/jcmturner/gokrb5/iana/nametype"
 	"github.com/jcmturner/gokrb5/types"
+	"os"
 	"time"
 )
 
@@ -14,3 +17,41 @@ type Session struct {
 	SessionKey           types.EncryptionKey
 	SessionKeyExpiration time.Time
 }
+
+func (cl *Client) RenewTGT() error {
+	spn := types.PrincipalName{
+		NameType:   nametype.KRB_NT_SRV_INST,
+		NameString: []string{"krbtgt", cl.Session.TGT.Realm},
+	}
+	_, tgsRep, err := cl.TGSExchange(spn, true)
+	if err != nil {
+		return err
+	}
+	cl.Session = &Session{
+		AuthTime:             tgsRep.DecryptedEncPart.AuthTime,
+		EndTime:              tgsRep.DecryptedEncPart.EndTime,
+		RenewTill:            tgsRep.DecryptedEncPart.RenewTill,
+		TGT:                  tgsRep.Ticket,
+		SessionKey:           tgsRep.DecryptedEncPart.Key,
+		SessionKeyExpiration: tgsRep.DecryptedEncPart.KeyExpiration,
+	}
+	return nil
+}
+
+func (cl *Client) EnableAutoSessionRenewal() {
+	go func() {
+		for {
+			//Wait until one minute before endtime
+			w := (time.Until(cl.Session.EndTime) * 5) / 6
+			if w < 0 {
+				return
+			}
+			time.Sleep(w)
+			if time.Now().Before(cl.Session.RenewTill) {
+				cl.RenewTGT()
+			} else {
+				cl.Login()
+			}
+		}
+	}()
+}

+ 3 - 1
debug.go

@@ -43,6 +43,7 @@ const pa149rep = "6b8202f3308202efa003020105a10302010ba22e302c302aa103020113a223
 func main() {
 
 	TestTGSReq()
+	time.Sleep(time.Duration(3) * time.Hour)
 }
 
 func NoPA() {
@@ -143,8 +144,9 @@ func TestTGSReq() {
 	if err != nil {
 		fmt.Fprintf(os.Stderr, "Error on AS_REQ: %v\n", err)
 	}
-	err = cl.TGSExchange("HTTP/host.test.gokrb5")
+	err = cl.GetServiceTicket("HTTP/host.test.gokrb5")
 	if err != nil {
 		fmt.Fprintf(os.Stderr, "Error on TGS_REQ: %v\n", err)
 	}
+	cl.EnableAutoSessionRenewal()
 }

+ 1 - 1
messages/KDCRep.go

@@ -286,7 +286,7 @@ func (k *TGSRep) IsValid(cfg *config.Config, tgsReq TGSReq) (bool, error) {
 	if len(tgsReq.ReqBody.Addresses) > 0 {
 		//TODO compare if address list is the same
 	}
-	if time.Since(k.DecryptedEncPart.AuthTime) > cfg.LibDefaults.Clockskew || time.Until(k.DecryptedEncPart.AuthTime) > cfg.LibDefaults.Clockskew {
+	if !tgsReq.Renewal && (time.Since(k.DecryptedEncPart.AuthTime) > cfg.LibDefaults.Clockskew || time.Until(k.DecryptedEncPart.AuthTime) > cfg.LibDefaults.Clockskew) {
 		return false, fmt.Errorf("Clock skew with KDC too large. Greater than %v seconds", cfg.LibDefaults.Clockskew.Seconds())
 	}
 	return true, nil

+ 16 - 16
messages/KDCReq.go

@@ -17,7 +17,6 @@ import (
 	"github.com/jcmturner/gokrb5/iana/patype"
 	"github.com/jcmturner/gokrb5/types"
 	"math/rand"
-	"strings"
 	"time"
 )
 
@@ -33,6 +32,7 @@ type KDCReqFields struct {
 	MsgType int
 	PAData  types.PADataSequence
 	ReqBody KDCReqBody
+	Renewal bool
 }
 
 type ASReq struct {
@@ -115,36 +115,31 @@ func NewASReq(c *config.Config, username string) ASReq {
 		types.SetFlag(&a.ReqBody.KDCOptions, types.Proxiable)
 	}
 	if c.LibDefaults.Renew_lifetime != 0 {
+		types.SetFlag(&a.ReqBody.KDCOptions, types.Renewable)
 		a.ReqBody.RTime = t.Add(c.LibDefaults.Renew_lifetime)
 	}
 	return a
 }
 
-func NewTGSReq(username string, c *config.Config, TGT types.Ticket, sessionKey types.EncryptionKey, spn string) (TGSReq, error) {
+func NewTGSReq(username string, c *config.Config, TGT types.Ticket, sessionKey types.EncryptionKey, spn types.PrincipalName, renewal bool) (TGSReq, error) {
 	nonce := int(rand.Int31())
 	t := time.Now()
-	s := strings.Split(spn, "/")
 	a := TGSReq{
 		KDCReqFields{
 			PVNO:    iana.PVNO,
 			MsgType: msgtype.KRB_TGS_REQ,
 			ReqBody: KDCReqBody{
 				KDCOptions: types.NewKrbFlags(),
-				Realm:      c.ResolveRealm(s[len(s) - 1]),
-				SName: types.PrincipalName{
-					NameType:   nametype.KRB_NT_PRINCIPAL,
-					NameString: s,
-				},
-				Till:  t.Add(c.LibDefaults.Ticket_lifetime),
-				Nonce: nonce,
-				EType: c.LibDefaults.Default_tgs_enctype_ids,
+				Realm:      c.ResolveRealm(spn.NameString[len(spn.NameString)-1]),
+				SName:      spn,
+				Till:       t.Add(c.LibDefaults.Ticket_lifetime),
+				Nonce:      nonce,
+				EType:      c.LibDefaults.Default_tgs_enctype_ids,
 			},
+			Renewal: renewal,
 		},
 	}
-	types.SetFlag(&a.ReqBody.KDCOptions, types.Forwardable)
-	types.SetFlag(&a.ReqBody.KDCOptions, types.Renewable)
-	types.SetFlag(&a.ReqBody.KDCOptions, types.Canonicalize)
-	/*if c.LibDefaults.Forwardable {
+	if c.LibDefaults.Forwardable {
 		types.SetFlag(&a.ReqBody.KDCOptions, types.Forwardable)
 	}
 	if c.LibDefaults.Canonicalize {
@@ -154,8 +149,13 @@ func NewTGSReq(username string, c *config.Config, TGT types.Ticket, sessionKey t
 		types.SetFlag(&a.ReqBody.KDCOptions, types.Proxiable)
 	}
 	if c.LibDefaults.Renew_lifetime != 0 {
+		types.SetFlag(&a.ReqBody.KDCOptions, types.Renewable)
 		a.ReqBody.RTime = t.Add(c.LibDefaults.Renew_lifetime)
-	}*/
+	}
+	if renewal {
+		types.SetFlag(&a.ReqBody.KDCOptions, types.Renew)
+		types.SetFlag(&a.ReqBody.KDCOptions, types.Renewable)
+	}
 	auth := types.NewAuthenticator(c.LibDefaults.Default_realm, username)
 	// Add the CName to make validation of the reply easier
 	a.ReqBody.CName = auth.CName