Browse Source

Merge pull request #5361 from mitake/auth-v3-token-credential

RFC: *: attach auth token as a gRPC credential
Xiang Li 9 years ago
parent
commit
7014f6861d

+ 26 - 8
clientv3/auth.go

@@ -52,9 +52,6 @@ type Auth interface {
 	// AuthDisable disables auth of an etcd cluster.
 	AuthDisable(ctx context.Context) (*AuthDisableResponse, error)
 
-	// Authenticate does authenticate with given user name and password.
-	Authenticate(ctx context.Context, name string, password string) (*AuthenticateResponse, error)
-
 	// UserAdd adds a new user to an etcd cluster.
 	UserAdd(ctx context.Context, name string, password string) (*AuthUserAddResponse, error)
 
@@ -100,11 +97,6 @@ func (auth *auth) AuthDisable(ctx context.Context) (*AuthDisableResponse, error)
 	return (*AuthDisableResponse)(resp), rpctypes.Error(err)
 }
 
-func (auth *auth) Authenticate(ctx context.Context, name string, password string) (*AuthenticateResponse, error) {
-	resp, err := auth.remote.Authenticate(ctx, &pb.AuthenticateRequest{Name: name, Password: password})
-	return (*AuthenticateResponse)(resp), rpctypes.Error(err)
-}
-
 func (auth *auth) UserAdd(ctx context.Context, name string, password string) (*AuthUserAddResponse, error) {
 	resp, err := auth.remote.UserAdd(ctx, &pb.AuthUserAddRequest{Name: name, Password: password})
 	return (*AuthUserAddResponse)(resp), rpctypes.Error(err)
@@ -146,3 +138,29 @@ func StrToPermissionType(s string) (PermissionType, error) {
 	}
 	return PermissionType(-1), fmt.Errorf("invalid permission type: %s", s)
 }
+
+type authenticator struct {
+	conn   *grpc.ClientConn // conn in-use
+	remote pb.AuthClient
+}
+
+func (auth *authenticator) authenticate(ctx context.Context, name string, password string) (*AuthenticateResponse, error) {
+	resp, err := auth.remote.Authenticate(ctx, &pb.AuthenticateRequest{Name: name, Password: password})
+	return (*AuthenticateResponse)(resp), rpctypes.Error(err)
+}
+
+func (auth *authenticator) close() {
+	auth.conn.Close()
+}
+
+func newAuthenticator(endpoint string, opts []grpc.DialOption) (*authenticator, error) {
+	conn, err := grpc.Dial(endpoint, opts...)
+	if err != nil {
+		return nil, err
+	}
+
+	return &authenticator{
+		conn:   conn,
+		remote: pb.NewAuthClient(conn),
+	}, nil
+}

+ 41 - 1
clientv3/client.go

@@ -65,6 +65,11 @@ type Client struct {
 	// newconnc is closed on successful connect and set to a fresh channel
 	newconnc    chan struct{}
 	lastConnErr error
+
+	// Username is a username of authentication
+	Username string
+	// Password is a password of authentication
+	Password string
 }
 
 // New creates a new etcdv3 client from a given configuration.
@@ -132,6 +137,20 @@ func (c *Client) Errors() (errs []error) {
 	return errs
 }
 
+type authTokenCredential struct {
+	token string
+}
+
+func (cred authTokenCredential) RequireTransportSecurity() bool {
+	return false
+}
+
+func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) {
+	return map[string]string{
+		"token": cred.token,
+	}, nil
+}
+
 // Dial establishes a connection for a given endpoint using the client's config
 func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
 	opts := []grpc.DialOption{
@@ -160,6 +179,21 @@ func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
 	}
 	opts = append(opts, grpc.WithDialer(f))
 
+	if c.Username != "" && c.Password != "" {
+		auth, err := newAuthenticator(endpoint, opts)
+		if err != nil {
+			return nil, err
+		}
+		defer auth.close()
+
+		resp, err := auth.authenticate(c.ctx, c.Username, c.Password)
+		if err != nil {
+			return nil, err
+		}
+
+		opts = append(opts, grpc.WithPerRPCCredentials(authTokenCredential{token: resp.Token}))
+	}
+
 	conn, err := grpc.Dial(endpoint, opts...)
 	if err != nil {
 		return nil, err
@@ -183,9 +217,10 @@ func newClient(cfg *Config) (*Client, error) {
 		c := credentials.NewTLS(cfg.TLS)
 		creds = &c
 	}
+
 	// use a temporary skeleton client to bootstrap first connection
 	ctx, cancel := context.WithCancel(context.TODO())
-	conn, err := cfg.RetryDialer(&Client{cfg: *cfg, creds: creds, ctx: ctx})
+	conn, err := cfg.RetryDialer(&Client{cfg: *cfg, creds: creds, ctx: ctx, Username: cfg.Username, Password: cfg.Password})
 	if err != nil {
 		return nil, err
 	}
@@ -199,6 +234,11 @@ func newClient(cfg *Config) (*Client, error) {
 		newconnc: make(chan struct{}),
 	}
 
+	if cfg.Username != "" && cfg.Password != "" {
+		client.Username = cfg.Username
+		client.Password = cfg.Password
+	}
+
 	go client.connMonitor()
 
 	client.Cluster = NewCluster(client)

+ 6 - 0
clientv3/config.go

@@ -43,6 +43,12 @@ type Config struct {
 
 	// Logger is the logger used by client library.
 	Logger Logger
+
+	// Username is a username of authentication
+	Username string
+
+	// Password is a password of authentication
+	Password string
 }
 
 type yamlConfig struct {

+ 0 - 44
clientv3/integration/auth_test.go

@@ -1,44 +0,0 @@
-// Copyright 2016 The etcd Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package integration
-
-import (
-	"testing"
-
-	"github.com/coreos/etcd/clientv3"
-	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
-	"github.com/coreos/etcd/integration"
-	"github.com/coreos/etcd/pkg/testutil"
-	"golang.org/x/net/context"
-)
-
-func TestAuthError(t *testing.T) {
-	defer testutil.AfterTest(t)
-
-	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1})
-	defer clus.Terminate(t)
-
-	authapi := clientv3.NewAuth(clus.RandClient())
-
-	_, err := authapi.UserAdd(context.TODO(), "foo", "bar")
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	_, err = authapi.Authenticate(context.TODO(), "foo", "bar111")
-	if err != rpctypes.ErrAuthFailed {
-		t.Fatalf("expected %v, got %v", rpctypes.ErrAuthFailed, err)
-	}
-}

+ 1 - 1
etcdctl/ctlv3/command/ep_command.go

@@ -70,7 +70,7 @@ func epHealthCommandFunc(cmd *cobra.Command, args []string) {
 	dt := dialTimeoutFromCmd(cmd)
 	cfgs := []*v3.Config{}
 	for _, ep := range endpoints {
-		cfg, err := newClientCfg([]string{ep}, dt, sec)
+		cfg, err := newClientCfg([]string{ep}, dt, sec, nil)
 		if err != nil {
 			ExitWithError(ExitBadArgs, err)
 		}

+ 47 - 4
etcdctl/ctlv3/command/global.go

@@ -19,8 +19,10 @@ import (
 	"errors"
 	"io"
 	"io/ioutil"
+	"strings"
 	"time"
 
+	"github.com/bgentry/speakeasy"
 	"github.com/coreos/etcd/clientv3"
 	"github.com/coreos/etcd/pkg/flags"
 	"github.com/coreos/etcd/pkg/transport"
@@ -40,6 +42,8 @@ type GlobalFlags struct {
 
 	OutputFormat string
 	IsHex        bool
+
+	User string
 }
 
 type secureCfg struct {
@@ -51,6 +55,11 @@ type secureCfg struct {
 	insecureSkipVerify bool
 }
 
+type authCfg struct {
+	username string
+	password string
+}
+
 var display printer = &simplePrinter{}
 
 func initDisplayFromCmd(cmd *cobra.Command) {
@@ -76,14 +85,15 @@ func mustClientFromCmd(cmd *cobra.Command) *clientv3.Client {
 	}
 	dialTimeout := dialTimeoutFromCmd(cmd)
 	sec := secureCfgFromCmd(cmd)
+	auth := authCfgFromCmd(cmd)
 
 	initDisplayFromCmd(cmd)
 
-	return mustClient(endpoints, dialTimeout, sec)
+	return mustClient(endpoints, dialTimeout, sec, auth)
 }
 
-func mustClient(endpoints []string, dialTimeout time.Duration, scfg *secureCfg) *clientv3.Client {
-	cfg, err := newClientCfg(endpoints, dialTimeout, scfg)
+func mustClient(endpoints []string, dialTimeout time.Duration, scfg *secureCfg, acfg *authCfg) *clientv3.Client {
+	cfg, err := newClientCfg(endpoints, dialTimeout, scfg, acfg)
 	if err != nil {
 		ExitWithError(ExitBadArgs, err)
 	}
@@ -96,7 +106,7 @@ func mustClient(endpoints []string, dialTimeout time.Duration, scfg *secureCfg)
 	return client
 }
 
-func newClientCfg(endpoints []string, dialTimeout time.Duration, scfg *secureCfg) (*clientv3.Config, error) {
+func newClientCfg(endpoints []string, dialTimeout time.Duration, scfg *secureCfg, acfg *authCfg) (*clientv3.Config, error) {
 	// set tls if any one tls option set
 	var cfgtls *transport.TLSInfo
 	tlsinfo := transport.TLSInfo{}
@@ -138,6 +148,12 @@ func newClientCfg(endpoints []string, dialTimeout time.Duration, scfg *secureCfg
 	if scfg.insecureSkipVerify && cfg.TLS != nil {
 		cfg.TLS.InsecureSkipVerify = true
 	}
+
+	if acfg != nil {
+		cfg.Username = acfg.username
+		cfg.Password = acfg.password
+	}
+
 	return cfg, nil
 }
 
@@ -213,3 +229,30 @@ func keyAndCertFromCmd(cmd *cobra.Command) (cert, key, cacert string) {
 
 	return cert, key, cacert
 }
+
+func authCfgFromCmd(cmd *cobra.Command) *authCfg {
+	userFlag, err := cmd.Flags().GetString("user")
+	if err != nil {
+		ExitWithError(ExitBadArgs, err)
+	}
+
+	if userFlag == "" {
+		return nil
+	}
+
+	var cfg authCfg
+
+	splitted := strings.SplitN(userFlag, ":", 2)
+	if len(splitted) == 0 {
+		cfg.username = userFlag
+		cfg.password, err = speakeasy.Ask("Password: ")
+		if err != nil {
+			ExitWithError(ExitError, err)
+		}
+	} else {
+		cfg.username = splitted[0]
+		cfg.password = splitted[1]
+	}
+
+	return &cfg
+}

+ 1 - 1
etcdctl/ctlv3/command/make_mirror_command.go

@@ -68,7 +68,7 @@ func makeMirrorCommandFunc(cmd *cobra.Command, args []string) {
 		insecureTransport: mminsecureTr,
 	}
 
-	dc := mustClient([]string{args[0]}, dialTimeout, sec)
+	dc := mustClient([]string{args[0]}, dialTimeout, sec, nil)
 	c := mustClientFromCmd(cmd)
 
 	err := makeMirror(context.TODO(), c, dc)

+ 1 - 0
etcdctl/ctlv3/ctl.go

@@ -57,6 +57,7 @@ func init() {
 	rootCmd.PersistentFlags().StringVar(&globalFlags.TLS.CertFile, "cert", "", "identify secure client using this TLS certificate file")
 	rootCmd.PersistentFlags().StringVar(&globalFlags.TLS.KeyFile, "key", "", "identify secure client using this TLS key file")
 	rootCmd.PersistentFlags().StringVar(&globalFlags.TLS.CAFile, "cacert", "", "verify certificates of TLS-enabled secure servers using this CA bundle")
+	rootCmd.PersistentFlags().StringVar(&globalFlags.User, "user", "", "username[:password] for authentication (prompt if password is not supplied)")
 
 	rootCmd.AddCommand(
 		command.NewGetCommand(),