Browse Source

zrpc timeout & unit tests (#573)

* zrpc timeout & unit tests
Kevin Wan 3 years ago
parent
commit
4884a7b3c6

+ 22 - 1
zrpc/internal/clientinterceptors/timeoutinterceptor.go

@@ -18,6 +18,27 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
 
 		ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
 		defer cancel()
-		return invoker(ctx, method, req, reply, cc, opts...)
+
+		done := make(chan error)
+		panicChan := make(chan interface{}, 1)
+		go func() {
+			defer func() {
+				if p := recover(); p != nil {
+					panicChan <- p
+				}
+			}()
+
+			done <- invoker(ctx, method, req, reply, cc, opts...)
+			close(done)
+		}()
+
+		select {
+		case p := <-panicChan:
+			panic(p)
+		case err := <-done:
+			return err
+		case <-ctx.Done():
+			return ctx.Err()
+		}
 	}
 }

+ 37 - 0
zrpc/internal/clientinterceptors/timeoutinterceptor_test.go

@@ -48,3 +48,40 @@ func TestTimeoutInterceptor_timeout(t *testing.T) {
 	wg.Wait()
 	assert.Nil(t, err)
 }
+
+func TestTimeoutInterceptor_timeoutExpire(t *testing.T) {
+	const timeout = time.Millisecond * 10
+	interceptor := TimeoutInterceptor(timeout)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
+	defer cancel()
+	var wg sync.WaitGroup
+	wg.Add(1)
+	cc := new(grpc.ClientConn)
+	err := interceptor(ctx, "/foo", nil, nil, cc,
+		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
+			opts ...grpc.CallOption) error {
+			defer wg.Done()
+			time.Sleep(time.Millisecond * 50)
+			return nil
+		})
+	wg.Wait()
+	assert.Equal(t, context.DeadlineExceeded, err)
+}
+
+func TestTimeoutInterceptor_panic(t *testing.T) {
+	timeouts := []time.Duration{0, time.Millisecond * 10}
+	for _, timeout := range timeouts {
+		t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
+			interceptor := TimeoutInterceptor(timeout)
+			cc := new(grpc.ClientConn)
+			assert.Panics(t, func() {
+				_ = interceptor(context.Background(), "/foo", nil, nil, cc,
+					func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
+						opts ...grpc.CallOption) error {
+						panic("any")
+					},
+				)
+			})
+		})
+	}
+}

+ 31 - 2
zrpc/internal/serverinterceptors/timeoutinterceptor.go

@@ -2,6 +2,7 @@ package serverinterceptors
 
 import (
 	"context"
+	"sync"
 	"time"
 
 	"github.com/tal-tech/go-zero/core/contextx"
@@ -11,9 +12,37 @@ import (
 // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
 func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
 	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
-		handler grpc.UnaryHandler) (resp interface{}, err error) {
+		handler grpc.UnaryHandler) (interface{}, error) {
 		ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
 		defer cancel()
-		return handler(ctx, req)
+
+		var resp interface{}
+		var err error
+		var lock sync.Mutex
+		done := make(chan struct{})
+		panicChan := make(chan interface{}, 1)
+		go func() {
+			defer func() {
+				if p := recover(); p != nil {
+					panicChan <- p
+				}
+			}()
+
+			lock.Lock()
+			defer lock.Unlock()
+			resp, err = handler(ctx, req)
+			close(done)
+		}()
+
+		select {
+		case p := <-panicChan:
+			panic(p)
+		case <-done:
+			lock.Lock()
+			defer lock.Unlock()
+			return resp, err
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		}
 	}
 }

+ 29 - 0
zrpc/internal/serverinterceptors/timeoutinterceptor_test.go

@@ -20,6 +20,17 @@ func TestUnaryTimeoutInterceptor(t *testing.T) {
 	assert.Nil(t, err)
 }
 
+func TestUnaryTimeoutInterceptor_panic(t *testing.T) {
+	interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
+	assert.Panics(t, func() {
+		_, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
+			FullMethod: "/",
+		}, func(ctx context.Context, req interface{}) (interface{}, error) {
+			panic("any")
+		})
+	})
+}
+
 func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
 	const timeout = time.Millisecond * 10
 	interceptor := UnaryTimeoutInterceptor(timeout)
@@ -39,3 +50,21 @@ func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
 	wg.Wait()
 	assert.Nil(t, err)
 }
+
+func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
+	const timeout = time.Millisecond * 10
+	interceptor := UnaryTimeoutInterceptor(timeout)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
+	defer cancel()
+	var wg sync.WaitGroup
+	wg.Add(1)
+	_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
+		FullMethod: "/",
+	}, func(ctx context.Context, req interface{}) (interface{}, error) {
+		defer wg.Done()
+		time.Sleep(time.Millisecond * 50)
+		return nil, nil
+	})
+	wg.Wait()
+	assert.Equal(t, context.DeadlineExceeded, err)
+}