自定义拦截器.md 3.4 KB

自定义拦截器

zRPC提供自定义拦截器功能以满足不同的业务需求

客户端拦截器定义

客户端拦截器需要实现 grpc.UnaryClientInterceptor,定义如下:

type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error

在创建客户端的时候通过zrpc.WithUnaryClientInterceptor进行注册

服务端拦截器定义

服务端拦截器需要实现grpc.UnaryServerInterceptor,定义如下:

type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)

通过RpcServer.AddUnaryInterceptors进行注册

自定义拦截器示例

自定义客户端和服务端拦截器,客户端拦截器输出请求方法耗时,服务端拦截器进行简单的限流

客户端代码

package main

import (
	"context"
	"fmt"
	"log"
	"time"

	"test/zrpc/pb"

	"github.com/tal-tech/go-zero/core/discov"
	"github.com/tal-tech/go-zero/zrpc"
	"google.golang.org/grpc"
)

func timeInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
	stime := time.Now()
	err := invoker(ctx, method, req, reply, cc, opts...)
	if err != nil {
		return err
	}

	fmt.Printf("调用 %s 方法 耗时: %v\n", method, time.Now().Sub(stime))
	return nil
}

func main() {
	client := zrpc.MustNewClient(zrpc.RpcClientConf{
		Etcd: discov.EtcdConf{
			Hosts: []string{"127.0.0.1:2379"},
			Key:   "hello.rpc",
		},
	}, zrpc.WithUnaryClientInterceptor(timeInterceptor))

	hello := pb.NewGreeterClient(client.Conn())

	var count int
	for {
		reply, err := hello.SayHello(context.Background(), &pb.HelloRequest{Name: fmt.Sprintf("Hanmeimei%d", count)})
		if err != nil {
			log.Fatal(err)
		}
		count++
		log.Println(reply.Message)
		time.Sleep(time.Millisecond * 100)
	}
}

服务端代码

Name: hello.rpc
Log:
  Mode: console
ListenOn: 127.0.0.1:9090
Etcd:
  Hosts:
    - 127.0.0.1:2379
  Key: hello.rpc
package main

import (
	"context"
	"flag"
	"fmt"
	"log"
	"time"

	"test/zrpc/pb"

	"github.com/tal-tech/go-zero/core/conf"
	"github.com/tal-tech/go-zero/zrpc"
    "golang.org/x/time/rate"
	"google.golang.org/grpc"
)

type Config struct {
	zrpc.RpcServerConf
}

var cfgFile = flag.String("f", "./server.yaml", "cfg file")

var limiter = rate.NewLimiter(rate.Limit(100), 100)

func main() {
	flag.Parse()

	var cfg Config
	conf.MustLoad(*cfgFile, &cfg)

	srv := zrpc.MustNewServer(cfg.RpcServerConf, func(s *grpc.Server) {
		pb.RegisterGreeterServer(s, &Hello{})
	})
	srv.AddUnaryInterceptors(rateLimitInterceptor)
	srv.Start()
}

type Hello struct{}

func (h *Hello) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
	time.Sleep(time.Millisecond * 300)
	return &pb.HelloReply{Message: "hello " + in.Name}, nil
}

func rateLimitInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
	if !limiter.Allow() {
		fmt.Println("限流了")
		return nil, nil
	}
	return handler(ctx, req)
}

运行代码可以看到客户端输出如下信息:

调用 /pb.Greeter/SayHello 方法 耗时: 300.633257ms

当请求速率加快服务端会输出:

限流了

以上输出说明自定义的拦截器已经可以正常的工作了