Переглянути джерело

Merge pull request #5510 from gyuho/clientv3_fix

clientv3: watch resp with error when client close
Gyu-Ho Lee 9 роки тому
батько
коміт
47ef5f7ca5
2 змінених файлів з 66 додано та 1 видалено
  1. 58 0
      clientv3/integration/watch_test.go
  2. 8 1
      clientv3/watch.go

+ 58 - 0
clientv3/integration/watch_test.go

@@ -558,3 +558,61 @@ func TestWatchEventType(t *testing.T) {
 		}
 	}
 }
+
+func TestWatchErrConnClosed(t *testing.T) {
+	defer testutil.AfterTest(t)
+
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1})
+	defer clus.Terminate(t)
+
+	cli := clus.Client(0)
+	wc := clientv3.NewWatcher(cli)
+
+	donec := make(chan struct{})
+	go func() {
+		defer close(donec)
+		wc.Watch(context.TODO(), "foo")
+		if err := wc.Close(); err != nil && err != rpctypes.ErrConnClosed {
+			t.Fatalf("expected %v, got %v", rpctypes.ErrConnClosed, err)
+		}
+	}()
+
+	if err := cli.Close(); err != nil {
+		t.Fatal(err)
+	}
+	clus.TakeClient(0)
+
+	select {
+	case <-time.After(3 * time.Second):
+		t.Fatal("wc.Watch took too long")
+	case <-donec:
+	}
+}
+
+func TestWatchAfterClose(t *testing.T) {
+	defer testutil.AfterTest(t)
+
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1})
+	defer clus.Terminate(t)
+
+	cli := clus.Client(0)
+	clus.TakeClient(0)
+	if err := cli.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	donec := make(chan struct{})
+	go func() {
+		wc := clientv3.NewWatcher(cli)
+		wc.Watch(context.TODO(), "foo")
+		if err := wc.Close(); err != nil && err != rpctypes.ErrConnClosed {
+			t.Fatalf("expected %v, got %v", rpctypes.ErrConnClosed, err)
+		}
+		close(donec)
+	}()
+	select {
+	case <-time.After(3 * time.Second):
+		t.Fatal("wc.Watch took too long")
+	case <-donec:
+	}
+}

+ 8 - 1
clientv3/watch.go

@@ -500,20 +500,27 @@ func (w *watcher) resume() (ws pb.Watch_WatchClient, err error) {
 // openWatchClient retries opening a watchclient until retryConnection fails
 func (w *watcher) openWatchClient() (ws pb.Watch_WatchClient, err error) {
 	for {
+		if err = w.rc.acquire(w.ctx); err != nil {
+			return nil, err
+		}
+
 		select {
 		case <-w.stopc:
 			if err == nil {
 				err = context.Canceled
 			}
+			w.rc.release()
 			return nil, err
 		default:
 		}
 		if ws, err = w.remote.Watch(w.ctx); ws != nil && err == nil {
+			w.rc.release()
 			break
 		} else if isHaltErr(w.ctx, err) {
+			w.rc.release()
 			return nil, v3rpc.Error(err)
 		}
-		err = w.rc.reconnectWait(w.ctx, nil)
+		w.rc.release()
 	}
 	return ws, nil
 }