Browse Source

auth: fix panic using WithRoot and improve JWT coverage

Sam Batschelet 7 years ago
parent
commit
b30a1166e0
4 changed files with 38 additions and 14 deletions
  1. 6 0
      auth/jwt_test.go
  2. 6 3
      auth/store.go
  3. 18 9
      auth/store_test.go
  4. 8 2
      tests/e2e/ctl_v3_auth_test.go

+ 6 - 0
auth/jwt_test.go

@@ -16,6 +16,7 @@ package auth
 
 
 import (
 import (
 	"context"
 	"context"
+	"fmt"
 	"testing"
 	"testing"
 
 
 	"go.uber.org/zap"
 	"go.uber.org/zap"
@@ -94,3 +95,8 @@ func TestJWTBad(t *testing.T) {
 	}
 	}
 	opts["priv-key"] = jwtPrivKey
 	opts["priv-key"] = jwtPrivKey
 }
 }
+
+// testJWTOpts is useful for passing to NewTokenProvider which requires a string.
+func testJWTOpts() string {
+	return fmt.Sprintf("%s,pub-key=%s,priv-key=%s,sign-method=RS256", tokenTypeJWT, jwtPubKey, jwtPrivKey)
+}

+ 6 - 3
auth/store.go

@@ -72,6 +72,9 @@ const (
 	rootUser = "root"
 	rootUser = "root"
 	rootRole = "root"
 	rootRole = "root"
 
 
+	tokenTypeSimple = "simple"
+	tokenTypeJWT    = "jwt"
+
 	revBytesLen = 8
 	revBytesLen = 8
 )
 )
 
 
@@ -1255,7 +1258,7 @@ func NewTokenProvider(
 	}
 	}
 
 
 	switch tokenType {
 	switch tokenType {
-	case "simple":
+	case tokenTypeSimple:
 		if lg != nil {
 		if lg != nil {
 			lg.Warn("simple token is not cryptographically signed")
 			lg.Warn("simple token is not cryptographically signed")
 		} else {
 		} else {
@@ -1263,7 +1266,7 @@ func NewTokenProvider(
 		}
 		}
 		return newTokenProviderSimple(lg, indexWaiter), nil
 		return newTokenProviderSimple(lg, indexWaiter), nil
 
 
-	case "jwt":
+	case tokenTypeJWT:
 		return newTokenProviderJWT(lg, typeSpecificOpts)
 		return newTokenProviderJWT(lg, typeSpecificOpts)
 
 
 	case "":
 	case "":
@@ -1289,7 +1292,7 @@ func (as *authStore) WithRoot(ctx context.Context) context.Context {
 	}
 	}
 
 
 	var ctxForAssign context.Context
 	var ctxForAssign context.Context
-	if ts := as.tokenProvider.(*tokenSimple); ts != nil {
+	if ts, ok := as.tokenProvider.(*tokenSimple); ok && ts != nil {
 		ctx1 := context.WithValue(ctx, AuthenticateParamIndex{}, uint64(0))
 		ctx1 := context.WithValue(ctx, AuthenticateParamIndex{}, uint64(0))
 		prefix, err := ts.genTokenPrefix()
 		prefix, err := ts.genTokenPrefix()
 		if err != nil {
 		if err != nil {

+ 18 - 9
auth/store_test.go

@@ -48,7 +48,7 @@ func TestNewAuthStoreRevision(t *testing.T) {
 	b, tPath := backend.NewDefaultTmpBackend()
 	b, tPath := backend.NewDefaultTmpBackend()
 	defer os.Remove(tPath)
 	defer os.Remove(tPath)
 
 
-	tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter)
+	tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -78,7 +78,7 @@ func TestNewAuthStoreBcryptCost(t *testing.T) {
 	b, tPath := backend.NewDefaultTmpBackend()
 	b, tPath := backend.NewDefaultTmpBackend()
 	defer os.Remove(tPath)
 	defer os.Remove(tPath)
 
 
-	tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter)
+	tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -98,7 +98,7 @@ func TestNewAuthStoreBcryptCost(t *testing.T) {
 func setupAuthStore(t *testing.T) (store *authStore, teardownfunc func(t *testing.T)) {
 func setupAuthStore(t *testing.T) (store *authStore, teardownfunc func(t *testing.T)) {
 	b, tPath := backend.NewDefaultTmpBackend()
 	b, tPath := backend.NewDefaultTmpBackend()
 
 
-	tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter)
+	tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -535,7 +535,7 @@ func TestAuthInfoFromCtxRace(t *testing.T) {
 	b, tPath := backend.NewDefaultTmpBackend()
 	b, tPath := backend.NewDefaultTmpBackend()
 	defer os.Remove(tPath)
 	defer os.Remove(tPath)
 
 
-	tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter)
+	tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -601,7 +601,7 @@ func TestRecoverFromSnapshot(t *testing.T) {
 
 
 	as.Close()
 	as.Close()
 
 
-	tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter)
+	tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -683,7 +683,7 @@ func TestRolesOrder(t *testing.T) {
 	b, tPath := backend.NewDefaultTmpBackend()
 	b, tPath := backend.NewDefaultTmpBackend()
 	defer os.Remove(tPath)
 	defer os.Remove(tPath)
 
 
-	tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter)
+	tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -724,12 +724,21 @@ func TestRolesOrder(t *testing.T) {
 	}
 	}
 }
 }
 
 
-// TestAuthInfoFromCtxWithRoot ensures "WithRoot" properly embeds token in the context.
-func TestAuthInfoFromCtxWithRoot(t *testing.T) {
+func TestAuthInfoFromCtxWithRootSimple(t *testing.T) {
+	testAuthInfoFromCtxWithRoot(t, tokenTypeSimple)
+}
+
+func TestAuthInfoFromCtxWithRootJWT(t *testing.T) {
+	opts := testJWTOpts()
+	testAuthInfoFromCtxWithRoot(t, opts)
+}
+
+// testAuthInfoFromCtxWithRoot ensures "WithRoot" properly embeds token in the context.
+func testAuthInfoFromCtxWithRoot(t *testing.T, opts string) {
 	b, tPath := backend.NewDefaultTmpBackend()
 	b, tPath := backend.NewDefaultTmpBackend()
 	defer os.Remove(tPath)
 	defer os.Remove(tPath)
 
 
-	tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter)
+	tp, err := NewTokenProvider(zap.NewExample(), opts, dummyIndexWaiter)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}

+ 8 - 2
tests/e2e/ctl_v3_auth_test.go

@@ -30,6 +30,7 @@ func TestCtlV3AuthRoleUpdate(t *testing.T)          { testCtl(t, authRoleUpdateT
 func TestCtlV3AuthUserDeleteDuringOps(t *testing.T) { testCtl(t, authUserDeleteDuringOpsTest) }
 func TestCtlV3AuthUserDeleteDuringOps(t *testing.T) { testCtl(t, authUserDeleteDuringOpsTest) }
 func TestCtlV3AuthRoleRevokeDuringOps(t *testing.T) { testCtl(t, authRoleRevokeDuringOpsTest) }
 func TestCtlV3AuthRoleRevokeDuringOps(t *testing.T) { testCtl(t, authRoleRevokeDuringOpsTest) }
 func TestCtlV3AuthTxn(t *testing.T)                 { testCtl(t, authTestTxn) }
 func TestCtlV3AuthTxn(t *testing.T)                 { testCtl(t, authTestTxn) }
+func TestCtlV3AuthTxnJWT(t *testing.T)              { testCtl(t, authTestTxn, withCfg(configJWT)) }
 func TestCtlV3AuthPrefixPerm(t *testing.T)          { testCtl(t, authTestPrefixPerm) }
 func TestCtlV3AuthPrefixPerm(t *testing.T)          { testCtl(t, authTestPrefixPerm) }
 func TestCtlV3AuthMemberAdd(t *testing.T)           { testCtl(t, authTestMemberAdd) }
 func TestCtlV3AuthMemberAdd(t *testing.T)           { testCtl(t, authTestMemberAdd) }
 func TestCtlV3AuthMemberRemove(t *testing.T) {
 func TestCtlV3AuthMemberRemove(t *testing.T) {
@@ -41,11 +42,15 @@ func TestCtlV3AuthRevokeWithDelete(t *testing.T) { testCtl(t, authTestRevokeWith
 func TestCtlV3AuthInvalidMgmt(t *testing.T)      { testCtl(t, authTestInvalidMgmt) }
 func TestCtlV3AuthInvalidMgmt(t *testing.T)      { testCtl(t, authTestInvalidMgmt) }
 func TestCtlV3AuthFromKeyPerm(t *testing.T)      { testCtl(t, authTestFromKeyPerm) }
 func TestCtlV3AuthFromKeyPerm(t *testing.T)      { testCtl(t, authTestFromKeyPerm) }
 func TestCtlV3AuthAndWatch(t *testing.T)         { testCtl(t, authTestWatch) }
 func TestCtlV3AuthAndWatch(t *testing.T)         { testCtl(t, authTestWatch) }
+func TestCtlV3AuthAndWatchJWT(t *testing.T)      { testCtl(t, authTestWatch, withCfg(configJWT)) }
 
 
 func TestCtlV3AuthLeaseTestKeepAlive(t *testing.T)         { testCtl(t, authLeaseTestKeepAlive) }
 func TestCtlV3AuthLeaseTestKeepAlive(t *testing.T)         { testCtl(t, authLeaseTestKeepAlive) }
 func TestCtlV3AuthLeaseTestTimeToLiveExpired(t *testing.T) { testCtl(t, authLeaseTestTimeToLiveExpired) }
 func TestCtlV3AuthLeaseTestTimeToLiveExpired(t *testing.T) { testCtl(t, authLeaseTestTimeToLiveExpired) }
 func TestCtlV3AuthLeaseGrantLeases(t *testing.T)           { testCtl(t, authLeaseTestLeaseGrantLeases) }
 func TestCtlV3AuthLeaseGrantLeases(t *testing.T)           { testCtl(t, authLeaseTestLeaseGrantLeases) }
-func TestCtlV3AuthLeaseRevoke(t *testing.T)                { testCtl(t, authLeaseTestLeaseRevoke) }
+func TestCtlV3AuthLeaseGrantLeasesJWT(t *testing.T) {
+	testCtl(t, authLeaseTestLeaseGrantLeases, withCfg(configJWT))
+}
+func TestCtlV3AuthLeaseRevoke(t *testing.T) { testCtl(t, authLeaseTestLeaseRevoke) }
 
 
 func TestCtlV3AuthRoleGet(t *testing.T)  { testCtl(t, authTestRoleGet) }
 func TestCtlV3AuthRoleGet(t *testing.T)  { testCtl(t, authTestRoleGet) }
 func TestCtlV3AuthUserGet(t *testing.T)  { testCtl(t, authTestUserGet) }
 func TestCtlV3AuthUserGet(t *testing.T)  { testCtl(t, authTestUserGet) }
@@ -55,7 +60,8 @@ func TestCtlV3AuthDefrag(t *testing.T) { testCtl(t, authTestDefrag) }
 func TestCtlV3AuthEndpointHealth(t *testing.T) {
 func TestCtlV3AuthEndpointHealth(t *testing.T) {
 	testCtl(t, authTestEndpointHealth, withQuorum())
 	testCtl(t, authTestEndpointHealth, withQuorum())
 }
 }
-func TestCtlV3AuthSnapshot(t *testing.T) { testCtl(t, authTestSnapshot) }
+func TestCtlV3AuthSnapshot(t *testing.T)    { testCtl(t, authTestSnapshot) }
+func TestCtlV3AuthSnapshotJWT(t *testing.T) { testCtl(t, authTestSnapshot, withCfg(configJWT)) }
 func TestCtlV3AuthCertCNAndUsername(t *testing.T) {
 func TestCtlV3AuthCertCNAndUsername(t *testing.T) {
 	testCtl(t, authTestCertCNAndUsername, withCfg(configClientTLSCertAuth))
 	testCtl(t, authTestCertCNAndUsername, withCfg(configClientTLSCertAuth))
 }
 }