123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- package internal
- import (
- "context"
- "google.golang.org/grpc"
- )
- func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
- return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...))
- }
- func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
- return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...))
- }
- func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
- switch len(interceptors) {
- case 0:
- return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
- handler grpc.StreamHandler) error {
- return handler(srv, stream)
- }
- case 1:
- return interceptors[0]
- default:
- last := len(interceptors) - 1
- return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
- handler grpc.StreamHandler) error {
- var chainHandler grpc.StreamHandler
- var current int
- chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error {
- if current == last {
- return handler(curSrv, curStream)
- }
- current++
- err := interceptors[current](curSrv, curStream, info, chainHandler)
- current--
- return err
- }
- return interceptors[0](srv, stream, info, chainHandler)
- }
- }
- }
- func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
- switch len(interceptors) {
- case 0:
- return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
- interface{}, error) {
- return handler(ctx, req)
- }
- case 1:
- return interceptors[0]
- default:
- last := len(interceptors) - 1
- return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
- interface{}, error) {
- var chainHandler grpc.UnaryHandler
- var current int
- chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) {
- if current == last {
- return handler(curCtx, curReq)
- }
- current++
- resp, err := interceptors[current](curCtx, curReq, info, chainHandler)
- current--
- return resp, err
- }
- return interceptors[0](ctx, req, info, chainHandler)
- }
- }
- }
|