浏览代码

client: fix cancel watch

ioutil.ReadAll is a blocking call, we need to wait cancelation
during the call.
Xiang Li 10 年之前
父节点
当前提交
15ac4f08f8
共有 2 个文件被更改,包括 59 次插入3 次删除
  1. 17 1
      client/client.go
  2. 42 2
      client/client_test.go

+ 17 - 1
client/client.go

@@ -293,7 +293,23 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon
 		return nil, nil, err
 	}
 
-	body, err := ioutil.ReadAll(resp.Body)
+	var body []byte
+	done := make(chan struct{})
+	go func() {
+		body, err = ioutil.ReadAll(resp.Body)
+		done <- struct{}{}
+	}()
+
+	select {
+	case <-ctx.Done():
+		err = resp.Body.Close()
+		<-done
+		if err == nil {
+			err = ctx.Err()
+		}
+	case <-done:
+	}
+
 	return resp, body, err
 }
 

+ 42 - 2
client/client_test.go

@@ -186,8 +186,11 @@ type checkableReadCloser struct {
 }
 
 func (c *checkableReadCloser) Close() error {
-	c.closed = true
-	return c.ReadCloser.Close()
+	if !c.closed {
+		c.closed = true
+		return c.ReadCloser.Close()
+	}
+	return nil
 }
 
 func TestSimpleHTTPClientDoCancelContextResponseBodyClosed(t *testing.T) {
@@ -218,6 +221,43 @@ func TestSimpleHTTPClientDoCancelContextResponseBodyClosed(t *testing.T) {
 	}
 }
 
+type blockingBody struct {
+	c chan struct{}
+}
+
+func (bb *blockingBody) Read(p []byte) (n int, err error) {
+	<-bb.c
+	return 0, errors.New("closed")
+}
+
+func (bb *blockingBody) Close() error {
+	close(bb.c)
+	return nil
+}
+
+func TestSimpleHTTPClientDoCancelContextResponseBodyClosedWithBlockingBody(t *testing.T) {
+	tr := newFakeTransport()
+	c := &simpleHTTPClient{transport: tr}
+
+	ctx, cancel := context.WithCancel(context.Background())
+	body := &checkableReadCloser{ReadCloser: &blockingBody{c: make(chan struct{})}}
+	go func() {
+		tr.respchan <- &http.Response{Body: body}
+		time.Sleep(2 * time.Millisecond)
+		// cancel after the body is received
+		cancel()
+	}()
+
+	_, _, err := c.Do(ctx, &fakeAction{})
+	if err == nil {
+		t.Fatalf("expected non-nil error, got nil")
+	}
+
+	if !body.closed {
+		t.Fatalf("expected closed body")
+	}
+}
+
 func TestSimpleHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) {
 	tr := newFakeTransport()
 	c := &simpleHTTPClient{transport: tr}