Browse Source

Merge pull request #7729 from heyitsanthony/fix-auth-token-crash

auth: protect simpleToken with single mutex and check if enabled
Anthony Romano 8 years ago
parent
commit
0b19921ec0
2 changed files with 59 additions and 38 deletions
  1. 13 29
      auth/simple_token.go
  2. 46 9
      integration/v3_auth_test.go

+ 13 - 29
auth/simple_token.go

@@ -41,20 +41,10 @@ var (
 )
 
 type simpleTokenTTLKeeper struct {
-	tokensMu        sync.Mutex
 	tokens          map[string]time.Time
 	stopCh          chan chan struct{}
 	deleteTokenFunc func(string)
-}
-
-func NewSimpleTokenTTLKeeper(deletefunc func(string)) *simpleTokenTTLKeeper {
-	stk := &simpleTokenTTLKeeper{
-		tokens:          make(map[string]time.Time),
-		stopCh:          make(chan chan struct{}),
-		deleteTokenFunc: deletefunc,
-	}
-	go stk.run()
-	return stk
+	mu              *sync.Mutex
 }
 
 func (tm *simpleTokenTTLKeeper) stop() {
@@ -85,14 +75,14 @@ func (tm *simpleTokenTTLKeeper) run() {
 		select {
 		case <-tokenTicker.C:
 			nowtime := time.Now()
-			tm.tokensMu.Lock()
+			tm.mu.Lock()
 			for t, tokenendtime := range tm.tokens {
 				if nowtime.After(tokenendtime) {
 					tm.deleteTokenFunc(t)
 					delete(tm.tokens, t)
 				}
 			}
-			tm.tokensMu.Unlock()
+			tm.mu.Unlock()
 		case waitCh := <-tm.stopCh:
 			tm.tokens = make(map[string]time.Time)
 			waitCh <- struct{}{}
@@ -124,9 +114,7 @@ func (t *tokenSimple) genTokenPrefix() (string, error) {
 }
 
 func (t *tokenSimple) assignSimpleTokenToUser(username, token string) {
-	t.simpleTokenKeeper.tokensMu.Lock()
 	t.simpleTokensMu.Lock()
-
 	_, ok := t.simpleTokens[token]
 	if ok {
 		plog.Panicf("token %s is alredy used", token)
@@ -135,14 +123,12 @@ func (t *tokenSimple) assignSimpleTokenToUser(username, token string) {
 	t.simpleTokens[token] = username
 	t.simpleTokenKeeper.addSimpleToken(token)
 	t.simpleTokensMu.Unlock()
-	t.simpleTokenKeeper.tokensMu.Unlock()
 }
 
 func (t *tokenSimple) invalidateUser(username string) {
 	if t.simpleTokenKeeper == nil {
 		return
 	}
-	t.simpleTokenKeeper.tokensMu.Lock()
 	t.simpleTokensMu.Lock()
 	for token, name := range t.simpleTokens {
 		if strings.Compare(name, username) == 0 {
@@ -151,22 +137,22 @@ func (t *tokenSimple) invalidateUser(username string) {
 		}
 	}
 	t.simpleTokensMu.Unlock()
-	t.simpleTokenKeeper.tokensMu.Unlock()
 }
 
-func newDeleterFunc(t *tokenSimple) func(string) {
-	return func(tk string) {
-		t.simpleTokensMu.Lock()
-		defer t.simpleTokensMu.Unlock()
+func (t *tokenSimple) enable() {
+	delf := func(tk string) {
 		if username, ok := t.simpleTokens[tk]; ok {
 			plog.Infof("deleting token %s for user %s", tk, username)
 			delete(t.simpleTokens, tk)
 		}
 	}
-}
-
-func (t *tokenSimple) enable() {
-	t.simpleTokenKeeper = NewSimpleTokenTTLKeeper(newDeleterFunc(t))
+	t.simpleTokenKeeper = &simpleTokenTTLKeeper{
+		tokens:          make(map[string]time.Time),
+		stopCh:          make(chan chan struct{}),
+		deleteTokenFunc: delf,
+		mu:              &t.simpleTokensMu,
+	}
+	go t.simpleTokenKeeper.run()
 }
 
 func (t *tokenSimple) disable() {
@@ -183,14 +169,12 @@ func (t *tokenSimple) info(ctx context.Context, token string, revision uint64) (
 	if !t.isValidSimpleToken(ctx, token) {
 		return nil, false
 	}
-	t.simpleTokenKeeper.tokensMu.Lock()
 	t.simpleTokensMu.Lock()
 	username, ok := t.simpleTokens[token]
-	if ok {
+	if ok && t.simpleTokenKeeper != nil {
 		t.simpleTokenKeeper.resetSimpleToken(token)
 	}
 	t.simpleTokensMu.Unlock()
-	t.simpleTokenKeeper.tokensMu.Unlock()
 	return &AuthInfo{Username: username, Revision: revision}, ok
 }
 

+ 46 - 9
integration/v3_auth_test.go

@@ -20,6 +20,7 @@ import (
 
 	"golang.org/x/net/context"
 
+	"github.com/coreos/etcd/clientv3"
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 	"github.com/coreos/etcd/pkg/testutil"
@@ -35,23 +36,59 @@ func TestV3AuthEmptyUserGet(t *testing.T) {
 	defer cancel()
 
 	api := toGRPC(clus.Client(0))
-	auth := api.Auth
+	authSetupRoot(t, api.Auth)
 
-	if _, err := auth.UserAdd(ctx, &pb.AuthUserAddRequest{Name: "root", Password: "123"}); err != nil {
+	_, err := api.KV.Range(ctx, &pb.RangeRequest{Key: []byte("abc")})
+	if !eqErrGRPC(err, rpctypes.ErrUserEmpty) {
+		t.Fatalf("got %v, expected %v", err, rpctypes.ErrUserEmpty)
+	}
+}
+
+// TestV3AuthTokenWithDisable tests that auth won't crash if
+// given a valid token when authentication is disabled
+func TestV3AuthTokenWithDisable(t *testing.T) {
+	defer testutil.AfterTest(t)
+	clus := NewClusterV3(t, &ClusterConfig{Size: 1})
+	defer clus.Terminate(t)
+
+	authSetupRoot(t, toGRPC(clus.Client(0)).Auth)
+
+	c, cerr := clientv3.New(clientv3.Config{Endpoints: clus.Client(0).Endpoints(), Username: "root", Password: "123"})
+	if cerr != nil {
+		t.Fatal(cerr)
+	}
+	defer c.Close()
+
+	rctx, cancel := context.WithCancel(context.TODO())
+	donec := make(chan struct{})
+	go func() {
+		defer close(donec)
+		for rctx.Err() == nil {
+			c.Put(rctx, "abc", "def")
+		}
+	}()
+
+	time.Sleep(10 * time.Millisecond)
+	if _, err := c.AuthDisable(context.TODO()); err != nil {
 		t.Fatal(err)
 	}
-	if _, err := auth.RoleAdd(ctx, &pb.AuthRoleAddRequest{Name: "root"}); err != nil {
+	time.Sleep(10 * time.Millisecond)
+
+	cancel()
+	<-donec
+}
+
+func authSetupRoot(t *testing.T, auth pb.AuthClient) {
+	if _, err := auth.UserAdd(context.TODO(), &pb.AuthUserAddRequest{Name: "root", Password: "123"}); err != nil {
 		t.Fatal(err)
 	}
-	if _, err := auth.UserGrantRole(ctx, &pb.AuthUserGrantRoleRequest{User: "root", Role: "root"}); err != nil {
+	if _, err := auth.RoleAdd(context.TODO(), &pb.AuthRoleAddRequest{Name: "root"}); err != nil {
 		t.Fatal(err)
 	}
-	if _, err := auth.AuthEnable(ctx, &pb.AuthEnableRequest{}); err != nil {
+	if _, err := auth.UserGrantRole(context.TODO(), &pb.AuthUserGrantRoleRequest{User: "root", Role: "root"}); err != nil {
 		t.Fatal(err)
 	}
-
-	_, err := api.KV.Range(ctx, &pb.RangeRequest{Key: []byte("abc")})
-	if !eqErrGRPC(err, rpctypes.ErrUserEmpty) {
-		t.Fatalf("got %v, expected %v", err, rpctypes.ErrUserEmpty)
+	if _, err := auth.AuthEnable(context.TODO(), &pb.AuthEnableRequest{}); err != nil {
+		t.Fatal(err)
 	}
 }