浏览代码

Merge pull request #6136 from heyitsanthony/fix-watcher-leak

clientv3: close watcher stream once all watchers detach
Anthony Romano 9 年之前
父节点
当前提交
88a77f30e1
共有 3 个文件被更改,包括 68 次插入7 次删除
  1. 24 0
      clientv3/integration/watch_test.go
  2. 18 7
      clientv3/watch.go
  3. 26 0
      integration/cluster.go

+ 24 - 0
clientv3/integration/watch_test.go

@@ -727,3 +727,27 @@ func TestWatchWithCreatedNotification(t *testing.T) {
 		t.Fatalf("expected created event, got %v", resp)
 	}
 }
+
+// TestWatchCancelOnServer ensures client watcher cancels propagate back to the server.
+func TestWatchCancelOnServer(t *testing.T) {
+	cluster := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1})
+	defer cluster.Terminate(t)
+
+	client := cluster.RandClient()
+
+	for i := 0; i < 10; i++ {
+		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+		client.Watch(ctx, "a", clientv3.WithCreatedNotify())
+		cancel()
+	}
+	// wait for cancels to propagate
+	time.Sleep(time.Second)
+
+	watchers, err := cluster.Members[0].Metric("etcd_debugging_mvcc_watcher_total")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if watchers != "0" {
+		t.Fatalf("expected 0 watchers, got %q", watchers)
+	}
+}

+ 18 - 7
clientv3/watch.go

@@ -311,7 +311,12 @@ func (w *watcher) Close() (err error) {
 }
 
 func (w *watchGrpcStream) Close() (err error) {
-	close(w.stopc)
+	w.mu.Lock()
+	if w.stopc != nil {
+		close(w.stopc)
+		w.stopc = nil
+	}
+	w.mu.Unlock()
 	<-w.donec
 	select {
 	case err = <-w.errc:
@@ -374,11 +379,17 @@ func (w *watchGrpcStream) addStream(resp *pb.WatchResponse, pendingReq *watchReq
 
 // closeStream closes the watcher resources and removes it
 func (w *watchGrpcStream) closeStream(ws *watcherStream) {
+	w.mu.Lock()
 	// cancels request stream; subscriber receives nil channel
 	close(ws.initReq.retc)
 	// close subscriber's channel
 	close(ws.outc)
 	delete(w.streams, ws.id)
+	if len(w.streams) == 0 && w.stopc != nil {
+		close(w.stopc)
+		w.stopc = nil
+	}
+	w.mu.Unlock()
 }
 
 // run is the root of the goroutines for managing a watcher client
@@ -404,6 +415,7 @@ func (w *watchGrpcStream) run() {
 
 	var pendingReq, failedReq *watchRequest
 	curReqC := w.reqc
+	stopc := w.stopc
 	cancelSet := make(map[int64]struct{})
 
 	for {
@@ -473,7 +485,7 @@ func (w *watchGrpcStream) run() {
 				failedReq = pendingReq
 			}
 			cancelSet = make(map[int64]struct{})
-		case <-w.stopc:
+		case <-stopc:
 			return
 		}
 
@@ -625,9 +637,7 @@ func (w *watchGrpcStream) serveStream(ws *watcherStream) {
 		}
 	}
 
-	w.mu.Lock()
 	w.closeStream(ws)
-	w.mu.Unlock()
 	// lazily send cancel message if events on missing id
 }
 
@@ -655,13 +665,14 @@ func (w *watchGrpcStream) resume() (ws pb.Watch_WatchClient, err error) {
 // openWatchClient retries opening a watchclient until retryConnection fails
 func (w *watchGrpcStream) openWatchClient() (ws pb.Watch_WatchClient, err error) {
 	for {
-		select {
-		case <-w.stopc:
+		w.mu.Lock()
+		stopc := w.stopc
+		w.mu.Unlock()
+		if stopc == nil {
 			if err == nil {
 				err = context.Canceled
 			}
 			return nil, err
-		default:
 		}
 		if ws, err = w.remote.Watch(w.ctx, grpc.FailFast(false)); ws != nil && err == nil {
 			break

+ 26 - 0
integration/cluster.go

@@ -708,6 +708,32 @@ func (m *member) Terminate(t *testing.T) {
 	plog.Printf("terminated %s (%s)", m.Name, m.grpcAddr)
 }
 
+// Metric gets the metric value for a member
+func (m *member) Metric(metricName string) (string, error) {
+	cfgtls := transport.TLSInfo{}
+	tr, err := transport.NewTimeoutTransport(cfgtls, time.Second, time.Second, time.Second)
+	if err != nil {
+		return "", err
+	}
+	cli := &http.Client{Transport: tr}
+	resp, err := cli.Get(m.ClientURLs[0].String() + "/metrics")
+	if err != nil {
+		return "", err
+	}
+	defer resp.Body.Close()
+	b, rerr := ioutil.ReadAll(resp.Body)
+	if rerr != nil {
+		return "", rerr
+	}
+	lines := strings.Split(string(b), "\n")
+	for _, l := range lines {
+		if strings.HasPrefix(l, metricName) {
+			return strings.Split(l, " ")[1], nil
+		}
+	}
+	return "", nil
+}
+
 func MustNewHTTPClient(t *testing.T, eps []string, tls *transport.TLSInfo) client.Client {
 	cfgtls := transport.TLSInfo{}
 	if tls != nil {