Quellcode durchsuchen

Merge pull request #7492 from heyitsanthony/simpletokenttl-deadlock

auth: get rid of deadlocking channel passing scheme in simpleTokenTTL
Hitoshi Mitake vor 8 Jahren
Ursprung
Commit
148c923c72
2 geänderte Dateien mit 81 neuen und 31 gelöschten Zeilen
  1. 32 31
      auth/simple_token.go
  2. 49 0
      auth/store_test.go

+ 32 - 31
auth/simple_token.go

@@ -32,27 +32,26 @@ import (
 const (
 	letters                  = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
 	defaultSimpleTokenLength = 16
+)
+
+// var for testing purposes
+var (
 	simpleTokenTTL           = 5 * time.Minute
 	simpleTokenTTLResolution = 1 * time.Second
 )
 
 type simpleTokenTTLKeeper struct {
-	tokens              map[string]time.Time
-	addSimpleTokenCh    chan string
-	resetSimpleTokenCh  chan string
-	deleteSimpleTokenCh chan string
-	stopCh              chan chan struct{}
-	deleteTokenFunc     func(string)
+	tokensMu        sync.Mutex
+	tokens          map[string]time.Time
+	stopCh          chan chan struct{}
+	deleteTokenFunc func(string)
 }
 
 func NewSimpleTokenTTLKeeper(deletefunc func(string)) *simpleTokenTTLKeeper {
 	stk := &simpleTokenTTLKeeper{
-		tokens:              make(map[string]time.Time),
-		addSimpleTokenCh:    make(chan string, 1),
-		resetSimpleTokenCh:  make(chan string, 1),
-		deleteSimpleTokenCh: make(chan string, 1),
-		stopCh:              make(chan chan struct{}),
-		deleteTokenFunc:     deletefunc,
+		tokens:          make(map[string]time.Time),
+		stopCh:          make(chan chan struct{}),
+		deleteTokenFunc: deletefunc,
 	}
 	go stk.run()
 	return stk
@@ -66,37 +65,34 @@ func (tm *simpleTokenTTLKeeper) stop() {
 }
 
 func (tm *simpleTokenTTLKeeper) addSimpleToken(token string) {
-	tm.addSimpleTokenCh <- token
+	tm.tokens[token] = time.Now().Add(simpleTokenTTL)
 }
 
 func (tm *simpleTokenTTLKeeper) resetSimpleToken(token string) {
-	tm.resetSimpleTokenCh <- token
+	if _, ok := tm.tokens[token]; ok {
+		tm.tokens[token] = time.Now().Add(simpleTokenTTL)
+	}
 }
 
 func (tm *simpleTokenTTLKeeper) deleteSimpleToken(token string) {
-	tm.deleteSimpleTokenCh <- token
+	delete(tm.tokens, token)
 }
+
 func (tm *simpleTokenTTLKeeper) run() {
 	tokenTicker := time.NewTicker(simpleTokenTTLResolution)
 	defer tokenTicker.Stop()
 	for {
 		select {
-		case t := <-tm.addSimpleTokenCh:
-			tm.tokens[t] = time.Now().Add(simpleTokenTTL)
-		case t := <-tm.resetSimpleTokenCh:
-			if _, ok := tm.tokens[t]; ok {
-				tm.tokens[t] = time.Now().Add(simpleTokenTTL)
-			}
-		case t := <-tm.deleteSimpleTokenCh:
-			delete(tm.tokens, t)
 		case <-tokenTicker.C:
 			nowtime := time.Now()
+			tm.tokensMu.Lock()
 			for t, tokenendtime := range tm.tokens {
 				if nowtime.After(tokenendtime) {
 					tm.deleteTokenFunc(t)
 					delete(tm.tokens, t)
 				}
 			}
+			tm.tokensMu.Unlock()
 		case waitCh := <-tm.stopCh:
 			tm.tokens = make(map[string]time.Time)
 			waitCh <- struct{}{}
@@ -108,7 +104,7 @@ func (tm *simpleTokenTTLKeeper) run() {
 type tokenSimple struct {
 	indexWaiter       func(uint64) <-chan struct{}
 	simpleTokenKeeper *simpleTokenTTLKeeper
-	simpleTokensMu    sync.RWMutex
+	simpleTokensMu    sync.Mutex
 	simpleTokens      map[string]string // token -> username
 }
 
@@ -128,6 +124,7 @@ func (t *tokenSimple) genTokenPrefix() (string, error) {
 }
 
 func (t *tokenSimple) assignSimpleTokenToUser(username, token string) {
+	t.simpleTokenKeeper.tokensMu.Lock()
 	t.simpleTokensMu.Lock()
 
 	_, ok := t.simpleTokens[token]
@@ -138,18 +135,23 @@ func (t *tokenSimple) assignSimpleTokenToUser(username, token string) {
 	t.simpleTokens[token] = username
 	t.simpleTokenKeeper.addSimpleToken(token)
 	t.simpleTokensMu.Unlock()
+	t.simpleTokenKeeper.tokensMu.Unlock()
 }
 
 func (t *tokenSimple) invalidateUser(username string) {
+	if t.simpleTokenKeeper == nil {
+		return
+	}
+	t.simpleTokenKeeper.tokensMu.Lock()
 	t.simpleTokensMu.Lock()
-	defer t.simpleTokensMu.Unlock()
-
 	for token, name := range t.simpleTokens {
 		if strings.Compare(name, username) == 0 {
 			delete(t.simpleTokens, token)
 			t.simpleTokenKeeper.deleteSimpleToken(token)
 		}
 	}
+	t.simpleTokensMu.Unlock()
+	t.simpleTokenKeeper.tokensMu.Unlock()
 }
 
 func newDeleterFunc(t *tokenSimple) func(string) {
@@ -172,7 +174,6 @@ func (t *tokenSimple) disable() {
 		t.simpleTokenKeeper.stop()
 		t.simpleTokenKeeper = nil
 	}
-
 	t.simpleTokensMu.Lock()
 	t.simpleTokens = make(map[string]string) // invalidate all tokens
 	t.simpleTokensMu.Unlock()
@@ -182,14 +183,14 @@ func (t *tokenSimple) info(ctx context.Context, token string, revision uint64) (
 	if !t.isValidSimpleToken(ctx, token) {
 		return nil, false
 	}
-
-	t.simpleTokensMu.RLock()
-	defer t.simpleTokensMu.RUnlock()
+	t.simpleTokenKeeper.tokensMu.Lock()
+	t.simpleTokensMu.Lock()
 	username, ok := t.simpleTokens[token]
 	if ok {
 		t.simpleTokenKeeper.resetSimpleToken(token)
 	}
-
+	t.simpleTokensMu.Unlock()
+	t.simpleTokenKeeper.tokensMu.Unlock()
 	return &AuthInfo{Username: username, Revision: revision}, ok
 }
 

+ 49 - 0
auth/store_test.go

@@ -15,9 +15,12 @@
 package auth
 
 import (
+	"fmt"
 	"os"
 	"reflect"
+	"sync"
 	"testing"
+	"time"
 
 	"github.com/coreos/etcd/auth/authpb"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
@@ -582,3 +585,49 @@ func contains(array []string, str string) bool {
 	}
 	return false
 }
+
+func TestHammerSimpleAuthenticate(t *testing.T) {
+	// set TTL values low to try to trigger races
+	oldTTL, oldTTLRes := simpleTokenTTL, simpleTokenTTLResolution
+	defer func() {
+		simpleTokenTTL = oldTTL
+		simpleTokenTTLResolution = oldTTLRes
+	}()
+	simpleTokenTTL = 10 * time.Millisecond
+	simpleTokenTTLResolution = simpleTokenTTL
+	users := make(map[string]struct{})
+
+	as, tearDown := setupAuthStore(t)
+	defer tearDown(t)
+
+	// create lots of users
+	for i := 0; i < 50; i++ {
+		u := fmt.Sprintf("user-%d", i)
+		ua := &pb.AuthUserAddRequest{Name: u, Password: "123"}
+		if _, err := as.UserAdd(ua); err != nil {
+			t.Fatal(err)
+		}
+		users[u] = struct{}{}
+	}
+
+	// hammer on authenticate with lots of users
+	for i := 0; i < 10; i++ {
+		var wg sync.WaitGroup
+		wg.Add(len(users))
+		for u := range users {
+			go func(user string) {
+				defer wg.Done()
+				token := fmt.Sprintf("%s(%d)", user, i)
+				ctx := context.WithValue(context.WithValue(context.TODO(), "index", uint64(1)), "simpleToken", token)
+				if _, err := as.Authenticate(ctx, user, "123"); err != nil {
+					t.Fatal(err)
+				}
+				if _, err := as.AuthInfoFromCtx(ctx); err != nil {
+					t.Fatal(err)
+				}
+			}(u)
+		}
+		time.Sleep(time.Millisecond)
+		wg.Wait()
+	}
+}