Browse Source

service side replay cache thread safety

Jonathan Turner 8 năm trước cách đây
mục cha
commit
21f62e24b1
2 tập tin đã thay đổi với 121 bổ sung27 xóa
  1. 42 15
      service/cache.go
  2. 79 12
      service/http_test.go

+ 42 - 15
service/cache.go

@@ -30,7 +30,10 @@ them following an event that caused the server to lose track of
 recently seen authenticators.*/
 
 // Cache for tickets received from clients keyed by fully qualified client name. Used to track replay of tickets.
-type Cache map[string]clientEntries
+type Cache struct {
+	Entries map[string]clientEntries
+	mux     sync.RWMutex
+}
 
 // clientEntries holds entries of client details sent to the service.
 type clientEntries struct {
@@ -46,6 +49,24 @@ type replayCacheEntry struct {
 	CTime         time.Time // This combines the ticket's CTime and Cusec
 }
 
+func (c *Cache) getClientEntries(cname types.PrincipalName) (clientEntries, bool) {
+	c.mux.RLock()
+	defer c.mux.RUnlock()
+	ce, ok := c.Entries[cname.GetPrincipalNameString()]
+	return ce, ok
+}
+
+func (c *Cache) getClientEntry(cname types.PrincipalName, t time.Time) (replayCacheEntry, bool) {
+	if ce, ok := c.getClientEntries(cname); ok {
+		c.mux.RLock()
+		defer c.mux.RUnlock()
+		if e, ok := ce.ReplayMap[t]; ok {
+			return e, true
+		}
+	}
+	return replayCacheEntry{}, false
+}
+
 // Instance of the ServiceCache. This needs to be a singleton.
 var replayCache Cache
 var once sync.Once
@@ -54,7 +75,9 @@ var once sync.Once
 func GetReplayCache(d time.Duration) *Cache {
 	// Create a singleton of the ReplayCache and start a background thread to regularly clean out old entries
 	once.Do(func() {
-		replayCache = make(Cache)
+		replayCache = Cache{
+			Entries: make(map[string]clientEntries),
+		}
 		go func() {
 			for {
 				// TODO consider using a context here.
@@ -69,7 +92,9 @@ func GetReplayCache(d time.Duration) *Cache {
 // AddEntry adds an entry to the Cache.
 func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) {
 	ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond)
-	if ce, ok := (*c)[a.CName.GetPrincipalNameString()]; ok {
+	if ce, ok := c.getClientEntries(a.CName); ok {
+		c.mux.Lock()
+		defer c.mux.Unlock()
 		ce.ReplayMap[ct] = replayCacheEntry{
 			PresentedTime: time.Now().UTC(),
 			SName:         sname,
@@ -78,7 +103,9 @@ func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) {
 		ce.SeqNumber = a.SeqNumber
 		ce.SubKey = a.SubKey
 	} else {
-		(*c)[a.CName.GetPrincipalNameString()] = clientEntries{
+		c.mux.Lock()
+		defer c.mux.Unlock()
+		c.Entries[a.CName.GetPrincipalNameString()] = clientEntries{
 			ReplayMap: map[time.Time]replayCacheEntry{
 				ct: {
 					PresentedTime: time.Now().UTC(),
@@ -94,26 +121,26 @@ func (c *Cache) AddEntry(sname types.PrincipalName, a types.Authenticator) {
 
 // ClearOldEntries clears entries from the Cache that are older than the duration provided.
 func (c *Cache) ClearOldEntries(d time.Duration) {
-	for ck := range *c {
-		for ct, e := range (*c)[ck].ReplayMap {
+	c.mux.Lock()
+	defer c.mux.Unlock()
+	for ke, ce := range c.Entries {
+		for k, e := range ce.ReplayMap {
 			if time.Now().UTC().Sub(e.PresentedTime) > d {
-				delete((*c)[ck].ReplayMap, ct)
+				delete(ce.ReplayMap, k)
 			}
 		}
-		if len((*c)[ck].ReplayMap) == 0 {
-			delete((*c), ck)
+		if len(ce.ReplayMap) == 0 {
+			delete(c.Entries, ke)
 		}
 	}
 }
 
 // IsReplay tests if the Authenticator provided is a replay within the duration defined. If this is not a replay add the entry to the cache for tracking.
 func (c *Cache) IsReplay(sname types.PrincipalName, a types.Authenticator) bool {
-	if ck, ok := (*c)[a.CName.GetPrincipalNameString()]; ok {
-		ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond)
-		if e, ok := ck.ReplayMap[ct]; ok {
-			if e.SName.Equal(sname) {
-				return true
-			}
+	ct := a.CTime.Add(time.Duration(a.Cusec) * time.Microsecond)
+	if e, ok := c.getClientEntry(a.CName, ct); ok {
+		if e.SName.Equal(sname) {
+			return true
 		}
 	}
 	c.AddEntry(sname, a)

+ 79 - 12
service/http_test.go

@@ -111,14 +111,12 @@ func TestService_SPNEGOKRB_Replay(t *testing.T) {
 	}
 	assert.Equal(t, http.StatusOK, httpResp.StatusCode, "Status code in response to client SPNEGO request not as expected")
 
-	// A number of concurrent requests with the same ticket should be rejected due to replay
-	var wg sync.WaitGroup
-	noReq := 10
-	wg.Add(noReq)
-	for i := 0; i < noReq; i++ {
-		go httpGetReplay(t, r1, &wg)
+	// Use ticket again should be rejected
+	httpResp, err = http.DefaultClient.Do(r1)
+	if err != nil {
+		t.Fatalf("Request error: %v\n", err)
 	}
-	wg.Wait()
+	assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.")
 
 	// Form a 2nd ticket
 	st = time.Now().UTC()
@@ -164,13 +162,82 @@ func TestService_SPNEGOKRB_Replay(t *testing.T) {
 	assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.")
 }
 
-func httpGetReplay(t *testing.T, r *http.Request, wg *sync.WaitGroup) {
-	defer wg.Done()
-	httpResp, err := http.DefaultClient.Do(r)
+func TestService_SPNEGOKRB_ReplayCache_Concurrency(t *testing.T) {
+	s := httpServer()
+	defer s.Close()
+
+	cl := getClient()
+	sname := types.PrincipalName{
+		NameType:   nametype.KRB_NT_PRINCIPAL,
+		NameString: []string{"HTTP", "host.test.gokrb5"},
+	}
+	b, _ := hex.DecodeString(testdata.HTTP_KEYTAB)
+	kt, _ := keytab.Parse(b)
+	st := time.Now().UTC()
+	tkt, sessionKey, err := messages.NewTicket(cl.Credentials.CName, cl.Credentials.Realm,
+		sname, "TEST.GOKRB5",
+		types.NewKrbFlags(),
+		kt,
+		18,
+		1,
+		st,
+		st,
+		st.Add(time.Duration(24)*time.Hour),
+		st.Add(time.Duration(48)*time.Hour),
+	)
 	if err != nil {
-		t.Fatalf("Request error: %v\n", err)
+		t.Fatalf("Error getting test ticket: %v", err)
 	}
-	assert.Equal(t, http.StatusUnauthorized, httpResp.StatusCode, "Status code in response to client with no SPNEGO not as expected. Expected a replay to be detected.")
+
+	r1, _ := http.NewRequest("GET", s.URL, nil)
+	err = client.SetSPNEGOHeader(*cl.Credentials, tkt, sessionKey, r1)
+	if err != nil {
+		t.Fatalf("Error setting client SPNEGO header: %v", err)
+	}
+
+	// Form a 2nd ticket
+	st = time.Now().UTC()
+	tkt2, sessionKey2, err := messages.NewTicket(cl.Credentials.CName, cl.Credentials.Realm,
+		sname, "TEST.GOKRB5",
+		types.NewKrbFlags(),
+		kt,
+		18,
+		1,
+		st,
+		st,
+		st.Add(time.Duration(24)*time.Hour),
+		st.Add(time.Duration(48)*time.Hour),
+	)
+	if err != nil {
+		t.Fatalf("Error getting test ticket: %v", err)
+	}
+	r2, _ := http.NewRequest("GET", s.URL, nil)
+	err = client.SetSPNEGOHeader(*cl.Credentials, tkt2, sessionKey2, r2)
+	if err != nil {
+		t.Fatalf("Error setting client SPNEGO header: %v", err)
+	}
+
+	// Concurrent 1st requests should be OK
+	var wg sync.WaitGroup
+	wg.Add(2)
+	go httpGet(r1, &wg)
+	go httpGet(r2, &wg)
+	wg.Wait()
+
+	// A number of concurrent requests with the same ticket should be rejected due to replay
+	var wg2 sync.WaitGroup
+	noReq := 10
+	wg2.Add(noReq * 2)
+	for i := 0; i < noReq; i++ {
+		go httpGet(r1, &wg2)
+		go httpGet(r2, &wg2)
+	}
+	wg2.Wait()
+}
+
+func httpGet(r *http.Request, wg *sync.WaitGroup) {
+	defer wg.Done()
+	http.DefaultClient.Do(r)
 }
 
 func httpServer() *httptest.Server {