123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- package serverinterceptors
- import (
- "context"
- "testing"
- "github.com/alicebob/miniredis"
- "github.com/stretchr/testify/assert"
- "github.com/tal-tech/go-zero/core/stores/redis"
- "github.com/tal-tech/go-zero/zrpc/internal/auth"
- "google.golang.org/grpc"
- "google.golang.org/grpc/metadata"
- )
- func TestStreamAuthorizeInterceptor(t *testing.T) {
- tests := []struct {
- name string
- app string
- token string
- strict bool
- hasError bool
- }{
- {
- name: "strict=false",
- strict: false,
- hasError: false,
- },
- {
- name: "strict=true",
- strict: true,
- hasError: true,
- },
- {
- name: "strict=true,with token",
- app: "foo",
- token: "bar",
- strict: true,
- hasError: false,
- },
- {
- name: "strict=true,with error token",
- app: "foo",
- token: "error",
- strict: true,
- hasError: true,
- },
- }
- r := miniredis.NewMiniRedis()
- assert.Nil(t, r.Start())
- defer r.Close()
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- store := redis.NewRedis(r.Addr(), redis.NodeType)
- if len(test.app) > 0 {
- assert.Nil(t, store.Hset("apps", test.app, test.token))
- defer store.Hdel("apps", test.app)
- }
- authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
- assert.Nil(t, err)
- interceptor := StreamAuthorizeInterceptor(authenticator)
- md := metadata.New(map[string]string{
- "app": "foo",
- "token": "bar",
- })
- ctx := metadata.NewIncomingContext(context.Background(), md)
- stream := mockedStream{ctx: ctx}
- err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
- return nil
- })
- if test.hasError {
- assert.NotNil(t, err)
- } else {
- assert.Nil(t, err)
- }
- })
- }
- }
- func TestUnaryAuthorizeInterceptor(t *testing.T) {
- tests := []struct {
- name string
- app string
- token string
- strict bool
- hasError bool
- }{
- {
- name: "strict=false",
- strict: false,
- hasError: false,
- },
- {
- name: "strict=true",
- strict: true,
- hasError: true,
- },
- {
- name: "strict=true,with token",
- app: "foo",
- token: "bar",
- strict: true,
- hasError: false,
- },
- {
- name: "strict=true,with error token",
- app: "foo",
- token: "error",
- strict: true,
- hasError: true,
- },
- }
- r := miniredis.NewMiniRedis()
- assert.Nil(t, r.Start())
- defer r.Close()
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- store := redis.NewRedis(r.Addr(), redis.NodeType)
- if len(test.app) > 0 {
- assert.Nil(t, store.Hset("apps", test.app, test.token))
- defer store.Hdel("apps", test.app)
- }
- authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
- assert.Nil(t, err)
- interceptor := UnaryAuthorizeInterceptor(authenticator)
- md := metadata.New(map[string]string{
- "app": "foo",
- "token": "bar",
- })
- ctx := metadata.NewIncomingContext(context.Background(), md)
- _, err = interceptor(ctx, nil, nil,
- func(ctx context.Context, req interface{}) (interface{}, error) {
- return nil, nil
- })
- if test.hasError {
- assert.NotNil(t, err)
- } else {
- assert.Nil(t, err)
- }
- if test.strict {
- _, err = interceptor(context.Background(), nil, nil,
- func(ctx context.Context, req interface{}) (interface{}, error) {
- return nil, nil
- })
- assert.NotNil(t, err)
- var md metadata.MD
- ctx := metadata.NewIncomingContext(context.Background(), md)
- _, err = interceptor(ctx, nil, nil,
- func(ctx context.Context, req interface{}) (interface{}, error) {
- return nil, nil
- })
- assert.NotNil(t, err)
- md = metadata.New(map[string]string{
- "app": "",
- "token": "",
- })
- ctx = metadata.NewIncomingContext(context.Background(), md)
- _, err = interceptor(ctx, nil, nil,
- func(ctx context.Context, req interface{}) (interface{}, error) {
- return nil, nil
- })
- assert.NotNil(t, err)
- }
- })
- }
- }
- type mockedStream struct {
- ctx context.Context
- }
- func (m mockedStream) SetHeader(md metadata.MD) error {
- return nil
- }
- func (m mockedStream) SendHeader(md metadata.MD) error {
- return nil
- }
- func (m mockedStream) SetTrailer(md metadata.MD) {
- }
- func (m mockedStream) Context() context.Context {
- return m.ctx
- }
- func (m mockedStream) SendMsg(v interface{}) error {
- return nil
- }
- func (m mockedStream) RecvMsg(v interface{}) error {
- return nil
- }
|