Browse Source

grpcproxy: shut down watcher proxy when client context is done

Anthony Romano 9 years ago
parent
commit
d3ecebd14e
2 changed files with 25 additions and 12 deletions
  1. 21 11
      proxy/grpcproxy/watch.go
  2. 4 1
      proxy/grpcproxy/watcher_groups.go

+ 21 - 11
proxy/grpcproxy/watch.go

@@ -31,16 +31,25 @@ type watchProxy struct {
 
 	mu           sync.Mutex
 	nextStreamID int64
+
+	ctx context.Context
 }
 
 func NewWatchProxy(c *clientv3.Client) pb.WatchServer {
-	return &watchProxy{
+	wp := &watchProxy{
 		cw: c.Watcher,
 		wgs: watchergroups{
-			cw:     c.Watcher,
-			groups: make(map[watchRange]*watcherGroup),
+			cw:       c.Watcher,
+			groups:   make(map[watchRange]*watcherGroup),
+			proxyCtx: c.Ctx(),
 		},
+		ctx: c.Ctx(),
 	}
+	go func() {
+		<-wp.ctx.Done()
+		wp.wgs.stop()
+	}()
+	return wp
 }
 
 func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) {
@@ -58,13 +67,13 @@ func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) {
 
 		ctrlCh:  make(chan *pb.WatchResponse, 10),
 		watchCh: make(chan *pb.WatchResponse, 10),
+
+		proxyCtx: wp.ctx,
 	}
 
 	go sws.recvLoop()
-
 	sws.sendLoop()
-
-	return nil
+	return wp.ctx.Err()
 }
 
 type serverWatchStream struct {
@@ -81,6 +90,8 @@ type serverWatchStream struct {
 	watchCh chan *pb.WatchResponse
 
 	nextWatcherID int64
+
+	proxyCtx context.Context
 }
 
 func (sws *serverWatchStream) close() {
@@ -89,8 +100,8 @@ func (sws *serverWatchStream) close() {
 
 	var wg sync.WaitGroup
 	sws.mu.Lock()
+	wg.Add(len(sws.singles))
 	for _, ws := range sws.singles {
-		wg.Add(1)
 		ws.stop()
 		// copy the range variable to avoid race
 		copyws := ws
@@ -100,10 +111,7 @@ func (sws *serverWatchStream) close() {
 		}()
 	}
 	sws.mu.Unlock()
-
 	wg.Wait()
-
-	sws.groups.stop()
 }
 
 func (sws *serverWatchStream) recvLoop() error {
@@ -166,6 +174,8 @@ func (sws *serverWatchStream) sendLoop() {
 			if err := sws.gRPCStream.Send(c); err != nil {
 				return
 			}
+		case <-sws.proxyCtx.Done():
+			return
 		}
 	}
 }
@@ -182,7 +192,7 @@ func (sws *serverWatchStream) addDedicatedWatcher(w watcher, rev int64) {
 	sws.mu.Lock()
 	defer sws.mu.Unlock()
 
-	ctx, cancel := context.WithCancel(context.Background())
+	ctx, cancel := context.WithCancel(sws.proxyCtx)
 
 	wch := sws.cw.Watch(ctx,
 		w.wr.key, clientv3.WithRange(w.wr.end),

+ 4 - 1
proxy/grpcproxy/watcher_groups.go

@@ -27,6 +27,8 @@ type watchergroups struct {
 	mu        sync.Mutex
 	groups    map[watchRange]*watcherGroup
 	idToGroup map[receiverID]*watcherGroup
+
+	proxyCtx context.Context
 }
 
 func (wgs *watchergroups) addWatcher(rid receiverID, w watcher) {
@@ -40,7 +42,7 @@ func (wgs *watchergroups) addWatcher(rid receiverID, w watcher) {
 		return
 	}
 
-	ctx, cancel := context.WithCancel(context.Background())
+	ctx, cancel := context.WithCancel(wgs.proxyCtx)
 
 	wch := wgs.cw.Watch(ctx, w.wr.key,
 		clientv3.WithRange(w.wr.end),
@@ -98,4 +100,5 @@ func (wgs *watchergroups) stop() {
 	for _, wg := range wgs.groups {
 		wg.stop()
 	}
+	wgs.groups = nil
 }