Browse Source

Merge pull request #8289 from mitake/auth-proxy

clientv3, etcdmain, proxy: support authed RPCs with grpcproxy
Hitoshi Mitake 8 years ago
parent
commit
6515a1dfd0

+ 0 - 3
e2e/ctl_v3_auth_test.go

@@ -12,9 +12,6 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-// Skip proxy tests for now since auth is broken on grpcproxy.
-// +build !cluster_proxy
-
 package e2e
 
 import (

+ 4 - 0
etcdmain/grpc_proxy.go

@@ -204,6 +204,10 @@ func mustNewClient() *clientv3.Client {
 		fmt.Fprintln(os.Stderr, err)
 		os.Exit(1)
 	}
+	cfg.DialOptions = append(cfg.DialOptions,
+		grpc.WithUnaryInterceptor(grpcproxy.AuthUnaryClientInterceptor))
+	cfg.DialOptions = append(cfg.DialOptions,
+		grpc.WithStreamInterceptor(grpcproxy.AuthStreamClientInterceptor))
 	client, err := clientv3.New(*cfg)
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)

+ 2 - 0
proxy/grpcproxy/maintenance.go

@@ -42,6 +42,8 @@ func (mp *maintenanceProxy) Snapshot(sr *pb.SnapshotRequest, stream pb.Maintenan
 	ctx, cancel := context.WithCancel(stream.Context())
 	defer cancel()
 
+	ctx = withClientAuthToken(ctx, stream.Context())
+
 	sc, err := pb.NewMaintenanceClient(conn).Snapshot(ctx, sr)
 	if err != nil {
 		return err

+ 73 - 0
proxy/grpcproxy/util.go

@@ -0,0 +1,73 @@
+// Copyright 2017 The etcd Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package grpcproxy
+
+import (
+	"context"
+
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/metadata"
+)
+
+func getAuthTokenFromClient(ctx context.Context) string {
+	md, ok := metadata.FromIncomingContext(ctx)
+	if ok {
+		ts, ok := md["token"]
+		if ok {
+			return ts[0]
+		}
+	}
+	return ""
+}
+
+func withClientAuthToken(ctx context.Context, ctxWithToken context.Context) context.Context {
+	token := getAuthTokenFromClient(ctxWithToken)
+	if token != "" {
+		ctx = context.WithValue(ctx, "token", token)
+	}
+	return ctx
+}
+
+type proxyTokenCredential struct {
+	token string
+}
+
+func (cred *proxyTokenCredential) RequireTransportSecurity() bool {
+	return false
+}
+
+func (cred *proxyTokenCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) {
+	return map[string]string{
+		"token": cred.token,
+	}, nil
+}
+
+func AuthUnaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+	token := getAuthTokenFromClient(ctx)
+	if token != "" {
+		tokenCred := &proxyTokenCredential{token}
+		opts = append(opts, grpc.PerRPCCredentials(tokenCred))
+	}
+	return invoker(ctx, method, req, reply, cc, opts...)
+}
+
+func AuthStreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
+	tokenif := ctx.Value("token")
+	if tokenif != nil {
+		tokenCred := &proxyTokenCredential{tokenif.(string)}
+		opts = append(opts, grpc.PerRPCCredentials(tokenCred))
+	}
+	return streamer(ctx, desc, cc, method, opts...)
+}

+ 36 - 0
proxy/grpcproxy/watch.go

@@ -40,6 +40,9 @@ type watchProxy struct {
 
 	// wg waits until all outstanding watch servers quit.
 	wg sync.WaitGroup
+
+	// kv is used for permission checking
+	kv clientv3.KV
 }
 
 func NewWatchProxy(c *clientv3.Client) (pb.WatchServer, <-chan struct{}) {
@@ -48,6 +51,8 @@ func NewWatchProxy(c *clientv3.Client) (pb.WatchServer, <-chan struct{}) {
 		cw:     c.Watcher,
 		ctx:    cctx,
 		leader: newLeader(c.Ctx(), c.Watcher),
+
+		kv: c.KV, // for permission checking
 	}
 	wp.ranges = newWatchRanges(wp)
 	ch := make(chan struct{})
@@ -92,6 +97,7 @@ func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) {
 		watchCh:  make(chan *pb.WatchResponse, 1024),
 		ctx:      ctx,
 		cancel:   cancel,
+		kv:       wp.kv,
 	}
 
 	var lostLeaderC <-chan struct{}
@@ -171,6 +177,9 @@ type watchProxyStream struct {
 
 	ctx    context.Context
 	cancel context.CancelFunc
+
+	// kv is used for permission checking
+	kv clientv3.KV
 }
 
 func (wps *watchProxyStream) close() {
@@ -192,6 +201,24 @@ func (wps *watchProxyStream) close() {
 	close(wps.watchCh)
 }
 
+func (wps *watchProxyStream) checkPermissionForWatch(key, rangeEnd []byte) error {
+	if len(key) == 0 {
+		// If the length of the key is 0, we need to obtain full range.
+		// look at clientv3.WithPrefix()
+		key = []byte{0}
+		rangeEnd = []byte{0}
+	}
+	req := &pb.RangeRequest{
+		Serializable: true,
+		Key:          key,
+		RangeEnd:     rangeEnd,
+		CountOnly:    true,
+		Limit:        1,
+	}
+	_, err := wps.kv.Do(wps.ctx, RangeRequestToOp(req))
+	return err
+}
+
 func (wps *watchProxyStream) recvLoop() error {
 	for {
 		req, err := wps.stream.Recv()
@@ -201,6 +228,15 @@ func (wps *watchProxyStream) recvLoop() error {
 		switch uv := req.RequestUnion.(type) {
 		case *pb.WatchRequest_CreateRequest:
 			cr := uv.CreateRequest
+
+			if err = wps.checkPermissionForWatch(cr.Key, cr.RangeEnd); err != nil && err == rpctypes.ErrPermissionDenied {
+				// Return WatchResponse which is caused by permission checking if and only if
+				// the error is permission denied. For other errors (e.g. timeout or connection closed),
+				// the permission checking mechanism should do nothing for preserving error code.
+				wps.watchCh <- &pb.WatchResponse{Header: &pb.ResponseHeader{}, WatchId: -1, Created: true, Canceled: true}
+				continue
+			}
+
 			w := &watcher{
 				wr:  watchRange{string(cr.Key), string(cr.RangeEnd)},
 				id:  wps.nextWatcherID,

+ 2 - 0
proxy/grpcproxy/watch_broadcast.go

@@ -58,6 +58,8 @@ func newWatchBroadcast(wp *watchProxy, w *watcher, update func(*watchBroadcast))
 			clientv3.WithCreatedNotify(),
 		}
 
+		cctx = withClientAuthToken(cctx, w.wps.stream.Context())
+
 		wch := wp.cw.Watch(cctx, w.wr.key, opts...)
 
 		for wr := range wch {