authinterceptor_test.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "testing"
  5. "github.com/alicebob/miniredis"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/tal-tech/go-zero/core/stores/redis"
  8. "github.com/tal-tech/go-zero/zrpc/internal/auth"
  9. "google.golang.org/grpc"
  10. "google.golang.org/grpc/metadata"
  11. )
  12. func TestStreamAuthorizeInterceptor(t *testing.T) {
  13. tests := []struct {
  14. name string
  15. app string
  16. token string
  17. strict bool
  18. hasError bool
  19. }{
  20. {
  21. name: "strict=false",
  22. strict: false,
  23. hasError: false,
  24. },
  25. {
  26. name: "strict=true",
  27. strict: true,
  28. hasError: true,
  29. },
  30. {
  31. name: "strict=true,with token",
  32. app: "foo",
  33. token: "bar",
  34. strict: true,
  35. hasError: false,
  36. },
  37. {
  38. name: "strict=true,with error token",
  39. app: "foo",
  40. token: "error",
  41. strict: true,
  42. hasError: true,
  43. },
  44. }
  45. r := miniredis.NewMiniRedis()
  46. assert.Nil(t, r.Start())
  47. defer r.Close()
  48. for _, test := range tests {
  49. t.Run(test.name, func(t *testing.T) {
  50. store := redis.NewRedis(r.Addr(), redis.NodeType)
  51. if len(test.app) > 0 {
  52. assert.Nil(t, store.Hset("apps", test.app, test.token))
  53. defer store.Hdel("apps", test.app)
  54. }
  55. authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
  56. assert.Nil(t, err)
  57. interceptor := StreamAuthorizeInterceptor(authenticator)
  58. md := metadata.New(map[string]string{
  59. "app": "foo",
  60. "token": "bar",
  61. })
  62. ctx := metadata.NewIncomingContext(context.Background(), md)
  63. stream := mockedStream{ctx: ctx}
  64. err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
  65. return nil
  66. })
  67. if test.hasError {
  68. assert.NotNil(t, err)
  69. } else {
  70. assert.Nil(t, err)
  71. }
  72. })
  73. }
  74. }
  75. func TestUnaryAuthorizeInterceptor(t *testing.T) {
  76. tests := []struct {
  77. name string
  78. app string
  79. token string
  80. strict bool
  81. hasError bool
  82. }{
  83. {
  84. name: "strict=false",
  85. strict: false,
  86. hasError: false,
  87. },
  88. {
  89. name: "strict=true",
  90. strict: true,
  91. hasError: true,
  92. },
  93. {
  94. name: "strict=true,with token",
  95. app: "foo",
  96. token: "bar",
  97. strict: true,
  98. hasError: false,
  99. },
  100. {
  101. name: "strict=true,with error token",
  102. app: "foo",
  103. token: "error",
  104. strict: true,
  105. hasError: true,
  106. },
  107. }
  108. r := miniredis.NewMiniRedis()
  109. assert.Nil(t, r.Start())
  110. defer r.Close()
  111. for _, test := range tests {
  112. t.Run(test.name, func(t *testing.T) {
  113. store := redis.NewRedis(r.Addr(), redis.NodeType)
  114. if len(test.app) > 0 {
  115. assert.Nil(t, store.Hset("apps", test.app, test.token))
  116. defer store.Hdel("apps", test.app)
  117. }
  118. authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
  119. assert.Nil(t, err)
  120. interceptor := UnaryAuthorizeInterceptor(authenticator)
  121. md := metadata.New(map[string]string{
  122. "app": "foo",
  123. "token": "bar",
  124. })
  125. ctx := metadata.NewIncomingContext(context.Background(), md)
  126. _, err = interceptor(ctx, nil, nil,
  127. func(ctx context.Context, req interface{}) (interface{}, error) {
  128. return nil, nil
  129. })
  130. if test.hasError {
  131. assert.NotNil(t, err)
  132. } else {
  133. assert.Nil(t, err)
  134. }
  135. if test.strict {
  136. _, err = interceptor(context.Background(), nil, nil,
  137. func(ctx context.Context, req interface{}) (interface{}, error) {
  138. return nil, nil
  139. })
  140. assert.NotNil(t, err)
  141. var md metadata.MD
  142. ctx := metadata.NewIncomingContext(context.Background(), md)
  143. _, err = interceptor(ctx, nil, nil,
  144. func(ctx context.Context, req interface{}) (interface{}, error) {
  145. return nil, nil
  146. })
  147. assert.NotNil(t, err)
  148. md = metadata.New(map[string]string{
  149. "app": "",
  150. "token": "",
  151. })
  152. ctx = metadata.NewIncomingContext(context.Background(), md)
  153. _, err = interceptor(ctx, nil, nil,
  154. func(ctx context.Context, req interface{}) (interface{}, error) {
  155. return nil, nil
  156. })
  157. assert.NotNil(t, err)
  158. }
  159. })
  160. }
  161. }
  162. type mockedStream struct {
  163. ctx context.Context
  164. }
  165. func (m mockedStream) SetHeader(md metadata.MD) error {
  166. return nil
  167. }
  168. func (m mockedStream) SendHeader(md metadata.MD) error {
  169. return nil
  170. }
  171. func (m mockedStream) SetTrailer(md metadata.MD) {
  172. }
  173. func (m mockedStream) Context() context.Context {
  174. return m.ctx
  175. }
  176. func (m mockedStream) SendMsg(v interface{}) error {
  177. return nil
  178. }
  179. func (m mockedStream) RecvMsg(v interface{}) error {
  180. return nil
  181. }