Browse Source

context/ctxhttp: allow cancellation after Do returns

Fixes #13325.

Change-Id: I17f35232cd0ea43e50ea12db09272195789426e9
Reviewed-on: https://go-review.googlesource.com/18188
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Dave Day 10 years ago
parent
commit
3b90a77d28
3 changed files with 85 additions and 1 deletions
  1. 1 0
      context/ctxhttp/cancelreq.go
  2. 45 1
      context/ctxhttp/ctxhttp.go
  3. 39 0
      context/ctxhttp/ctxhttp_test.go

+ 1 - 0
context/ctxhttp/cancelreq.go

@@ -9,6 +9,7 @@ package ctxhttp
 import "net/http"
 
 func canceler(client *http.Client, req *http.Request) func() {
+	// TODO(djd): Respect any existing value of req.Cancel.
 	ch := make(chan struct{})
 	req.Cancel = ch
 

+ 45 - 1
context/ctxhttp/ctxhttp.go

@@ -36,13 +36,32 @@ func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Resp
 		result <- responseAndError{resp, err}
 	}()
 
+	var resp *http.Response
+
 	select {
 	case <-ctx.Done():
 		cancel()
 		return nil, ctx.Err()
 	case r := <-result:
-		return r.resp, r.err
+		var err error
+		resp, err = r.resp, r.err
+		if err != nil {
+			return resp, err
+		}
 	}
+
+	c := make(chan struct{})
+	go func() {
+		select {
+		case <-ctx.Done():
+			cancel()
+		case <-c:
+			// The response's Body is closed.
+		}
+	}()
+	resp.Body = &notifyingReader{resp.Body, c}
+
+	return resp, nil
 }
 
 // Get issues a GET request via the Do function.
@@ -77,3 +96,28 @@ func Post(ctx context.Context, client *http.Client, url string, bodyType string,
 func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
 	return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
 }
+
+// notifyingReader is an io.ReadCloser that closes the notify channel after
+// Close is called or a Read fails on the underlying ReadCloser.
+type notifyingReader struct {
+	io.ReadCloser
+	notify chan<- struct{}
+}
+
+func (r *notifyingReader) Read(p []byte) (int, error) {
+	n, err := r.ReadCloser.Read(p)
+	if err != nil && r.notify != nil {
+		close(r.notify)
+		r.notify = nil
+	}
+	return n, err
+}
+
+func (r *notifyingReader) Close() error {
+	err := r.ReadCloser.Close()
+	if r.notify != nil {
+		close(r.notify)
+		r.notify = nil
+	}
+	return err
+}

+ 39 - 0
context/ctxhttp/ctxhttp_test.go

@@ -27,6 +27,7 @@ func TestNoTimeout(t *testing.T) {
 		t.Fatalf("error received from client: %v %v", err, resp)
 	}
 }
+
 func TestCancel(t *testing.T) {
 	ctx, cancel := context.WithCancel(context.Background())
 	go func() {
@@ -59,6 +60,44 @@ func TestCancelAfterRequest(t *testing.T) {
 	}
 }
 
+func TestCancelAfterHangingRequest(t *testing.T) {
+	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusOK)
+		w.(http.Flusher).Flush()
+		<-w.(http.CloseNotifier).CloseNotify()
+	})
+
+	serv := httptest.NewServer(handler)
+	defer serv.Close()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	resp, err := Get(ctx, nil, serv.URL)
+	if err != nil {
+		t.Fatalf("unexpected error in Get: %v", err)
+	}
+
+	// Cancel befer reading the body.
+	// Reading Request.Body should fail, since the request was
+	// canceled before anything was written.
+	cancel()
+
+	done := make(chan struct{})
+
+	go func() {
+		b, err := ioutil.ReadAll(resp.Body)
+		if len(b) != 0 || err == nil {
+			t.Errorf(`Read got (%q, %v); want ("", error)`, b, err)
+		}
+		close(done)
+	}()
+
+	select {
+	case <-time.After(1 * time.Second):
+		t.Errorf("Test timed out")
+	case <-done:
+	}
+}
+
 func doRequest(ctx context.Context) (*http.Response, error) {
 	var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		time.Sleep(requestDuration)