Browse Source

auth: get rid of deadlocking channel passing scheme in simpleTokenTTL

Just use the mutex instead.

Fixes #7471
Anthony Romano 8 years ago
parent
commit
1b1fabef8f
1 changed files with 32 additions and 31 deletions
  1. 32 31
      auth/simple_token.go

+ 32 - 31
auth/simple_token.go

@@ -32,27 +32,26 @@ import (
 const (
 const (
 	letters                  = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
 	letters                  = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
 	defaultSimpleTokenLength = 16
 	defaultSimpleTokenLength = 16
+)
+
+// var for testing purposes
+var (
 	simpleTokenTTL           = 5 * time.Minute
 	simpleTokenTTL           = 5 * time.Minute
 	simpleTokenTTLResolution = 1 * time.Second
 	simpleTokenTTLResolution = 1 * time.Second
 )
 )
 
 
 type simpleTokenTTLKeeper struct {
 type simpleTokenTTLKeeper struct {
-	tokens              map[string]time.Time
-	addSimpleTokenCh    chan string
-	resetSimpleTokenCh  chan string
-	deleteSimpleTokenCh chan string
-	stopCh              chan chan struct{}
-	deleteTokenFunc     func(string)
+	tokensMu        sync.Mutex
+	tokens          map[string]time.Time
+	stopCh          chan chan struct{}
+	deleteTokenFunc func(string)
 }
 }
 
 
 func NewSimpleTokenTTLKeeper(deletefunc func(string)) *simpleTokenTTLKeeper {
 func NewSimpleTokenTTLKeeper(deletefunc func(string)) *simpleTokenTTLKeeper {
 	stk := &simpleTokenTTLKeeper{
 	stk := &simpleTokenTTLKeeper{
-		tokens:              make(map[string]time.Time),
-		addSimpleTokenCh:    make(chan string, 1),
-		resetSimpleTokenCh:  make(chan string, 1),
-		deleteSimpleTokenCh: make(chan string, 1),
-		stopCh:              make(chan chan struct{}),
-		deleteTokenFunc:     deletefunc,
+		tokens:          make(map[string]time.Time),
+		stopCh:          make(chan chan struct{}),
+		deleteTokenFunc: deletefunc,
 	}
 	}
 	go stk.run()
 	go stk.run()
 	return stk
 	return stk
@@ -66,37 +65,34 @@ func (tm *simpleTokenTTLKeeper) stop() {
 }
 }
 
 
 func (tm *simpleTokenTTLKeeper) addSimpleToken(token string) {
 func (tm *simpleTokenTTLKeeper) addSimpleToken(token string) {
-	tm.addSimpleTokenCh <- token
+	tm.tokens[token] = time.Now().Add(simpleTokenTTL)
 }
 }
 
 
 func (tm *simpleTokenTTLKeeper) resetSimpleToken(token string) {
 func (tm *simpleTokenTTLKeeper) resetSimpleToken(token string) {
-	tm.resetSimpleTokenCh <- token
+	if _, ok := tm.tokens[token]; ok {
+		tm.tokens[token] = time.Now().Add(simpleTokenTTL)
+	}
 }
 }
 
 
 func (tm *simpleTokenTTLKeeper) deleteSimpleToken(token string) {
 func (tm *simpleTokenTTLKeeper) deleteSimpleToken(token string) {
-	tm.deleteSimpleTokenCh <- token
+	delete(tm.tokens, token)
 }
 }
+
 func (tm *simpleTokenTTLKeeper) run() {
 func (tm *simpleTokenTTLKeeper) run() {
 	tokenTicker := time.NewTicker(simpleTokenTTLResolution)
 	tokenTicker := time.NewTicker(simpleTokenTTLResolution)
 	defer tokenTicker.Stop()
 	defer tokenTicker.Stop()
 	for {
 	for {
 		select {
 		select {
-		case t := <-tm.addSimpleTokenCh:
-			tm.tokens[t] = time.Now().Add(simpleTokenTTL)
-		case t := <-tm.resetSimpleTokenCh:
-			if _, ok := tm.tokens[t]; ok {
-				tm.tokens[t] = time.Now().Add(simpleTokenTTL)
-			}
-		case t := <-tm.deleteSimpleTokenCh:
-			delete(tm.tokens, t)
 		case <-tokenTicker.C:
 		case <-tokenTicker.C:
 			nowtime := time.Now()
 			nowtime := time.Now()
+			tm.tokensMu.Lock()
 			for t, tokenendtime := range tm.tokens {
 			for t, tokenendtime := range tm.tokens {
 				if nowtime.After(tokenendtime) {
 				if nowtime.After(tokenendtime) {
 					tm.deleteTokenFunc(t)
 					tm.deleteTokenFunc(t)
 					delete(tm.tokens, t)
 					delete(tm.tokens, t)
 				}
 				}
 			}
 			}
+			tm.tokensMu.Unlock()
 		case waitCh := <-tm.stopCh:
 		case waitCh := <-tm.stopCh:
 			tm.tokens = make(map[string]time.Time)
 			tm.tokens = make(map[string]time.Time)
 			waitCh <- struct{}{}
 			waitCh <- struct{}{}
@@ -108,7 +104,7 @@ func (tm *simpleTokenTTLKeeper) run() {
 type tokenSimple struct {
 type tokenSimple struct {
 	indexWaiter       func(uint64) <-chan struct{}
 	indexWaiter       func(uint64) <-chan struct{}
 	simpleTokenKeeper *simpleTokenTTLKeeper
 	simpleTokenKeeper *simpleTokenTTLKeeper
-	simpleTokensMu    sync.RWMutex
+	simpleTokensMu    sync.Mutex
 	simpleTokens      map[string]string // token -> username
 	simpleTokens      map[string]string // token -> username
 }
 }
 
 
@@ -128,6 +124,7 @@ func (t *tokenSimple) genTokenPrefix() (string, error) {
 }
 }
 
 
 func (t *tokenSimple) assignSimpleTokenToUser(username, token string) {
 func (t *tokenSimple) assignSimpleTokenToUser(username, token string) {
+	t.simpleTokenKeeper.tokensMu.Lock()
 	t.simpleTokensMu.Lock()
 	t.simpleTokensMu.Lock()
 
 
 	_, ok := t.simpleTokens[token]
 	_, ok := t.simpleTokens[token]
@@ -138,18 +135,23 @@ func (t *tokenSimple) assignSimpleTokenToUser(username, token string) {
 	t.simpleTokens[token] = username
 	t.simpleTokens[token] = username
 	t.simpleTokenKeeper.addSimpleToken(token)
 	t.simpleTokenKeeper.addSimpleToken(token)
 	t.simpleTokensMu.Unlock()
 	t.simpleTokensMu.Unlock()
+	t.simpleTokenKeeper.tokensMu.Unlock()
 }
 }
 
 
 func (t *tokenSimple) invalidateUser(username string) {
 func (t *tokenSimple) invalidateUser(username string) {
+	if t.simpleTokenKeeper == nil {
+		return
+	}
+	t.simpleTokenKeeper.tokensMu.Lock()
 	t.simpleTokensMu.Lock()
 	t.simpleTokensMu.Lock()
-	defer t.simpleTokensMu.Unlock()
-
 	for token, name := range t.simpleTokens {
 	for token, name := range t.simpleTokens {
 		if strings.Compare(name, username) == 0 {
 		if strings.Compare(name, username) == 0 {
 			delete(t.simpleTokens, token)
 			delete(t.simpleTokens, token)
 			t.simpleTokenKeeper.deleteSimpleToken(token)
 			t.simpleTokenKeeper.deleteSimpleToken(token)
 		}
 		}
 	}
 	}
+	t.simpleTokensMu.Unlock()
+	t.simpleTokenKeeper.tokensMu.Unlock()
 }
 }
 
 
 func newDeleterFunc(t *tokenSimple) func(string) {
 func newDeleterFunc(t *tokenSimple) func(string) {
@@ -172,7 +174,6 @@ func (t *tokenSimple) disable() {
 		t.simpleTokenKeeper.stop()
 		t.simpleTokenKeeper.stop()
 		t.simpleTokenKeeper = nil
 		t.simpleTokenKeeper = nil
 	}
 	}
-
 	t.simpleTokensMu.Lock()
 	t.simpleTokensMu.Lock()
 	t.simpleTokens = make(map[string]string) // invalidate all tokens
 	t.simpleTokens = make(map[string]string) // invalidate all tokens
 	t.simpleTokensMu.Unlock()
 	t.simpleTokensMu.Unlock()
@@ -182,14 +183,14 @@ func (t *tokenSimple) info(ctx context.Context, token string, revision uint64) (
 	if !t.isValidSimpleToken(ctx, token) {
 	if !t.isValidSimpleToken(ctx, token) {
 		return nil, false
 		return nil, false
 	}
 	}
-
-	t.simpleTokensMu.RLock()
-	defer t.simpleTokensMu.RUnlock()
+	t.simpleTokenKeeper.tokensMu.Lock()
+	t.simpleTokensMu.Lock()
 	username, ok := t.simpleTokens[token]
 	username, ok := t.simpleTokens[token]
 	if ok {
 	if ok {
 		t.simpleTokenKeeper.resetSimpleToken(token)
 		t.simpleTokenKeeper.resetSimpleToken(token)
 	}
 	}
-
+	t.simpleTokensMu.Unlock()
+	t.simpleTokenKeeper.tokensMu.Unlock()
 	return &AuthInfo{Username: username, Revision: revision}, ok
 	return &AuthInfo{Username: username, Revision: revision}, ok
 }
 }