Browse Source

Merge pull request #7110 from mitake/reauth

etcdserver, clientv3: handle a case of expired auth token
Hitoshi Mitake 9 years ago
parent
commit
c89eae790d
4 changed files with 105 additions and 27 deletions
  1. 49 13
      clientv3/client.go
  2. 51 14
      clientv3/retry.go
  3. 3 0
      etcdserver/api/v3rpc/rpctypes/error.go
  4. 2 0
      etcdserver/api/v3rpc/util.go

+ 49 - 13
clientv3/client.go

@@ -21,6 +21,7 @@ import (
 	"net"
 	"net/url"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
@@ -46,11 +47,12 @@ type Client struct {
 	Auth
 	Maintenance
 
-	conn         *grpc.ClientConn
-	cfg          Config
-	creds        *credentials.TransportCredentials
-	balancer     *simpleBalancer
-	retryWrapper retryRpcFunc
+	conn             *grpc.ClientConn
+	cfg              Config
+	creds            *credentials.TransportCredentials
+	balancer         *simpleBalancer
+	retryWrapper     retryRpcFunc
+	retryAuthWrapper retryRpcFunc
 
 	ctx    context.Context
 	cancel context.CancelFunc
@@ -59,6 +61,8 @@ type Client struct {
 	Username string
 	// Password is a password for authentication
 	Password string
+	// tokenCred is an instance of WithPerRPCCredentials()'s argument
+	tokenCred *authTokenCredential
 }
 
 // New creates a new etcdv3 client from a given configuration.
@@ -144,7 +148,8 @@ func (c *Client) autoSync() {
 }
 
 type authTokenCredential struct {
-	token string
+	token   string
+	tokenMu *sync.RWMutex
 }
 
 func (cred authTokenCredential) RequireTransportSecurity() bool {
@@ -152,6 +157,8 @@ func (cred authTokenCredential) RequireTransportSecurity() bool {
 }
 
 func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) {
+	cred.tokenMu.RLock()
+	defer cred.tokenMu.RUnlock()
 	return map[string]string{
 		"token": cred.token,
 	}, nil
@@ -236,22 +243,50 @@ func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
 	return c.dial(endpoint)
 }
 
+func (c *Client) getToken(ctx context.Context) error {
+	var err error // return last error in a case of fail
+	var auth *authenticator
+
+	for i := 0; i < len(c.cfg.Endpoints); i++ {
+		endpoint := c.cfg.Endpoints[i]
+		host := getHost(endpoint)
+		// use dial options without dopts to avoid reusing the client balancer
+		auth, err = newAuthenticator(host, c.dialSetupOpts(endpoint))
+		if err != nil {
+			continue
+		}
+		defer auth.close()
+
+		var resp *AuthenticateResponse
+		resp, err = auth.authenticate(ctx, c.Username, c.Password)
+		if err != nil {
+			continue
+		}
+
+		c.tokenCred.tokenMu.Lock()
+		c.tokenCred.token = resp.Token
+		c.tokenCred.tokenMu.Unlock()
+
+		return nil
+	}
+
+	return err
+}
+
 func (c *Client) dial(endpoint string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
 	opts := c.dialSetupOpts(endpoint, dopts...)
 	host := getHost(endpoint)
 	if c.Username != "" && c.Password != "" {
-		// use dial options without dopts to avoid reusing the client balancer
-		auth, err := newAuthenticator(host, c.dialSetupOpts(endpoint))
-		if err != nil {
-			return nil, err
+		c.tokenCred = &authTokenCredential{
+			tokenMu: &sync.RWMutex{},
 		}
-		defer auth.close()
 
-		resp, err := auth.authenticate(c.ctx, c.Username, c.Password)
+		err := c.getToken(context.TODO())
 		if err != nil {
 			return nil, err
 		}
-		opts = append(opts, grpc.WithPerRPCCredentials(authTokenCredential{token: resp.Token}))
+
+		opts = append(opts, grpc.WithPerRPCCredentials(c.tokenCred))
 	}
 
 	// add metrics options
@@ -303,6 +338,7 @@ func newClient(cfg *Config) (*Client, error) {
 	}
 	client.conn = conn
 	client.retryWrapper = client.newRetryWrapper()
+	client.retryAuthWrapper = client.newAuthRetryWrapper()
 
 	// wait for a connection
 	if cfg.DialTimeout > 0 {

+ 51 - 14
clientv3/retry.go

@@ -33,16 +33,17 @@ func (c *Client) newRetryWrapper() retryRpcFunc {
 				return nil
 			}
 
-			// only retry if unavailable
-			if grpc.Code(err) != codes.Unavailable {
-				return err
-			}
-			// always stop retry on etcd errors
 			eErr := rpctypes.Error(err)
+			// always stop retry on etcd errors
 			if _, ok := eErr.(rpctypes.EtcdError); ok {
 				return err
 			}
 
+			// only retry if unavailable
+			if grpc.Code(err) != codes.Unavailable {
+				return err
+			}
+
 			select {
 			case <-c.balancer.ConnectNotify():
 			case <-rpcCtx.Done():
@@ -54,17 +55,52 @@ func (c *Client) newRetryWrapper() retryRpcFunc {
 	}
 }
 
-type retryKVClient struct {
-	pb.KVClient
-	retryf retryRpcFunc
+func (c *Client) newAuthRetryWrapper() retryRpcFunc {
+	return func(rpcCtx context.Context, f rpcFunc) error {
+		for {
+			err := f(rpcCtx)
+			if err == nil {
+				return nil
+			}
+
+			// always stop retry on etcd errors other than invalid auth token
+			if rpctypes.Error(err) == rpctypes.ErrInvalidAuthToken {
+				gterr := c.getToken(rpcCtx)
+				if gterr != nil {
+					return err // return the original error for simplicity
+				}
+				continue
+			}
+
+			return err
+		}
+	}
 }
 
 // RetryKVClient implements a KVClient that uses the client's FailFast retry policy.
 func RetryKVClient(c *Client) pb.KVClient {
-	return &retryKVClient{pb.NewKVClient(c.conn), c.retryWrapper}
+	retryWrite := &retryWriteKVClient{pb.NewKVClient(c.conn), c.retryWrapper}
+	return &retryKVClient{&retryWriteKVClient{retryWrite, c.retryAuthWrapper}}
+}
+
+type retryKVClient struct {
+	*retryWriteKVClient
+}
+
+func (rkv *retryKVClient) Range(ctx context.Context, in *pb.RangeRequest, opts ...grpc.CallOption) (resp *pb.RangeResponse, err error) {
+	err = rkv.retryf(ctx, func(rctx context.Context) error {
+		resp, err = rkv.retryWriteKVClient.Range(rctx, in, opts...)
+		return err
+	})
+	return resp, err
+}
+
+type retryWriteKVClient struct {
+	pb.KVClient
+	retryf retryRpcFunc
 }
 
-func (rkv *retryKVClient) Put(ctx context.Context, in *pb.PutRequest, opts ...grpc.CallOption) (resp *pb.PutResponse, err error) {
+func (rkv *retryWriteKVClient) Put(ctx context.Context, in *pb.PutRequest, opts ...grpc.CallOption) (resp *pb.PutResponse, err error) {
 	err = rkv.retryf(ctx, func(rctx context.Context) error {
 		resp, err = rkv.KVClient.Put(rctx, in, opts...)
 		return err
@@ -72,7 +108,7 @@ func (rkv *retryKVClient) Put(ctx context.Context, in *pb.PutRequest, opts ...gr
 	return resp, err
 }
 
-func (rkv *retryKVClient) DeleteRange(ctx context.Context, in *pb.DeleteRangeRequest, opts ...grpc.CallOption) (resp *pb.DeleteRangeResponse, err error) {
+func (rkv *retryWriteKVClient) DeleteRange(ctx context.Context, in *pb.DeleteRangeRequest, opts ...grpc.CallOption) (resp *pb.DeleteRangeResponse, err error) {
 	err = rkv.retryf(ctx, func(rctx context.Context) error {
 		resp, err = rkv.KVClient.DeleteRange(rctx, in, opts...)
 		return err
@@ -80,7 +116,7 @@ func (rkv *retryKVClient) DeleteRange(ctx context.Context, in *pb.DeleteRangeReq
 	return resp, err
 }
 
-func (rkv *retryKVClient) Txn(ctx context.Context, in *pb.TxnRequest, opts ...grpc.CallOption) (resp *pb.TxnResponse, err error) {
+func (rkv *retryWriteKVClient) Txn(ctx context.Context, in *pb.TxnRequest, opts ...grpc.CallOption) (resp *pb.TxnResponse, err error) {
 	err = rkv.retryf(ctx, func(rctx context.Context) error {
 		resp, err = rkv.KVClient.Txn(rctx, in, opts...)
 		return err
@@ -88,7 +124,7 @@ func (rkv *retryKVClient) Txn(ctx context.Context, in *pb.TxnRequest, opts ...gr
 	return resp, err
 }
 
-func (rkv *retryKVClient) Compact(ctx context.Context, in *pb.CompactionRequest, opts ...grpc.CallOption) (resp *pb.CompactionResponse, err error) {
+func (rkv *retryWriteKVClient) Compact(ctx context.Context, in *pb.CompactionRequest, opts ...grpc.CallOption) (resp *pb.CompactionResponse, err error) {
 	err = rkv.retryf(ctx, func(rctx context.Context) error {
 		resp, err = rkv.KVClient.Compact(rctx, in, opts...)
 		return err
@@ -103,7 +139,8 @@ type retryLeaseClient struct {
 
 // RetryLeaseClient implements a LeaseClient that uses the client's FailFast retry policy.
 func RetryLeaseClient(c *Client) pb.LeaseClient {
-	return &retryLeaseClient{pb.NewLeaseClient(c.conn), c.retryWrapper}
+	retry := &retryLeaseClient{pb.NewLeaseClient(c.conn), c.retryWrapper}
+	return &retryLeaseClient{retry, c.retryAuthWrapper}
 }
 
 func (rlc *retryLeaseClient) LeaseGrant(ctx context.Context, in *pb.LeaseGrantRequest, opts ...grpc.CallOption) (resp *pb.LeaseGrantResponse, err error) {

+ 3 - 0
etcdserver/api/v3rpc/rpctypes/error.go

@@ -52,6 +52,7 @@ var (
 	ErrGRPCRoleNotGranted       = grpc.Errorf(codes.FailedPrecondition, "etcdserver: role is not granted to the user")
 	ErrGRPCPermissionNotGranted = grpc.Errorf(codes.FailedPrecondition, "etcdserver: permission is not granted to the role")
 	ErrGRPCAuthNotEnabled       = grpc.Errorf(codes.FailedPrecondition, "etcdserver: authentication is not enabled")
+	ErrGRPCInvalidAuthToken     = grpc.Errorf(codes.Unauthenticated, "etcdserver: invalid auth token")
 
 	ErrGRPCNoLeader                   = grpc.Errorf(codes.Unavailable, "etcdserver: no leader")
 	ErrGRPCNotCapable                 = grpc.Errorf(codes.Unavailable, "etcdserver: not capable")
@@ -93,6 +94,7 @@ var (
 		grpc.ErrorDesc(ErrGRPCRoleNotGranted):       ErrGRPCRoleNotGranted,
 		grpc.ErrorDesc(ErrGRPCPermissionNotGranted): ErrGRPCPermissionNotGranted,
 		grpc.ErrorDesc(ErrGRPCAuthNotEnabled):       ErrGRPCAuthNotEnabled,
+		grpc.ErrorDesc(ErrGRPCInvalidAuthToken):     ErrGRPCInvalidAuthToken,
 
 		grpc.ErrorDesc(ErrGRPCNoLeader):                   ErrGRPCNoLeader,
 		grpc.ErrorDesc(ErrGRPCNotCapable):                 ErrGRPCNotCapable,
@@ -135,6 +137,7 @@ var (
 	ErrRoleNotGranted       = Error(ErrGRPCRoleNotGranted)
 	ErrPermissionNotGranted = Error(ErrGRPCPermissionNotGranted)
 	ErrAuthNotEnabled       = Error(ErrGRPCAuthNotEnabled)
+	ErrInvalidAuthToken     = Error(ErrGRPCInvalidAuthToken)
 
 	ErrNoLeader                   = Error(ErrGRPCNoLeader)
 	ErrNotCapable                 = Error(ErrGRPCNotCapable)

+ 2 - 0
etcdserver/api/v3rpc/util.go

@@ -93,6 +93,8 @@ func togRPCError(err error) error {
 		return rpctypes.ErrGRPCPermissionNotGranted
 	case auth.ErrAuthNotEnabled:
 		return rpctypes.ErrGRPCAuthNotEnabled
+	case etcdserver.ErrInvalidAuthToken:
+		return rpctypes.ErrGRPCInvalidAuthToken
 	default:
 		return grpc.Errorf(codes.Unknown, err.Error())
 	}