chainserverinterceptors.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package internal
  2. import (
  3. "context"
  4. "google.golang.org/grpc"
  5. )
  6. func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
  7. return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...))
  8. }
  9. func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
  10. return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...))
  11. }
  12. func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
  13. switch len(interceptors) {
  14. case 0:
  15. return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
  16. handler grpc.StreamHandler) error {
  17. return handler(srv, stream)
  18. }
  19. case 1:
  20. return interceptors[0]
  21. default:
  22. last := len(interceptors) - 1
  23. return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
  24. handler grpc.StreamHandler) error {
  25. var chainHandler grpc.StreamHandler
  26. var current int
  27. chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error {
  28. if current == last {
  29. return handler(curSrv, curStream)
  30. }
  31. current++
  32. err := interceptors[current](curSrv, curStream, info, chainHandler)
  33. current--
  34. return err
  35. }
  36. return interceptors[0](srv, stream, info, chainHandler)
  37. }
  38. }
  39. }
  40. func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
  41. switch len(interceptors) {
  42. case 0:
  43. return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
  44. interface{}, error) {
  45. return handler(ctx, req)
  46. }
  47. case 1:
  48. return interceptors[0]
  49. default:
  50. last := len(interceptors) - 1
  51. return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
  52. interface{}, error) {
  53. var chainHandler grpc.UnaryHandler
  54. var current int
  55. chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) {
  56. if current == last {
  57. return handler(curCtx, curReq)
  58. }
  59. current++
  60. resp, err := interceptors[current](curCtx, curReq, info, chainHandler)
  61. current--
  62. return resp, err
  63. }
  64. return interceptors[0](ctx, req, info, chainHandler)
  65. }
  66. }
  67. }