Browse Source

http2: remove afterReqBodyWriteError wrapper

There was a case where we forgot to undo this wrapper. Instead of fixing
that case, I moved the implementation of ClientConn.RoundTrip into an
unexported method that returns the same info as a bool.

Fixes golang/go#22136

Change-Id: I7e5fc467f9c26fb74b9b83f2b3b7f8882645e34c
Reviewed-on: https://go-review.googlesource.com/75252
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Tom Bergan 8 years ago
parent
commit
ea0da6f35c
1 changed files with 27 additions and 36 deletions
  1. 27 36
      http2/transport.go

+ 27 - 36
http2/transport.go

@@ -274,6 +274,13 @@ func (cs *clientStream) checkResetOrDone() error {
 	}
 }
 
+func (cs *clientStream) getStartedWrite() bool {
+	cc := cs.cc
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	return cs.startedWrite
+}
+
 func (cs *clientStream) abortRequestBodyWrite(err error) {
 	if err == nil {
 		panic("nil error")
@@ -349,14 +356,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
 			return nil, err
 		}
 		traceGotConn(req, cc)
-		res, err := cc.RoundTrip(req)
+		res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
 		if err != nil && retry <= 6 {
-			afterBodyWrite := false
-			if e, ok := err.(afterReqBodyWriteError); ok {
-				err = e
-				afterBodyWrite = true
-			}
-			if req, err = shouldRetryRequest(req, err, afterBodyWrite); err == nil {
+			if req, err = shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
 				// After the first retry, do exponential backoff with 10% jitter.
 				if retry == 0 {
 					continue
@@ -394,16 +396,6 @@ var (
 	errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
 )
 
-// afterReqBodyWriteError is a wrapper around errors returned by ClientConn.RoundTrip.
-// It is used to signal that err happened after part of Request.Body was sent to the server.
-type afterReqBodyWriteError struct {
-	err error
-}
-
-func (e afterReqBodyWriteError) Error() string {
-	return e.err.Error() + "; some request body already written"
-}
-
 // shouldRetryRequest is called by RoundTrip when a request fails to get
 // response headers. It is always called with a non-nil error.
 // It returns either a request to retry (either the same request, or a
@@ -752,8 +744,13 @@ func actualContentLength(req *http.Request) int64 {
 }
 
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
+	resp, _, err := cc.roundTrip(req)
+	return resp, err
+}
+
+func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
 	if err := checkConnHeaders(req); err != nil {
-		return nil, err
+		return nil, false, err
 	}
 	if cc.idleTimer != nil {
 		cc.idleTimer.Stop()
@@ -761,14 +758,14 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 
 	trailers, err := commaSeparatedTrailers(req)
 	if err != nil {
-		return nil, err
+		return nil, false, err
 	}
 	hasTrailers := trailers != ""
 
 	cc.mu.Lock()
 	if err := cc.awaitOpenSlotForRequest(req); err != nil {
 		cc.mu.Unlock()
-		return nil, err
+		return nil, false, err
 	}
 
 	body := req.Body
@@ -802,7 +799,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen)
 	if err != nil {
 		cc.mu.Unlock()
-		return nil, err
+		return nil, false, err
 	}
 
 	cs := cc.newStream()
@@ -828,7 +825,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 		// Don't bother sending a RST_STREAM (our write already failed;
 		// no need to keep writing)
 		traceWroteRequest(cs.trace, werr)
-		return nil, werr
+		return nil, false, werr
 	}
 
 	var respHeaderTimer <-chan time.Time
@@ -847,7 +844,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	bodyWritten := false
 	ctx := reqContext(req)
 
-	handleReadLoopResponse := func(re resAndError) (*http.Response, error) {
+	handleReadLoopResponse := func(re resAndError) (*http.Response, bool, error) {
 		res := re.res
 		if re.err != nil || res.StatusCode > 299 {
 			// On error or status code 3xx, 4xx, 5xx, etc abort any
@@ -863,18 +860,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			cs.abortRequestBodyWrite(errStopReqBodyWrite)
 		}
 		if re.err != nil {
-			cc.mu.Lock()
-			afterBodyWrite := cs.startedWrite
-			cc.mu.Unlock()
 			cc.forgetStreamID(cs.ID)
-			if afterBodyWrite {
-				return nil, afterReqBodyWriteError{re.err}
-			}
-			return nil, re.err
+			return nil, cs.getStartedWrite(), re.err
 		}
 		res.Request = req
 		res.TLS = cc.tlsState
-		return res, nil
+		return res, false, nil
 	}
 
 	for {
@@ -889,7 +880,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			cc.forgetStreamID(cs.ID)
-			return nil, errTimeout
+			return nil, cs.getStartedWrite(), errTimeout
 		case <-ctx.Done():
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
@@ -898,7 +889,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			cc.forgetStreamID(cs.ID)
-			return nil, ctx.Err()
+			return nil, cs.getStartedWrite(), ctx.Err()
 		case <-req.Cancel:
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
@@ -907,12 +898,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			cc.forgetStreamID(cs.ID)
-			return nil, errRequestCanceled
+			return nil, cs.getStartedWrite(), errRequestCanceled
 		case <-cs.peerReset:
 			// processResetStream already removed the
 			// stream from the streams map; no need for
 			// forgetStreamID.
-			return nil, cs.resetErr
+			return nil, cs.getStartedWrite(), cs.resetErr
 		case err := <-bodyWriter.resc:
 			// Prefer the read loop's response, if available. Issue 16102.
 			select {
@@ -921,7 +912,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			default:
 			}
 			if err != nil {
-				return nil, err
+				return nil, cs.getStartedWrite(), err
 			}
 			bodyWritten = true
 			if d := cc.responseHeaderTimeout(); d != 0 {