Browse Source

v3rpc: make watcher wait for its send goroutine to finish

Anthony Romano 9 years ago
parent
commit
c438310634
1 changed files with 21 additions and 10 deletions
  1. 21 10
      etcdserver/api/v3rpc/watch.go

+ 21 - 10
etcdserver/api/v3rpc/watch.go

@@ -94,9 +94,12 @@ type serverWatchStream struct {
 
 
 	// closec indicates the stream is closed.
 	// closec indicates the stream is closed.
 	closec chan struct{}
 	closec chan struct{}
+
+	// wg waits for the send loop to complete
+	wg sync.WaitGroup
 }
 }
 
 
-func (ws *watchServer) Watch(stream pb.Watch_WatchServer) error {
+func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
 	sws := serverWatchStream{
 	sws := serverWatchStream{
 		clusterID:   ws.clusterID,
 		clusterID:   ws.clusterID,
 		memberID:    ws.memberID,
 		memberID:    ws.memberID,
@@ -109,23 +112,30 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) error {
 		closec:     make(chan struct{}),
 		closec:     make(chan struct{}),
 	}
 	}
 
 
-	go sws.sendLoop()
-	errc := make(chan error, 1)
+	sws.wg.Add(1)
 	go func() {
 	go func() {
-		errc <- sws.recvLoop()
-		sws.close()
+		sws.sendLoop()
+		sws.wg.Done()
 	}()
 	}()
+
+	errc := make(chan error, 1)
+	// Ideally recvLoop would also use sws.wg to signal its completion
+	// but when stream.Context().Done() is closed, the stream's recv
+	// may continue to block since it uses a different context, leading to
+	// deadlock when calling sws.close().
+	go func() { errc <- sws.recvLoop() }()
+
 	select {
 	select {
-	case err := <-errc:
-		return err
+	case err = <-errc:
 	case <-stream.Context().Done():
 	case <-stream.Context().Done():
-		err := stream.Context().Err()
+		err = stream.Context().Err()
 		// the only server-side cancellation is noleader for now.
 		// the only server-side cancellation is noleader for now.
 		if err == context.Canceled {
 		if err == context.Canceled {
-			return rpctypes.ErrGRPCNoLeader
+			err = rpctypes.ErrGRPCNoLeader
 		}
 		}
-		return err
 	}
 	}
+	sws.close()
+	return err
 }
 }
 
 
 func (sws *serverWatchStream) recvLoop() error {
 func (sws *serverWatchStream) recvLoop() error {
@@ -292,6 +302,7 @@ func (sws *serverWatchStream) close() {
 	sws.watchStream.Close()
 	sws.watchStream.Close()
 	close(sws.closec)
 	close(sws.closec)
 	close(sws.ctrlStream)
 	close(sws.ctrlStream)
+	sws.wg.Wait()
 }
 }
 
 
 func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
 func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {