Browse Source

etcdmain, proxy: handle authed watch in grpcproxy

This commit lets grpcproxy handle authed watch. The main changes are:
1. forwrading a token of a new broadcast client
2. checking permission of a new client that participates to an
   existing broadcast
Hitoshi Mitake 8 năm trước cách đây
mục cha
commit
94b5071c30

+ 2 - 0
etcdmain/grpc_proxy.go

@@ -206,6 +206,8 @@ func mustNewClient() *clientv3.Client {
 	}
 	}
 	cfg.DialOptions = append(cfg.DialOptions,
 	cfg.DialOptions = append(cfg.DialOptions,
 		grpc.WithUnaryInterceptor(grpcproxy.AuthUnaryClientInterceptor))
 		grpc.WithUnaryInterceptor(grpcproxy.AuthUnaryClientInterceptor))
+	cfg.DialOptions = append(cfg.DialOptions,
+		grpc.WithStreamInterceptor(grpcproxy.AuthStreamClientInterceptor))
 	client, err := clientv3.New(*cfg)
 	client, err := clientv3.New(*cfg)
 	if err != nil {
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)
 		fmt.Fprintln(os.Stderr, err)

+ 9 - 0
proxy/grpcproxy/util.go

@@ -54,3 +54,12 @@ func AuthUnaryClientInterceptor(ctx context.Context, method string, req, reply i
 	}
 	}
 	return invoker(ctx, method, req, reply, cc, opts...)
 	return invoker(ctx, method, req, reply, cc, opts...)
 }
 }
+
+func AuthStreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
+	tokenif := ctx.Value("token")
+	if tokenif != nil {
+		tokenCred := &proxyTokenCredential{tokenif.(string)}
+		opts = append(opts, grpc.PerRPCCredentials(tokenCred))
+	}
+	return streamer(ctx, desc, cc, method, opts...)
+}

+ 36 - 0
proxy/grpcproxy/watch.go

@@ -40,6 +40,9 @@ type watchProxy struct {
 
 
 	// wg waits until all outstanding watch servers quit.
 	// wg waits until all outstanding watch servers quit.
 	wg sync.WaitGroup
 	wg sync.WaitGroup
+
+	// kv is used for permission checking
+	kv clientv3.KV
 }
 }
 
 
 func NewWatchProxy(c *clientv3.Client) (pb.WatchServer, <-chan struct{}) {
 func NewWatchProxy(c *clientv3.Client) (pb.WatchServer, <-chan struct{}) {
@@ -48,6 +51,8 @@ func NewWatchProxy(c *clientv3.Client) (pb.WatchServer, <-chan struct{}) {
 		cw:     c.Watcher,
 		cw:     c.Watcher,
 		ctx:    cctx,
 		ctx:    cctx,
 		leader: newLeader(c.Ctx(), c.Watcher),
 		leader: newLeader(c.Ctx(), c.Watcher),
+
+		kv: c.KV, // for permission checking
 	}
 	}
 	wp.ranges = newWatchRanges(wp)
 	wp.ranges = newWatchRanges(wp)
 	ch := make(chan struct{})
 	ch := make(chan struct{})
@@ -92,6 +97,7 @@ func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) {
 		watchCh:  make(chan *pb.WatchResponse, 1024),
 		watchCh:  make(chan *pb.WatchResponse, 1024),
 		ctx:      ctx,
 		ctx:      ctx,
 		cancel:   cancel,
 		cancel:   cancel,
+		kv:       wp.kv,
 	}
 	}
 
 
 	var lostLeaderC <-chan struct{}
 	var lostLeaderC <-chan struct{}
@@ -171,6 +177,9 @@ type watchProxyStream struct {
 
 
 	ctx    context.Context
 	ctx    context.Context
 	cancel context.CancelFunc
 	cancel context.CancelFunc
+
+	// kv is used for permission checking
+	kv clientv3.KV
 }
 }
 
 
 func (wps *watchProxyStream) close() {
 func (wps *watchProxyStream) close() {
@@ -192,6 +201,24 @@ func (wps *watchProxyStream) close() {
 	close(wps.watchCh)
 	close(wps.watchCh)
 }
 }
 
 
+func (wps *watchProxyStream) checkPermissionForWatch(key, rangeEnd []byte) error {
+	if len(key) == 0 {
+		// If the length of the key is 0, we need to obtain full range.
+		// look at clientv3.WithPrefix()
+		key = []byte{0}
+		rangeEnd = []byte{0}
+	}
+	req := &pb.RangeRequest{
+		Serializable: true,
+		Key:          key,
+		RangeEnd:     rangeEnd,
+		CountOnly:    true,
+		Limit:        1,
+	}
+	_, err := wps.kv.Do(wps.ctx, RangeRequestToOp(req))
+	return err
+}
+
 func (wps *watchProxyStream) recvLoop() error {
 func (wps *watchProxyStream) recvLoop() error {
 	for {
 	for {
 		req, err := wps.stream.Recv()
 		req, err := wps.stream.Recv()
@@ -201,6 +228,15 @@ func (wps *watchProxyStream) recvLoop() error {
 		switch uv := req.RequestUnion.(type) {
 		switch uv := req.RequestUnion.(type) {
 		case *pb.WatchRequest_CreateRequest:
 		case *pb.WatchRequest_CreateRequest:
 			cr := uv.CreateRequest
 			cr := uv.CreateRequest
+
+			if err = wps.checkPermissionForWatch(cr.Key, cr.RangeEnd); err != nil && err == rpctypes.ErrPermissionDenied {
+				// Return WatchResponse which is caused by permission checking if and only if
+				// the error is permission denied. For other errors (e.g. timeout or connection closed),
+				// the permission checking mechanism should do nothing for preserving error code.
+				wps.watchCh <- &pb.WatchResponse{Header: &pb.ResponseHeader{}, WatchId: -1, Created: true, Canceled: true}
+				continue
+			}
+
 			w := &watcher{
 			w := &watcher{
 				wr:  watchRange{string(cr.Key), string(cr.RangeEnd)},
 				wr:  watchRange{string(cr.Key), string(cr.RangeEnd)},
 				id:  wps.nextWatcherID,
 				id:  wps.nextWatcherID,

+ 6 - 0
proxy/grpcproxy/watch_broadcast.go

@@ -58,6 +58,12 @@ func newWatchBroadcast(wp *watchProxy, w *watcher, update func(*watchBroadcast))
 			clientv3.WithCreatedNotify(),
 			clientv3.WithCreatedNotify(),
 		}
 		}
 
 
+		// Forward a token from client to server.
+		token := getAuthTokenFromClient(w.wps.stream.Context())
+		if token != "" {
+			cctx = context.WithValue(cctx, "token", token)
+		}
+
 		wch := wp.cw.Watch(cctx, w.wr.key, opts...)
 		wch := wp.cw.Watch(cctx, w.wr.key, opts...)
 
 
 		for wr := range wch {
 		for wr := range wch {