# 自定义拦截器 zRPC提供自定义拦截器功能以满足不同的业务需求 ### 客户端拦截器定义 客户端拦截器需要实现 grpc.UnaryClientInterceptor,定义如下: ```go type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error ``` 在创建客户端的时候通过zrpc.WithUnaryClientInterceptor进行注册 ### 服务端拦截器定义 服务端拦截器需要实现grpc.UnaryServerInterceptor,定义如下: ```go type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) ``` 通过RpcServer.AddUnaryInterceptors进行注册 ### 自定义拦截器示例 自定义客户端和服务端拦截器,客户端拦截器输出请求方法耗时,服务端拦截器进行简单的限流 #### 客户端代码 ```go package main import ( "context" "fmt" "log" "time" "test/zrpc/pb" "github.com/tal-tech/go-zero/core/discov" "github.com/tal-tech/go-zero/zrpc" "google.golang.org/grpc" ) func timeInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { stime := time.Now() err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { return err } fmt.Printf("调用 %s 方法 耗时: %v\n", method, time.Now().Sub(stime)) return nil } func main() { client := zrpc.MustNewClient(zrpc.RpcClientConf{ Etcd: discov.EtcdConf{ Hosts: []string{"127.0.0.1:2379"}, Key: "hello.rpc", }, }, zrpc.WithUnaryClientInterceptor(timeInterceptor)) hello := pb.NewGreeterClient(client.Conn()) var count int for { reply, err := hello.SayHello(context.Background(), &pb.HelloRequest{Name: fmt.Sprintf("Hanmeimei%d", count)}) if err != nil { log.Fatal(err) } count++ log.Println(reply.Message) time.Sleep(time.Millisecond * 100) } } ``` #### 服务端代码 ```yaml Name: hello.rpc Log: Mode: console ListenOn: 127.0.0.1:9090 Etcd: Hosts: - 127.0.0.1:2379 Key: hello.rpc ``` ```go package main import ( "context" "flag" "fmt" "log" "time" "test/zrpc/pb" "github.com/tal-tech/go-zero/core/conf" "github.com/tal-tech/go-zero/zrpc" "golang.org/x/time/rate" "google.golang.org/grpc" ) type Config struct { zrpc.RpcServerConf } var cfgFile = flag.String("f", "./server.yaml", "cfg file") var limiter = rate.NewLimiter(rate.Limit(100), 100) func main() { flag.Parse() var cfg Config conf.MustLoad(*cfgFile, &cfg) srv := zrpc.MustNewServer(cfg.RpcServerConf, func(s *grpc.Server) { pb.RegisterGreeterServer(s, &Hello{}) }) srv.AddUnaryInterceptors(rateLimitInterceptor) srv.Start() } type Hello struct{} func (h *Hello) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { time.Sleep(time.Millisecond * 300) return &pb.HelloReply{Message: "hello " + in.Name}, nil } func rateLimitInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { if !limiter.Allow() { fmt.Println("限流了") return nil, nil } return handler(ctx, req) } ``` 运行代码可以看到客户端输出如下信息: ```go 调用 /pb.Greeter/SayHello 方法 耗时: 300.633257ms ``` 当请求速率加快服务端会输出: ```go 限流了 ``` 以上输出说明自定义的拦截器已经可以正常的工作了