authinterceptor_test.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "testing"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/tal-tech/go-zero/core/stores/redis/redistest"
  7. "github.com/tal-tech/go-zero/zrpc/internal/auth"
  8. "google.golang.org/grpc"
  9. "google.golang.org/grpc/metadata"
  10. )
  11. func TestStreamAuthorizeInterceptor(t *testing.T) {
  12. tests := []struct {
  13. name string
  14. app string
  15. token string
  16. strict bool
  17. hasError bool
  18. }{
  19. {
  20. name: "strict=false",
  21. strict: false,
  22. hasError: false,
  23. },
  24. {
  25. name: "strict=true",
  26. strict: true,
  27. hasError: true,
  28. },
  29. {
  30. name: "strict=true,with token",
  31. app: "foo",
  32. token: "bar",
  33. strict: true,
  34. hasError: false,
  35. },
  36. {
  37. name: "strict=true,with error token",
  38. app: "foo",
  39. token: "error",
  40. strict: true,
  41. hasError: true,
  42. },
  43. }
  44. store, clean, err := redistest.CreateRedis()
  45. assert.Nil(t, err)
  46. defer clean()
  47. for _, test := range tests {
  48. t.Run(test.name, func(t *testing.T) {
  49. if len(test.app) > 0 {
  50. assert.Nil(t, store.Hset("apps", test.app, test.token))
  51. defer store.Hdel("apps", test.app)
  52. }
  53. authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
  54. assert.Nil(t, err)
  55. interceptor := StreamAuthorizeInterceptor(authenticator)
  56. md := metadata.New(map[string]string{
  57. "app": "foo",
  58. "token": "bar",
  59. })
  60. ctx := metadata.NewIncomingContext(context.Background(), md)
  61. stream := mockedStream{ctx: ctx}
  62. err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
  63. return nil
  64. })
  65. if test.hasError {
  66. assert.NotNil(t, err)
  67. } else {
  68. assert.Nil(t, err)
  69. }
  70. })
  71. }
  72. }
  73. func TestUnaryAuthorizeInterceptor(t *testing.T) {
  74. tests := []struct {
  75. name string
  76. app string
  77. token string
  78. strict bool
  79. hasError bool
  80. }{
  81. {
  82. name: "strict=false",
  83. strict: false,
  84. hasError: false,
  85. },
  86. {
  87. name: "strict=true",
  88. strict: true,
  89. hasError: true,
  90. },
  91. {
  92. name: "strict=true,with token",
  93. app: "foo",
  94. token: "bar",
  95. strict: true,
  96. hasError: false,
  97. },
  98. {
  99. name: "strict=true,with error token",
  100. app: "foo",
  101. token: "error",
  102. strict: true,
  103. hasError: true,
  104. },
  105. }
  106. store, clean, err := redistest.CreateRedis()
  107. assert.Nil(t, err)
  108. defer clean()
  109. for _, test := range tests {
  110. t.Run(test.name, func(t *testing.T) {
  111. if len(test.app) > 0 {
  112. assert.Nil(t, store.Hset("apps", test.app, test.token))
  113. defer store.Hdel("apps", test.app)
  114. }
  115. authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
  116. assert.Nil(t, err)
  117. interceptor := UnaryAuthorizeInterceptor(authenticator)
  118. md := metadata.New(map[string]string{
  119. "app": "foo",
  120. "token": "bar",
  121. })
  122. ctx := metadata.NewIncomingContext(context.Background(), md)
  123. _, err = interceptor(ctx, nil, nil,
  124. func(ctx context.Context, req interface{}) (interface{}, error) {
  125. return nil, nil
  126. })
  127. if test.hasError {
  128. assert.NotNil(t, err)
  129. } else {
  130. assert.Nil(t, err)
  131. }
  132. if test.strict {
  133. _, err = interceptor(context.Background(), nil, nil,
  134. func(ctx context.Context, req interface{}) (interface{}, error) {
  135. return nil, nil
  136. })
  137. assert.NotNil(t, err)
  138. var md metadata.MD
  139. ctx := metadata.NewIncomingContext(context.Background(), md)
  140. _, err = interceptor(ctx, nil, nil,
  141. func(ctx context.Context, req interface{}) (interface{}, error) {
  142. return nil, nil
  143. })
  144. assert.NotNil(t, err)
  145. md = metadata.New(map[string]string{
  146. "app": "",
  147. "token": "",
  148. })
  149. ctx = metadata.NewIncomingContext(context.Background(), md)
  150. _, err = interceptor(ctx, nil, nil,
  151. func(ctx context.Context, req interface{}) (interface{}, error) {
  152. return nil, nil
  153. })
  154. assert.NotNil(t, err)
  155. }
  156. })
  157. }
  158. }
  159. type mockedStream struct {
  160. ctx context.Context
  161. }
  162. func (m mockedStream) SetHeader(md metadata.MD) error {
  163. return nil
  164. }
  165. func (m mockedStream) SendHeader(md metadata.MD) error {
  166. return nil
  167. }
  168. func (m mockedStream) SetTrailer(md metadata.MD) {
  169. }
  170. func (m mockedStream) Context() context.Context {
  171. return m.ctx
  172. }
  173. func (m mockedStream) SendMsg(v interface{}) error {
  174. return nil
  175. }
  176. func (m mockedStream) RecvMsg(v interface{}) error {
  177. return nil
  178. }