Jonathan Turner 7 лет назад
Родитель
Сommit
1929b03a14
4 измененных файлов с 49 добавлено и 70 удалено
  1. 1 6
      client/TGSExchange.go
  2. 24 2
      client/client.go
  3. 8 8
      client/client_integration_test.go
  4. 16 54
      client/session.go

+ 1 - 6
client/TGSExchange.go

@@ -67,12 +67,7 @@ func (cl *Client) GetServiceTicket(spn string) (messages.Ticket, types.Encryptio
 	princ := types.NewPrincipalName(nametype.KRB_NT_PRINCIPAL, spn)
 	realm := cl.Config.ResolveRealm(princ.NameString[len(princ.NameString)-1])
 
-	err := cl.ensureValidSession(realm)
-	if err != nil {
-		return tkt, skey, err
-	}
-
-	tgt, skey, err := cl.sessionTGTDetails(realm)
+	tgt, skey, err := cl.sessionTGT(realm)
 
 	_, tgsRep, err := cl.TGSExchange(princ, realm, tgt, skey, false, 0)
 	if err != nil {

+ 24 - 2
client/client.go

@@ -176,8 +176,8 @@ func (cl *Client) IsConfigured() (bool, error) {
 	}
 	// Client needs to have either a password, keytab or a session already (later when loading from CCache)
 	if !cl.Credentials.HasPassword() && !cl.Credentials.HasKeytab() {
-		sess, err := cl.realmSession(cl.Credentials.Realm)
-		if err != nil || sess.authTime.IsZero() {
+		authTime, _, _, _, err := cl.sessionTimes(cl.Credentials.Realm)
+		if err != nil || authTime.IsZero() {
 			return false, errors.New("client has neither a keytab nor a password set and no session")
 		}
 	}
@@ -215,6 +215,28 @@ func (cl *Client) Login() error {
 	return nil
 }
 
+// remoteRealmSession returns the session for a realm that the client is not a member of but for which there is a trust
+func (cl *Client) realmLogin(realm string) error {
+	err := cl.ensureValidSession(cl.Credentials.Realm)
+	if err != nil || realm == cl.Credentials.Realm {
+		return err
+	}
+	tgt, skey, err := cl.sessionTGT(cl.Credentials.Realm)
+
+	spn := types.PrincipalName{
+		NameType:   nametype.KRB_NT_SRV_INST,
+		NameString: []string{"krbtgt", realm},
+	}
+
+	_, tgsRep, err := cl.TGSExchange(spn, cl.Credentials.Realm, tgt, skey, false, 0)
+	if err != nil {
+		return err
+	}
+	cl.AddSession(tgsRep.Ticket, tgsRep.DecryptedEncPart)
+
+	return nil
+}
+
 // Destroy stops the auto-renewal of all sessions and removes the sessions and cache entries from the client.
 func (cl *Client) Destroy() {
 	creds := credentials.NewCredentials("", "")

+ 8 - 8
client/client_integration_test.go

@@ -416,9 +416,9 @@ func TestMultiThreadedClientSession(t *testing.T) {
 		t.Fatalf("failed to log in: %v", err)
 	}
 
-	s, err := cl.realmSession("TEST.GOKRB5")
-	if err != nil {
-		t.Fatalf("error initially getting session: %v", err)
+	s, ok := cl.sessions.get("TEST.GOKRB5")
+	if !ok {
+		t.Fatal("error initially getting session")
 	}
 	go func() {
 		for {
@@ -435,11 +435,11 @@ func TestMultiThreadedClientSession(t *testing.T) {
 	for i := 0; i < 10; i++ {
 		go func() {
 			defer wg.Done()
-			s, err := cl.realmSession("TEST.GOKRB5")
-			if err != nil {
+			tgt, _, err := cl.sessionTGT("TEST.GOKRB5")
+			if err != nil || tgt.Realm != "TEST.GOKRB5" {
 				t.Logf("error getting session: %v", err)
 			}
-			_, _, _, r, _ := s.timeDetails()
+			_, _, _, r, _ := cl.sessionTimes("TEST.GOKRB5")
 			fmt.Fprintf(ioutil.Discard, "%v", r)
 		}()
 		time.Sleep(time.Second)
@@ -730,11 +730,11 @@ func TestClient_AutoRenew_Goroutine(t *testing.T) {
 	n := runtime.NumGoroutine()
 	for i := 0; i < 6; i++ {
 		time.Sleep(time.Second * 20)
-		sess, err := cl.realmSession("TEST.GOKRB5")
+		_, endTime, _, _, err := cl.sessionTimes("TEST.GOKRB5")
 		if err != nil {
 			t.Errorf("could not get client's session: %v", err)
 		}
-		if !sess.valid() {
+		if time.Now().UTC().After(endTime) {
 			t.Fatalf("session auto update failed")
 		}
 		if runtime.NumGoroutine() > n {

+ 16 - 54
client/session.go

@@ -179,7 +179,7 @@ func (cl *Client) renewTGT(s *session) error {
 	return nil
 }
 
-// updateSession updates either through renewal or creating a new login.
+// refreshSession updates either through renewal or creating a new login.
 // The boolean indicates if the update was a renewal.
 func (cl *Client) refreshSession(s *session) (bool, error) {
 	s.mux.RLock()
@@ -190,14 +190,7 @@ func (cl *Client) refreshSession(s *session) (bool, error) {
 		err := cl.renewTGT(s)
 		return true, err
 	}
-	if realm != cl.Credentials.Realm {
-		// session is not for the client's own realm
-		_, err := cl.remoteRealmSession(realm)
-		if err != nil {
-			return false, err
-		}
-	}
-	err := cl.Login()
+	err := cl.realmLogin(realm)
 	return false, err
 }
 
@@ -214,62 +207,31 @@ func (cl *Client) ensureValidSession(realm string) error {
 		_, err := cl.refreshSession(s)
 		return err
 	}
-	if realm != cl.Credentials.Realm {
-		// not for the client's own realm
-		_, err := cl.remoteRealmSession(realm)
-		return err
-	}
-	return cl.Login()
+	return cl.realmLogin(realm)
 }
 
-// remoteRealmSession returns the session for a realm that the client is not a member of but for which there is a trust
-func (cl *Client) remoteRealmSession(realm string) (*session, error) {
-	s, ok := cl.sessions.get(cl.Credentials.Realm)
-	if !ok || !s.valid() {
-		err := cl.Login()
-		if err != nil {
-			return nil, fmt.Errorf("client was unable to login: %v", err)
-		}
-	}
-
-	spn := types.PrincipalName{
-		NameType:   nametype.KRB_NT_SRV_INST,
-		NameString: []string{"krbtgt", realm},
-	}
-
-	_, tgsRep, err := cl.TGSExchange(spn, cl.Credentials.Realm, s.tgt, s.sessionKey, false, 0)
+// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
+func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
+	err = cl.ensureValidSession(realm)
 	if err != nil {
-		return nil, err
+		return
 	}
-	cl.AddSession(tgsRep.Ticket, tgsRep.DecryptedEncPart)
-
-	cl.sessions.mux.RLock()
-	defer cl.sessions.mux.RUnlock()
-	return cl.sessions.Entries[realm], nil
-}
-
-// realmSession returns the session for the realm provided.
-func (cl *Client) realmSession(realm string) (*session, error) {
 	s, ok := cl.sessions.get(realm)
-	var err error
 	if !ok {
-		// Try to request TGT from trusted remote Realm
-		s, err = cl.remoteRealmSession(realm)
-		if err != nil {
-			return s, err
-		}
+		err = fmt.Errorf("could not find TGT session for %s", realm)
+		return
 	}
-	return s, nil
+	_, tgt, sessionKey = s.tgtDetails()
+	return
 }
 
-// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
-func (cl *Client) sessionTGTDetails(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
-	var s *session
-	s, err = cl.realmSession(realm)
-	if err != nil {
+func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
+	s, ok := cl.sessions.get(realm)
+	if !ok {
+		err = fmt.Errorf("could not find TGT session for %s", realm)
 		return
 	}
-	realm, tgt, sessionKey = s.tgtDetails()
+	_, authTime, endTime, renewTime, sessionExp = s.timeDetails()
 	return
 }