chainclientinterceptors.go 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package internal
  2. import (
  3. "context"
  4. "google.golang.org/grpc"
  5. )
  6. func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption {
  7. return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...))
  8. }
  9. func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption {
  10. return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...))
  11. }
  12. func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
  13. switch len(interceptors) {
  14. case 0:
  15. return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
  16. streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
  17. return streamer(ctx, desc, cc, method, opts...)
  18. }
  19. case 1:
  20. return interceptors[0]
  21. default:
  22. last := len(interceptors) - 1
  23. return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
  24. method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
  25. var chainStreamer grpc.Streamer
  26. var current int
  27. chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn,
  28. curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) {
  29. if current == last {
  30. return streamer(curCtx, curDesc, curCc, curMethod, curOpts...)
  31. }
  32. current++
  33. clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...)
  34. current--
  35. return clientStream, err
  36. }
  37. return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...)
  38. }
  39. }
  40. }
  41. func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
  42. switch len(interceptors) {
  43. case 0:
  44. return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
  45. invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
  46. return invoker(ctx, method, req, reply, cc, opts...)
  47. }
  48. case 1:
  49. return interceptors[0]
  50. default:
  51. last := len(interceptors) - 1
  52. return func(ctx context.Context, method string, req, reply interface{},
  53. cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
  54. var chainInvoker grpc.UnaryInvoker
  55. var current int
  56. chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{},
  57. curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error {
  58. if current == last {
  59. return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...)
  60. }
  61. current++
  62. err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...)
  63. current--
  64. return err
  65. }
  66. return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...)
  67. }
  68. }
  69. }