sheddinginterceptor_test.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "testing"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/tal-tech/go-zero/core/load"
  7. "github.com/tal-tech/go-zero/core/stat"
  8. "google.golang.org/grpc"
  9. )
  10. func TestUnarySheddingInterceptor(t *testing.T) {
  11. tests := []struct {
  12. name string
  13. allow bool
  14. handleErr error
  15. expect error
  16. }{
  17. {
  18. name: "allow",
  19. allow: true,
  20. handleErr: nil,
  21. expect: nil,
  22. },
  23. {
  24. name: "allow",
  25. allow: true,
  26. handleErr: context.DeadlineExceeded,
  27. expect: context.DeadlineExceeded,
  28. },
  29. {
  30. name: "reject",
  31. allow: false,
  32. handleErr: nil,
  33. expect: load.ErrServiceOverloaded,
  34. },
  35. }
  36. for _, test := range tests {
  37. test := test
  38. t.Run(test.name, func(t *testing.T) {
  39. t.Parallel()
  40. shedder := mockedShedder{allow: test.allow}
  41. metrics := stat.NewMetrics("mock")
  42. interceptor := UnarySheddingInterceptor(shedder, metrics)
  43. _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
  44. FullMethod: "/",
  45. }, func(ctx context.Context, req interface{}) (interface{}, error) {
  46. return nil, test.handleErr
  47. })
  48. assert.Equal(t, test.expect, err)
  49. })
  50. }
  51. }
  52. type mockedShedder struct {
  53. allow bool
  54. }
  55. func (m mockedShedder) Allow() (load.Promise, error) {
  56. if m.allow {
  57. return mockedPromise{}, nil
  58. } else {
  59. return nil, load.ErrServiceOverloaded
  60. }
  61. }
  62. type mockedPromise struct {
  63. }
  64. func (m mockedPromise) Pass() {
  65. }
  66. func (m mockedPromise) Fail() {
  67. }