Sfoglia il codice sorgente

disable prometheus if not configured (#663)

Kevin Wan 3 anni fa
parent
commit
06eeef2cf3

+ 11 - 1
core/prometheus/agent.go

@@ -7,10 +7,19 @@ import (
 
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/tal-tech/go-zero/core/logx"
+	"github.com/tal-tech/go-zero/core/syncx"
 	"github.com/tal-tech/go-zero/core/threading"
 )
 
-var once sync.Once
+var (
+	once    sync.Once
+	enabled syncx.AtomicBool
+)
+
+// Enabled returns if prometheus is enabled.
+func Enabled() bool {
+	return enabled.True()
+}
 
 // StartAgent starts a prometheus agent.
 func StartAgent(c Config) {
@@ -19,6 +28,7 @@ func StartAgent(c Config) {
 			return
 		}
 
+		enabled.Set(true)
 		threading.GoSafe(func() {
 			http.Handle(c.Path, promhttp.Handler())
 			addr := fmt.Sprintf("%s:%d", c.Host, c.Port)

+ 5 - 0
rest/handler/prometheushandler.go

@@ -6,6 +6,7 @@ import (
 	"time"
 
 	"github.com/tal-tech/go-zero/core/metric"
+	"github.com/tal-tech/go-zero/core/prometheus"
 	"github.com/tal-tech/go-zero/core/timex"
 	"github.com/tal-tech/go-zero/rest/internal/security"
 )
@@ -34,6 +35,10 @@ var (
 // PrometheusHandler returns a middleware that reports stats to prometheus.
 func PrometheusHandler(path string) func(http.Handler) http.Handler {
 	return func(next http.Handler) http.Handler {
+		if !prometheus.Enabled() {
+			return next
+		}
+
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			startTime := timex.Now()
 			cw := &security.WithCodeResponseWriter{Writer: w}

+ 18 - 1
rest/handler/prometheushandler_test.go

@@ -6,9 +6,26 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/prometheus"
 )
 
-func TestPromMetricHandler(t *testing.T) {
+func TestPromMetricHandler_Disabled(t *testing.T) {
+	promMetricHandler := PrometheusHandler("/user/login")
+	handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusOK)
+	}))
+
+	req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
+	resp := httptest.NewRecorder()
+	handler.ServeHTTP(resp, req)
+	assert.Equal(t, http.StatusOK, resp.Code)
+}
+
+func TestPromMetricHandler_Enabled(t *testing.T) {
+	prometheus.StartAgent(prometheus.Config{
+		Host: "localhost",
+		Path: "/",
+	})
 	promMetricHandler := PrometheusHandler("/user/login")
 	handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.WriteHeader(http.StatusOK)

+ 5 - 0
zrpc/internal/clientinterceptors/prometheusinterceptor.go

@@ -6,6 +6,7 @@ import (
 	"time"
 
 	"github.com/tal-tech/go-zero/core/metric"
+	"github.com/tal-tech/go-zero/core/prometheus"
 	"github.com/tal-tech/go-zero/core/timex"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/status"
@@ -35,6 +36,10 @@ var (
 // PrometheusInterceptor is an interceptor that reports to prometheus server.
 func PrometheusInterceptor(ctx context.Context, method string, req, reply interface{},
 	cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+	if !prometheus.Enabled() {
+		return invoker(ctx, method, req, reply, cc, opts...)
+	}
+
 	startTime := timex.Now()
 	err := invoker(ctx, method, req, reply, cc, opts...)
 	metricClientReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), method)

+ 19 - 6
zrpc/internal/clientinterceptors/prometheusinterceptor_test.go

@@ -6,25 +6,38 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/prometheus"
 	"google.golang.org/grpc"
 )
 
 func TestPromMetricInterceptor(t *testing.T) {
 	tests := []struct {
-		name string
-		err  error
+		name   string
+		enable bool
+		err    error
 	}{
 		{
-			name: "nil",
-			err:  nil,
+			name:   "nil",
+			enable: true,
+			err:    nil,
 		},
 		{
-			name: "with error",
-			err:  errors.New("mock"),
+			name:   "with error",
+			enable: true,
+			err:    errors.New("mock"),
+		},
+		{
+			name: "disabled",
 		},
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
+			if test.enable {
+				prometheus.StartAgent(prometheus.Config{
+					Host: "localhost",
+					Path: "/",
+				})
+			}
 			cc := new(grpc.ClientConn)
 			err := PrometheusInterceptor(context.Background(), "/foo", nil, nil, cc,
 				func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,

+ 5 - 0
zrpc/internal/serverinterceptors/prometheusinterceptor.go

@@ -6,6 +6,7 @@ import (
 	"time"
 
 	"github.com/tal-tech/go-zero/core/metric"
+	"github.com/tal-tech/go-zero/core/prometheus"
 	"github.com/tal-tech/go-zero/core/timex"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/status"
@@ -36,6 +37,10 @@ var (
 func UnaryPrometheusInterceptor() grpc.UnaryServerInterceptor {
 	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
 		interface{}, error) {
+		if !prometheus.Enabled() {
+			return handler(ctx, req)
+		}
+
 		startTime := timex.Now()
 		resp, err := handler(ctx, req)
 		metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), info.FullMethod)

+ 16 - 1
zrpc/internal/serverinterceptors/prometheusinterceptor_test.go

@@ -5,10 +5,25 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/prometheus"
 	"google.golang.org/grpc"
 )
 
-func TestUnaryPromMetricInterceptor(t *testing.T) {
+func TestUnaryPromMetricInterceptor_Disabled(t *testing.T) {
+	interceptor := UnaryPrometheusInterceptor()
+	_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
+		FullMethod: "/",
+	}, func(ctx context.Context, req interface{}) (interface{}, error) {
+		return nil, nil
+	})
+	assert.Nil(t, err)
+}
+
+func TestUnaryPromMetricInterceptor_Enabled(t *testing.T) {
+	prometheus.StartAgent(prometheus.Config{
+		Host: "localhost",
+		Path: "/",
+	})
 	interceptor := UnaryPrometheusInterceptor()
 	_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
 		FullMethod: "/",