Selaa lähdekoodia

context: fix removal of cancelled timer contexts from parent

Change-Id: Iee673c97e6a3b779c3d8ba6bb1b5f2b2e2032b86
Reviewed-on: https://go-review.googlesource.com/3911
Reviewed-by: Sameer Ajmani <sameer@golang.org>
Damien Neil 11 vuotta sitten
vanhempi
commit
f090b05f9b
2 muutettua tiedostoa jossa 41 lisäystä ja 8 poistoa
  1. 19 8
      context/context.go
  2. 22 0
      context/context_test.go

+ 19 - 8
context/context.go

@@ -262,6 +262,19 @@ func parentCancelCtx(parent Context) (*cancelCtx, bool) {
 	}
 }
 
+// removeChild removes a context from its parent.
+func removeChild(parent Context, child canceler) {
+	p, ok := parentCancelCtx(parent)
+	if !ok {
+		return
+	}
+	p.mu.Lock()
+	if p.children != nil {
+		delete(p.children, child)
+	}
+	p.mu.Unlock()
+}
+
 // A canceler is a context type that can be canceled directly.  The
 // implementations are *cancelCtx and *timerCtx.
 type canceler interface {
@@ -316,13 +329,7 @@ func (c *cancelCtx) cancel(removeFromParent bool, err error) {
 	c.mu.Unlock()
 
 	if removeFromParent {
-		if p, ok := parentCancelCtx(c.Context); ok {
-			p.mu.Lock()
-			if p.children != nil {
-				delete(p.children, c)
-			}
-			p.mu.Unlock()
-		}
+		removeChild(c.Context, c)
 	}
 }
 
@@ -380,7 +387,11 @@ func (c *timerCtx) String() string {
 }
 
 func (c *timerCtx) cancel(removeFromParent bool, err error) {
-	c.cancelCtx.cancel(removeFromParent, err)
+	c.cancelCtx.cancel(false, err)
+	if removeFromParent {
+		// Remove this timerCtx from its parent cancelCtx's children.
+		removeChild(c.cancelCtx.Context, c)
+	}
 	c.mu.Lock()
 	if c.timer != nil {
 		c.timer.Stop()

+ 22 - 0
context/context_test.go

@@ -551,3 +551,25 @@ func testLayers(t *testing.T, seed int64, testTimeout bool) {
 		checkValues("after cancel")
 	}
 }
+
+func TestCancelRemoves(t *testing.T) {
+	checkChildren := func(when string, ctx Context, want int) {
+		if got := len(ctx.(*cancelCtx).children); got != want {
+			t.Errorf("%s: context has %d children, want %d", when, got, want)
+		}
+	}
+
+	ctx, _ := WithCancel(Background())
+	checkChildren("after creation", ctx, 0)
+	_, cancel := WithCancel(ctx)
+	checkChildren("with WithCancel child ", ctx, 1)
+	cancel()
+	checkChildren("after cancelling WithCancel child", ctx, 0)
+
+	ctx, _ = WithCancel(Background())
+	checkChildren("after creation", ctx, 0)
+	_, cancel = WithTimeout(ctx, 60*time.Minute)
+	checkChildren("with WithTimeout child ", ctx, 1)
+	cancel()
+	checkChildren("after cancelling WithTimeout child", ctx, 0)
+}