Browse Source

Merge pull request #6816 from heyitsanthony/fix-disconn-cancel

clientv3: let watchers cancel when reconnecting
Anthony Romano 9 years ago
parent
commit
510676fea9
2 changed files with 89 additions and 28 deletions
  1. 19 0
      clientv3/integration/watch_test.go
  2. 70 28
      clientv3/watch.go

+ 19 - 0
clientv3/integration/watch_test.go

@@ -961,3 +961,22 @@ func TestWatchStressResumeClose(t *testing.T) {
 	}
 	clus.TakeClient(0)
 }
+
+// TestWatchCancelDisconnected ensures canceling a watcher works when
+// its grpc stream is disconnected / reconnecting.
+func TestWatchCancelDisconnected(t *testing.T) {
+	defer testutil.AfterTest(t)
+	clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1})
+	defer clus.Terminate(t)
+	cli := clus.Client(0)
+	ctx, cancel := context.WithCancel(context.Background())
+	// add more watches than can be resumed before the cancel
+	wch := cli.Watch(ctx, "abc")
+	clus.Members[0].Stop(t)
+	cancel()
+	select {
+	case <-wch:
+	case <-time.After(time.Second):
+		t.Fatal("took too long to cancel disconnected watcher")
+	}
+}

+ 70 - 28
clientv3/watch.go

@@ -126,8 +126,6 @@ type watchGrpcStream struct {
 	reqc chan *watchRequest
 	// respc receives data from the watch client
 	respc chan *pb.WatchResponse
-	// stopc is sent to the main goroutine to stop all processing
-	stopc chan struct{}
 	// donec closes to broadcast shutdown
 	donec chan struct{}
 	// errc transmits errors from grpc Recv to the watch stream reconn logic
@@ -213,7 +211,6 @@ func (w *watcher) newWatcherGrpcStream(inctx context.Context) *watchGrpcStream {
 
 		respc:    make(chan *pb.WatchResponse),
 		reqc:     make(chan *watchRequest),
-		stopc:    make(chan struct{}),
 		donec:    make(chan struct{}),
 		errc:     make(chan error, 1),
 		closingc: make(chan *watcherStream),
@@ -319,7 +316,7 @@ func (w *watcher) Close() (err error) {
 }
 
 func (w *watchGrpcStream) Close() (err error) {
-	close(w.stopc)
+	w.cancel()
 	<-w.donec
 	select {
 	case err = <-w.errc:
@@ -366,7 +363,7 @@ func (w *watchGrpcStream) closeSubstream(ws *watcherStream) {
 	// close subscriber's channel
 	if closeErr := w.closeErr; closeErr != nil && ws.initReq.ctx.Err() == nil {
 		go w.sendCloseSubstream(ws, &WatchResponse{closeErr: w.closeErr})
-	} else {
+	} else if ws.outc != nil {
 		close(ws.outc)
 	}
 	if ws.id != -1 {
@@ -493,7 +490,7 @@ func (w *watchGrpcStream) run() {
 				wc.Send(ws.initReq.toPB())
 			}
 			cancelSet = make(map[int64]struct{})
-		case <-w.stopc:
+		case <-w.ctx.Done():
 			return
 		case ws := <-w.closingc:
 			w.closeSubstream(ws)
@@ -632,7 +629,8 @@ func (w *watchGrpcStream) serveSubstream(ws *watcherStream, resumec chan struct{
 
 			// TODO pause channel if buffer gets too large
 			ws.buf = append(ws.buf, wr)
-
+		case <-w.ctx.Done():
+			return
 		case <-ws.initReq.ctx.Done():
 			return
 		case <-resumec:
@@ -644,34 +642,78 @@ func (w *watchGrpcStream) serveSubstream(ws *watcherStream, resumec chan struct{
 }
 
 func (w *watchGrpcStream) newWatchClient() (pb.Watch_WatchClient, error) {
-	// connect to grpc stream
-	wc, err := w.openWatchClient()
-	if err != nil {
-		return nil, v3rpc.Error(err)
-	}
 	// mark all substreams as resuming
-	if len(w.substreams)+len(w.resuming) > 0 {
-		close(w.resumec)
-		w.resumec = make(chan struct{})
-		w.joinSubstreams()
-		for _, ws := range w.substreams {
-			ws.id = -1
-			w.resuming = append(w.resuming, ws)
-		}
-		for _, ws := range w.resuming {
-			if ws == nil || ws.closing {
-				continue
-			}
-			ws.donec = make(chan struct{})
-			go w.serveSubstream(ws, w.resumec)
+	close(w.resumec)
+	w.resumec = make(chan struct{})
+	w.joinSubstreams()
+	for _, ws := range w.substreams {
+		ws.id = -1
+		w.resuming = append(w.resuming, ws)
+	}
+	// strip out nils, if any
+	var resuming []*watcherStream
+	for _, ws := range w.resuming {
+		if ws != nil {
+			resuming = append(resuming, ws)
 		}
 	}
+	w.resuming = resuming
 	w.substreams = make(map[int64]*watcherStream)
+
+	// connect to grpc stream while accepting watcher cancelation
+	stopc := make(chan struct{})
+	donec := w.waitCancelSubstreams(stopc)
+	wc, err := w.openWatchClient()
+	close(stopc)
+	<-donec
+
+	// serve all non-closing streams, even if there's a client error
+	// so that the teardown path can shutdown the streams as expected.
+	for _, ws := range w.resuming {
+		if ws.closing {
+			continue
+		}
+		ws.donec = make(chan struct{})
+		go w.serveSubstream(ws, w.resumec)
+	}
+
+	if err != nil {
+		return nil, v3rpc.Error(err)
+	}
+
 	// receive data from new grpc stream
 	go w.serveWatchClient(wc)
 	return wc, nil
 }
 
+func (w *watchGrpcStream) waitCancelSubstreams(stopc <-chan struct{}) <-chan struct{} {
+	var wg sync.WaitGroup
+	wg.Add(len(w.resuming))
+	donec := make(chan struct{})
+	for i := range w.resuming {
+		go func(ws *watcherStream) {
+			defer wg.Done()
+			if ws.closing {
+				return
+			}
+			select {
+			case <-ws.initReq.ctx.Done():
+				// closed ws will be removed from resuming
+				ws.closing = true
+				close(ws.outc)
+				ws.outc = nil
+				go func() { w.closingc <- ws }()
+			case <-stopc:
+			}
+		}(w.resuming[i])
+	}
+	go func() {
+		defer close(donec)
+		wg.Wait()
+	}()
+	return donec
+}
+
 // joinSubstream waits for all substream goroutines to complete
 func (w *watchGrpcStream) joinSubstreams() {
 	for _, ws := range w.substreams {
@@ -688,9 +730,9 @@ func (w *watchGrpcStream) joinSubstreams() {
 func (w *watchGrpcStream) openWatchClient() (ws pb.Watch_WatchClient, err error) {
 	for {
 		select {
-		case <-w.stopc:
+		case <-w.ctx.Done():
 			if err == nil {
-				return nil, context.Canceled
+				return nil, w.ctx.Err()
 			}
 			return nil, err
 		default: