Ver Fonte

add tracing logs in server side and client side

kevin há 3 anos atrás
pai
commit
be9c48da7f

+ 18 - 13
core/logx/customlogger.go → core/logx/durationlogger.go

@@ -8,55 +8,60 @@ import (
 	"github.com/tal-tech/go-zero/core/timex"
 )
 
-const customCallerDepth = 3
+const durationCallerDepth = 3
 
-type customLog logEntry
+type durationLogger logEntry
 
 func WithDuration(d time.Duration) Logger {
-	return customLog{
+	return &durationLogger{
 		Duration: timex.ReprOfDuration(d),
 	}
 }
 
-func (l customLog) Error(v ...interface{}) {
+func (l *durationLogger) Error(v ...interface{}) {
 	if shouldLog(ErrorLevel) {
-		l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), customCallerDepth))
+		l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), durationCallerDepth))
 	}
 }
 
-func (l customLog) Errorf(format string, v ...interface{}) {
+func (l *durationLogger) Errorf(format string, v ...interface{}) {
 	if shouldLog(ErrorLevel) {
-		l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), customCallerDepth))
+		l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), durationCallerDepth))
 	}
 }
 
-func (l customLog) Info(v ...interface{}) {
+func (l *durationLogger) Info(v ...interface{}) {
 	if shouldLog(InfoLevel) {
 		l.write(infoLog, levelInfo, fmt.Sprint(v...))
 	}
 }
 
-func (l customLog) Infof(format string, v ...interface{}) {
+func (l *durationLogger) Infof(format string, v ...interface{}) {
 	if shouldLog(InfoLevel) {
 		l.write(infoLog, levelInfo, fmt.Sprintf(format, v...))
 	}
 }
 
-func (l customLog) Slow(v ...interface{}) {
+func (l *durationLogger) Slow(v ...interface{}) {
 	if shouldLog(ErrorLevel) {
 		l.write(slowLog, levelSlow, fmt.Sprint(v...))
 	}
 }
 
-func (l customLog) Slowf(format string, v ...interface{}) {
+func (l *durationLogger) Slowf(format string, v ...interface{}) {
 	if shouldLog(ErrorLevel) {
 		l.write(slowLog, levelSlow, fmt.Sprintf(format, v...))
 	}
 }
 
-func (l customLog) write(writer io.Writer, level, content string) {
+func (l *durationLogger) WithDuration(duration time.Duration) Logger {
+	l.Duration = timex.ReprOfDuration(duration)
+	return l
+}
+
+func (l *durationLogger) write(writer io.Writer, level, content string) {
 	l.Timestamp = getTimestamp()
 	l.Level = level
 	l.Content = content
-	outputJson(writer, logEntry(l))
+	outputJson(writer, logEntry(*l))
 }

+ 52 - 0
core/logx/durationlogger_test.go

@@ -0,0 +1,52 @@
+package logx
+
+import (
+	"log"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestWithDurationError(t *testing.T) {
+	var builder strings.Builder
+	log.SetOutput(&builder)
+	WithDuration(time.Second).Error("foo")
+	assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
+}
+
+func TestWithDurationErrorf(t *testing.T) {
+	var builder strings.Builder
+	log.SetOutput(&builder)
+	WithDuration(time.Second).Errorf("foo")
+	assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
+}
+
+func TestWithDurationInfo(t *testing.T) {
+	var builder strings.Builder
+	log.SetOutput(&builder)
+	WithDuration(time.Second).Info("foo")
+	assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
+}
+
+func TestWithDurationInfof(t *testing.T) {
+	var builder strings.Builder
+	log.SetOutput(&builder)
+	WithDuration(time.Second).Infof("foo")
+	assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
+}
+
+func TestWithDurationSlow(t *testing.T) {
+	var builder strings.Builder
+	log.SetOutput(&builder)
+	WithDuration(time.Second).Slow("foo")
+	assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
+}
+
+func TestWithDurationSlowf(t *testing.T) {
+	var builder strings.Builder
+	log.SetOutput(&builder)
+	WithDuration(time.Second).WithDuration(time.Hour).Slowf("foo")
+	assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
+}

+ 2 - 0
core/logx/logs.go

@@ -15,6 +15,7 @@ import (
 	"strings"
 	"sync"
 	"sync/atomic"
+	"time"
 
 	"github.com/tal-tech/go-zero/core/iox"
 	"github.com/tal-tech/go-zero/core/sysx"
@@ -96,6 +97,7 @@ type (
 		Infof(string, ...interface{})
 		Slow(...interface{})
 		Slowf(string, ...interface{})
+		WithDuration(time.Duration) Logger
 	}
 )
 

+ 18 - 11
core/logx/tracelog.go → core/logx/tracelogger.go

@@ -4,54 +4,61 @@ import (
 	"context"
 	"fmt"
 	"io"
+	"time"
 
+	"github.com/tal-tech/go-zero/core/timex"
 	"github.com/tal-tech/go-zero/core/trace/tracespec"
 )
 
-type tracingEntry struct {
+type traceLogger struct {
 	logEntry
 	Trace string `json:"trace,omitempty"`
 	Span  string `json:"span,omitempty"`
 	ctx   context.Context
 }
 
-func (l tracingEntry) Error(v ...interface{}) {
+func (l *traceLogger) Error(v ...interface{}) {
 	if shouldLog(ErrorLevel) {
-		l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), customCallerDepth))
+		l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), durationCallerDepth))
 	}
 }
 
-func (l tracingEntry) Errorf(format string, v ...interface{}) {
+func (l *traceLogger) Errorf(format string, v ...interface{}) {
 	if shouldLog(ErrorLevel) {
-		l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), customCallerDepth))
+		l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), durationCallerDepth))
 	}
 }
 
-func (l tracingEntry) Info(v ...interface{}) {
+func (l *traceLogger) Info(v ...interface{}) {
 	if shouldLog(InfoLevel) {
 		l.write(infoLog, levelInfo, fmt.Sprint(v...))
 	}
 }
 
-func (l tracingEntry) Infof(format string, v ...interface{}) {
+func (l *traceLogger) Infof(format string, v ...interface{}) {
 	if shouldLog(InfoLevel) {
 		l.write(infoLog, levelInfo, fmt.Sprintf(format, v...))
 	}
 }
 
-func (l tracingEntry) Slow(v ...interface{}) {
+func (l *traceLogger) Slow(v ...interface{}) {
 	if shouldLog(ErrorLevel) {
 		l.write(slowLog, levelSlow, fmt.Sprint(v...))
 	}
 }
 
-func (l tracingEntry) Slowf(format string, v ...interface{}) {
+func (l *traceLogger) Slowf(format string, v ...interface{}) {
 	if shouldLog(ErrorLevel) {
 		l.write(slowLog, levelSlow, fmt.Sprintf(format, v...))
 	}
 }
 
-func (l tracingEntry) write(writer io.Writer, level, content string) {
+func (l *traceLogger) WithDuration(duration time.Duration) Logger {
+	l.Duration = timex.ReprOfDuration(duration)
+	return l
+}
+
+func (l *traceLogger) write(writer io.Writer, level, content string) {
 	l.Timestamp = getTimestamp()
 	l.Level = level
 	l.Content = content
@@ -61,7 +68,7 @@ func (l tracingEntry) write(writer io.Writer, level, content string) {
 }
 
 func WithContext(ctx context.Context) Logger {
-	return tracingEntry{
+	return &traceLogger{
 		ctx: ctx,
 	}
 }

+ 1 - 1
core/logx/tracelog_test.go → core/logx/tracelogger_test.go

@@ -19,7 +19,7 @@ var mock tracespec.Trace = new(mockTrace)
 func TestTraceLog(t *testing.T) {
 	var buf strings.Builder
 	ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
-	WithContext(ctx).(tracingEntry).write(&buf, levelInfo, testlog)
+	WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
 	assert.True(t, strings.Contains(buf.String(), mockTraceId))
 	assert.True(t, strings.Contains(buf.String(), mockSpanId))
 }

BIN
doc/images/wechat.jpg


+ 11 - 11
doc/jwt.md

@@ -1,4 +1,4 @@
-### 基于go-zero实现JWT认证
+# 基于go-zero实现JWT认证
 
 关于JWT是什么,大家可以看看[官网](https://jwt.io/),一句话介绍下:是可以实现服务器无状态的鉴权认证方案,也是目前最流行的跨域认证解决方案。
 
@@ -7,7 +7,7 @@
 * 客户端获取JWT token。
 * 服务器对客户端带来的JWT token认证。
 
-### 1.  客户端获取JWT Token
+## 1.  客户端获取JWT Token
 
 我们定义一个协议供客户端调用获取JWT token,我们新建一个目录jwt然后在目录中执行 `goctl api -o jwt.api`,将生成的jwt.api改成如下:
 
@@ -61,7 +61,11 @@ func (l *JwtLogic) Jwt(req types.JwtTokenRequest) (*types.JwtTokenResponse, erro
 		return nil, err
 	}
 
-	return &types.JwtTokenResponse{AccessToken: accessToken, AccessExpire: now + accessExpire, RefreshAfter: now + accessExpire/2}, nil
+	return &types.JwtTokenResponse{
+    AccessToken:  accessToken,
+    AccessExpire: now + accessExpire,
+    RefreshAfter: now + accessExpire/2,
+  }, nil
 }
 
 func (l *JwtLogic) GenToken(iat int64, secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
@@ -91,13 +95,11 @@ JwtAuth:
 启动服务器,然后测试下获取到的token。
 
 ```sh
-➜  jwt curl --location --request POST '127.0.0.1:8888/user/token'
+➜ curl --location --request POST '127.0.0.1:8888/user/token'
 {"access_token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDEyNjE0MjksImlhdCI6MTYwMDY1NjYyOX0.6u_hpE_4m5gcI90taJLZtvfekwUmjrbNJ-5saaDGeQc","access_expire":1601261429,"refresh_after":1600959029}
 ```
 
-
-
-### 2 服务器验证JWT token
+## 2. 服务器验证JWT token
 
 1. 在api文件中通过`jwt: JwtAuth`标记的service表示激活了jwt认证。
 2. 可以阅读rest/handler/authhandler.go文件了解服务器jwt实现。
@@ -112,7 +114,7 @@ func (l *GetUserLogic) GetUser(req types.GetUserRequest) (*types.GetUserResponse
 * 我们先不带JWT Authorization header请求头测试下,返回http status code是401,符合预期。
 
 ```sh
-➜  jwt curl -w  "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \
+➜ curl -w  "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \
 --header 'Content-Type: application/json' \
 --data-raw '{
     "userId": "a"
@@ -124,7 +126,7 @@ http: 401
 * 加上Authorization header请求头测试。
 
 ```sh
-➜  jwt curl -w  "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \
+➜ curl -w  "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \
 --header 'Authorization: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDEyNjE0MjksImlhdCI6MTYwMDY1NjYyOX0.6u_hpE_4m5gcI90taJLZtvfekwUmjrbNJ-5saaDGeQc' \
 --header 'Content-Type: application/json' \
 --data-raw '{
@@ -134,7 +136,5 @@ http: 401
 http: 200
 ```
 
-
-
 综上所述:基于go-zero的JWT认证完成,在真实生产环境部署时候,AccessSecret, AccessExpire, RefreshAfter根据业务场景通过配置文件配置,RefreshAfter 是告诉客户端什么时候该刷新JWT token了,一般都需要设置过期时间前几天。
 

+ 12 - 0
rest/httpx/requests_test.go

@@ -109,6 +109,18 @@ func TestParseRequired(t *testing.T) {
 	assert.NotNil(t, err)
 }
 
+func TestParseOptions(t *testing.T) {
+	v := struct {
+		Position int8 `form:"pos,options=1|2"`
+	}{}
+
+	r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?pos=4", nil)
+	assert.Nil(t, err)
+
+	err = Parse(r, &v)
+	assert.NotNil(t, err)
+}
+
 func BenchmarkParseRaw(b *testing.B) {
 	r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", nil)
 	if err != nil {

+ 2 - 2
zrpc/internal/client.go

@@ -66,11 +66,11 @@ func buildDialOptions(opts ...ClientOption) []grpc.DialOption {
 		grpc.WithInsecure(),
 		grpc.WithBlock(),
 		WithUnaryClientInterceptors(
-			clientinterceptors.BreakerInterceptor,
+			clientinterceptors.TracingInterceptor,
 			clientinterceptors.DurationInterceptor,
+			clientinterceptors.BreakerInterceptor,
 			clientinterceptors.PromMetricInterceptor,
 			clientinterceptors.TimeoutInterceptor(clientOptions.Timeout),
-			clientinterceptors.TracingInterceptor,
 		),
 	}
 

+ 4 - 2
zrpc/internal/clientinterceptors/durationinterceptor.go

@@ -18,11 +18,13 @@ func DurationInterceptor(ctx context.Context, method string, req, reply interfac
 	start := timex.Now()
 	err := invoker(ctx, method, req, reply, cc, opts...)
 	if err != nil {
-		logx.WithDuration(timex.Since(start)).Infof("fail - %s - %v - %s", serverName, req, err.Error())
+		logx.WithContext(ctx).WithDuration(timex.Since(start)).Infof("fail - %s - %v - %s",
+			serverName, req, err.Error())
 	} else {
 		elapsed := timex.Since(start)
 		if elapsed > slowThreshold {
-			logx.WithDuration(elapsed).Slowf("[RPC] ok - slowcall - %s - %v - %v", serverName, req, reply)
+			logx.WithContext(ctx).WithDuration(elapsed).Slowf("[RPC] ok - slowcall - %s - %v - %v",
+				serverName, req, reply)
 		}
 	}
 

+ 3 - 0
zrpc/internal/rpcserver.go

@@ -17,6 +17,7 @@ type (
 	}
 
 	rpcServer struct {
+		name string
 		*baseRpcServer
 	}
 )
@@ -40,6 +41,7 @@ func NewRpcServer(address string, opts ...ServerOption) Server {
 }
 
 func (s *rpcServer) SetName(name string) {
+	s.name = name
 	s.baseRpcServer.SetName(name)
 }
 
@@ -50,6 +52,7 @@ func (s *rpcServer) Start(register RegisterFn) error {
 	}
 
 	unaryInterceptors := []grpc.UnaryServerInterceptor{
+		serverinterceptors.UnaryTracingInterceptor(s.name),
 		serverinterceptors.UnaryCrashInterceptor(),
 		serverinterceptors.UnaryStatInterceptor(s.metrics),
 		serverinterceptors.UnaryPromMetricInterceptor(),

+ 4 - 3
zrpc/internal/serverinterceptors/statinterceptor.go

@@ -42,10 +42,11 @@ func logDuration(ctx context.Context, method string, req interface{}, duration t
 	}
 	content, err := json.Marshal(req)
 	if err != nil {
-		logx.Errorf("%s - %s", addr, err.Error())
+		logx.WithContext(ctx).Errorf("%s - %s", addr, err.Error())
 	} else if duration > serverSlowThreshold {
-		logx.WithDuration(duration).Slowf("[RPC] slowcall - %s - %s - %s", addr, method, string(content))
+		logx.WithContext(ctx).WithDuration(duration).Slowf("[RPC] slowcall - %s - %s - %s",
+			addr, method, string(content))
 	} else {
-		logx.WithDuration(duration).Infof("%s - %s - %s", addr, method, string(content))
+		logx.WithContext(ctx).WithDuration(duration).Infof("%s - %s - %s", addr, method, string(content))
 	}
 }

+ 0 - 2
zrpc/server.go

@@ -109,8 +109,6 @@ func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Me
 			time.Duration(c.Timeout) * time.Millisecond))
 	}
 
-	server.AddUnaryInterceptors(serverinterceptors.UnaryTracingInterceptor(c.Name))
-
 	if c.Auth {
 		authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl)
 		if err != nil {