Jonathan Turner 8 лет назад
Родитель
Сommit
d982df7044
3 измененных файлов с 73 добавлено и 9 удалено
  1. 47 0
      client/client_integration_test.go
  2. 25 8
      client/session.go
  3. 1 1
      messages/KDCReq.go

+ 47 - 0
client/client_integration_test.go

@@ -7,6 +7,7 @@ import (
 	"bytes"
 	"encoding/hex"
 	"io"
+	"io/ioutil"
 	"net/http"
 	"os"
 	"os/exec"
@@ -393,6 +394,52 @@ func TestMultiThreadedClientUse(t *testing.T) {
 	wg2.Wait()
 }
 
+func TestMultiThreadedClientSession(t *testing.T) {
+	b, _ := hex.DecodeString(testdata.TESTUSER1_KEYTAB)
+	kt, _ := keytab.Parse(b)
+	c, _ := config.NewConfigFromString(testdata.TEST_KRB5CONF)
+	addr := os.Getenv("TEST_KDC_ADDR")
+	if addr == "" {
+		addr = testdata.TEST_KDC_ADDR
+	}
+	c.Realms[0].KDC = []string{addr + ":" + testdata.TEST_KDC}
+	cl := NewClientWithKeytab("testuser1", "TEST.GOKRB5", kt)
+	cl.WithConfig(c)
+	err := cl.Login()
+	if err != nil {
+		t.Fatalf("failed to log in: %v", err)
+	}
+
+	s, err := cl.GetSessionFromRealm("TEST.GOKRB5")
+	if err != nil {
+		t.Fatalf("error initially getting session: %v", err)
+	}
+	go func() {
+		for {
+			err := cl.renewTGT(s)
+			if err != nil {
+				t.Logf("error renewing TGT: %v", err)
+			}
+			time.Sleep(time.Millisecond * 100)
+		}
+	}()
+
+	var wg sync.WaitGroup
+	wg.Add(10)
+	for i := 0; i < 10; i++ {
+		go func() {
+			defer wg.Done()
+			s, err := cl.GetSessionFromRealm("TEST.GOKRB5")
+			if err != nil {
+				t.Logf("error getting session: %v", err)
+			}
+			fmt.Fprintf(ioutil.Discard, "%v", s.RenewTill)
+		}()
+		time.Sleep(time.Second)
+	}
+	wg.Wait()
+}
+
 func spnegoGet(cl *Client) error {
 	url := os.Getenv("TEST_HTTP_URL")
 	if url == "" {

+ 25 - 8
client/session.go

@@ -27,6 +27,19 @@ type session struct {
 	SessionKey           types.EncryptionKey
 	SessionKeyExpiration time.Time
 	cancel               chan bool
+	mux                  sync.RWMutex
+}
+
+func (s *session) update(tkt messages.Ticket, dep messages.EncKDCRepPart) {
+	s.mux.Lock()
+	defer s.mux.Unlock()
+	s.AuthTime = dep.AuthTime
+	s.AuthTime = dep.AuthTime
+	s.EndTime = dep.EndTime
+	s.RenewTill = dep.RenewTill
+	s.TGT = tkt
+	s.SessionKey = dep.Key
+	s.SessionKeyExpiration = dep.KeyExpiration
 }
 
 // AddSession adds a session for a realm with a TGT to the client's session cache.
@@ -88,13 +101,7 @@ func (cl *Client) renewTGT(s *session) error {
 	if err != nil {
 		return krberror.Errorf(err, krberror.KRBMsgError, "Error renewing TGT")
 	}
-	s.AuthTime = tgsRep.DecryptedEncPart.AuthTime
-	s.AuthTime = tgsRep.DecryptedEncPart.AuthTime
-	s.EndTime = tgsRep.DecryptedEncPart.EndTime
-	s.RenewTill = tgsRep.DecryptedEncPart.RenewTill
-	s.TGT = tgsRep.Ticket
-	s.SessionKey = tgsRep.DecryptedEncPart.Key
-	s.SessionKeyExpiration = tgsRep.DecryptedEncPart.KeyExpiration
+	s.update(tgsRep.Ticket, tgsRep.DecryptedEncPart)
 	return nil
 }
 
@@ -136,7 +143,17 @@ func (cl *Client) getSessionFromRemoteRealm(realm string) (*session, error) {
 // GetSessionFromRealm returns the session for the realm provided.
 func (cl *Client) GetSessionFromRealm(realm string) (*session, error) {
 	cl.sessions.mux.RLock()
-	sess, ok := cl.sessions.Entries[realm]
+	s, ok := cl.sessions.Entries[realm]
+	// Create another session to return to prevent race condition.
+	sess := &session{
+		Realm:                s.Realm,
+		AuthTime:             s.AuthTime,
+		EndTime:              s.EndTime,
+		RenewTill:            s.RenewTill,
+		TGT:                  s.TGT,
+		SessionKey:           s.SessionKey,
+		SessionKeyExpiration: s.SessionKeyExpiration,
+	}
 	cl.sessions.mux.RUnlock()
 	if !ok {
 		// Try to request TGT from trusted remote Realm

+ 1 - 1
messages/KDCReq.go

@@ -109,7 +109,7 @@ func NewASReq(realm string, c *config.Config, cname, sname types.PrincipalName)
 	}
 	t := time.Now().UTC()
 	// Copy the default options to make this thread safe
-	var kopts asn1.BitString
+	kopts := types.NewKrbFlags()
 	copy(kopts.Bytes, c.LibDefaults.KDCDefaultOptions.Bytes)
 	kopts.BitLength = c.LibDefaults.KDCDefaultOptions.BitLength
 	a := ASReq{