Browse Source

http2: delay sending request body in Transport if 100-continue is set

In Go 1.6, the HTTP/1 client got Transport.ExpectContinueTimeout.

This makes the HTTP/2 client respect a Request's "Expect:
100-continue" field and the Transport.ExpectContinueTimeout
configuration.

This also makes sure to call the traceWroteRequest hook if the server
replied while we're still writing the request, since that code was
in the same spot and it couldn't be trivially separated.

Updates golang/go#13851 (fixed after integrating it into std)
Updates golang/go#15744

Change-Id: I67dfd68532daa6c4a0c026549c6e5cbfce50e1ea
Reviewed-on: https://go-review.googlesource.com/23235
Reviewed-by: Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick 9 năm trước cách đây
mục cha
commit
202ff482f7
5 tập tin đã thay đổi với 167 bổ sung35 xóa
  1. 16 0
      http2/go16.go
  2. 14 2
      http2/go17.go
  3. 8 1
      http2/not_go16.go
  4. 2 0
      http2/not_go17.go
  5. 127 32
      http2/transport.go

+ 16 - 0
http2/go16.go

@@ -0,0 +1,16 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.6
+
+package http2
+
+import (
+	"net/http"
+	"time"
+)
+
+func transportExpectContinueTimeout(t1 *http.Transport) time.Duration {
+	return t1.ExpectContinueTimeout
+}

+ 14 - 2
http2/go17.go

@@ -49,8 +49,8 @@ func traceGotConn(req *http.Request, cc *ClientConn) {
 	ci := httptrace.GotConnInfo{Conn: cc.tconn}
 	ci := httptrace.GotConnInfo{Conn: cc.tconn}
 	cc.mu.Lock()
 	cc.mu.Lock()
 	ci.Reused = cc.nextStreamID > 1
 	ci.Reused = cc.nextStreamID > 1
-	ci.WasIdle = len(cc.streams) == 0
-	if ci.WasIdle {
+	ci.WasIdle = len(cc.streams) == 0 && ci.Reused
+	if ci.WasIdle && !cc.lastActive.IsZero() {
 		ci.IdleTime = time.Now().Sub(cc.lastActive)
 		ci.IdleTime = time.Now().Sub(cc.lastActive)
 	}
 	}
 	cc.mu.Unlock()
 	cc.mu.Unlock()
@@ -64,6 +64,18 @@ func traceWroteHeaders(trace *clientTrace) {
 	}
 	}
 }
 }
 
 
+func traceGot100Continue(trace *clientTrace) {
+	if trace != nil && trace.Got100Continue != nil {
+		trace.Got100Continue()
+	}
+}
+
+func traceWait100Continue(trace *clientTrace) {
+	if trace != nil && trace.Wait100Continue != nil {
+		trace.Wait100Continue()
+	}
+}
+
 func traceWroteRequest(trace *clientTrace, err error) {
 func traceWroteRequest(trace *clientTrace, err error) {
 	if trace != nil && trace.WroteRequest != nil {
 	if trace != nil && trace.WroteRequest != nil {
 		trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
 		trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})

+ 8 - 1
http2/not_go16.go

@@ -6,8 +6,15 @@
 
 
 package http2
 package http2
 
 
-import "net/http"
+import (
+	"net/http"
+	"time"
+)
 
 
 func configureTransport(t1 *http.Transport) (*Transport, error) {
 func configureTransport(t1 *http.Transport) (*Transport, error) {
 	return nil, errTransportVersion
 	return nil, errTransportVersion
 }
 }
+
+func transportExpectContinueTimeout(t1 *http.Transport) time.Duration {
+	return 0
+}

+ 2 - 0
http2/not_go17.go

@@ -33,6 +33,8 @@ func traceGotConn(*http.Request, *ClientConn) {}
 func traceFirstResponseByte(*clientTrace)     {}
 func traceFirstResponseByte(*clientTrace)     {}
 func traceWroteHeaders(*clientTrace)          {}
 func traceWroteHeaders(*clientTrace)          {}
 func traceWroteRequest(*clientTrace, error)   {}
 func traceWroteRequest(*clientTrace, error)   {}
+func traceGot100Continue(trace *clientTrace)  {}
+func traceWait100Continue(trace *clientTrace) {}
 
 
 func nop() {}
 func nop() {}
 
 

+ 127 - 32
http2/transport.go

@@ -178,6 +178,7 @@ type clientStream struct {
 	resc          chan resAndError
 	resc          chan resAndError
 	bufPipe       pipe // buffered pipe with the flow-controlled response payload
 	bufPipe       pipe // buffered pipe with the flow-controlled response payload
 	requestedGzip bool
 	requestedGzip bool
+	on100         func() // optional code to run if get a 100 continue response
 
 
 	flow        flow  // guarded by cc.mu
 	flow        flow  // guarded by cc.mu
 	inflow      flow  // guarded by cc.mu
 	inflow      flow  // guarded by cc.mu
@@ -387,6 +388,13 @@ func (t *Transport) disableKeepAlives() bool {
 	return t.t1 != nil && t.t1.DisableKeepAlives
 	return t.t1 != nil && t.t1.DisableKeepAlives
 }
 }
 
 
+func (t *Transport) expectContinueTimeout() time.Duration {
+	if t.t1 == nil {
+		return 0
+	}
+	return transportExpectContinueTimeout(t.t1)
+}
+
 func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
 func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
 	if VerboseLogs {
 	if VerboseLogs {
 		t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
 		t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
@@ -593,6 +601,33 @@ func checkConnHeaders(req *http.Request) error {
 	return nil
 	return nil
 }
 }
 
 
+func bodyAndLength(req *http.Request) (body io.Reader, contentLen int64) {
+	body = req.Body
+	if body == nil {
+		return nil, 0
+	}
+	if req.ContentLength != 0 {
+		return req.Body, req.ContentLength
+	}
+
+	// We have a body but a zero content length. Test to see if
+	// it's actually zero or just unset.
+	var buf [1]byte
+	n, rerr := io.ReadFull(body, buf[:])
+	if rerr != nil && rerr != io.EOF {
+		return errorReader{rerr}, -1
+	}
+	if n == 1 {
+		// Oh, guess there is data in this Body Reader after all.
+		// The ContentLength field just wasn't set.
+		// Stich the Body back together again, re-attaching our
+		// consumed byte.
+		return io.MultiReader(bytes.NewReader(buf[:]), body), -1
+	}
+	// Body is actually zero bytes.
+	return nil, 0
+}
+
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	if err := checkConnHeaders(req); err != nil {
 	if err := checkConnHeaders(req); err != nil {
 		return nil, err
 		return nil, err
@@ -604,27 +639,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 	}
 	hasTrailers := trailers != ""
 	hasTrailers := trailers != ""
 
 
-	var body io.Reader = req.Body
-	contentLen := req.ContentLength
-	if req.Body != nil && contentLen == 0 {
-		// Test to see if it's actually zero or just unset.
-		var buf [1]byte
-		n, rerr := io.ReadFull(body, buf[:])
-		if rerr != nil && rerr != io.EOF {
-			contentLen = -1
-			body = errorReader{rerr}
-		} else if n == 1 {
-			// Oh, guess there is data in this Body Reader after all.
-			// The ContentLength field just wasn't set.
-			// Stich the Body back together again, re-attaching our
-			// consumed byte.
-			contentLen = -1
-			body = io.MultiReader(bytes.NewReader(buf[:]), body)
-		} else {
-			// Body is actually empty.
-			body = nil
-		}
-	}
+	body, contentLen := bodyAndLength(req)
+	hasBody := body != nil
 
 
 	cc.mu.Lock()
 	cc.mu.Lock()
 	cc.lastActive = time.Now()
 	cc.lastActive = time.Now()
@@ -666,8 +682,9 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	cs := cc.newStream()
 	cs := cc.newStream()
 	cs.req = req
 	cs.req = req
 	cs.trace = requestTrace(req)
 	cs.trace = requestTrace(req)
-	hasBody := body != nil
 	cs.requestedGzip = requestedGzip
 	cs.requestedGzip = requestedGzip
+	bodyWriter := cc.t.getBodyWriterState(cs, body)
+	cs.on100 = bodyWriter.on100
 
 
 	cc.wmu.Lock()
 	cc.wmu.Lock()
 	endStream := !hasBody && !hasTrailers
 	endStream := !hasBody && !hasTrailers
@@ -679,6 +696,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	if werr != nil {
 	if werr != nil {
 		if hasBody {
 		if hasBody {
 			req.Body.Close() // per RoundTripper contract
 			req.Body.Close() // per RoundTripper contract
+			bodyWriter.cancel()
 		}
 		}
 		cc.forgetStreamID(cs.ID)
 		cc.forgetStreamID(cs.ID)
 		// Don't bother sending a RST_STREAM (our write already failed;
 		// Don't bother sending a RST_STREAM (our write already failed;
@@ -688,12 +706,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 	}
 
 
 	var respHeaderTimer <-chan time.Time
 	var respHeaderTimer <-chan time.Time
-	var bodyCopyErrc chan error // result of body copy
 	if hasBody {
 	if hasBody {
-		bodyCopyErrc = make(chan error, 1)
-		go func() {
-			bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
-		}()
+		bodyWriter.scheduleBodyWrite()
 	} else {
 	} else {
 		traceWroteRequest(cs.trace, nil)
 		traceWroteRequest(cs.trace, nil)
 		if d := cc.responseHeaderTimeout(); d != 0 {
 		if d := cc.responseHeaderTimeout(); d != 0 {
@@ -721,6 +735,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 				// doesn't, they'll RST_STREAM us soon enough.  This is a
 				// doesn't, they'll RST_STREAM us soon enough.  This is a
 				// heuristic to avoid adding knobs to Transport.  Hopefully
 				// heuristic to avoid adding knobs to Transport.  Hopefully
 				// we can keep it.
 				// we can keep it.
+				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWrite)
 				cs.abortRequestBodyWrite(errStopReqBodyWrite)
 			}
 			}
 			if re.err != nil {
 			if re.err != nil {
@@ -735,6 +750,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			if !hasBody || bodyWritten {
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
 			} else {
+				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			}
 			return nil, errTimeout
 			return nil, errTimeout
@@ -743,6 +759,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			if !hasBody || bodyWritten {
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
 			} else {
+				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			}
 			return nil, ctx.Err()
 			return nil, ctx.Err()
@@ -751,6 +768,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			if !hasBody || bodyWritten {
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
 			} else {
+				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			}
 			return nil, errRequestCanceled
 			return nil, errRequestCanceled
@@ -759,8 +777,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			// stream from the streams map; no need for
 			// stream from the streams map; no need for
 			// forgetStreamID.
 			// forgetStreamID.
 			return nil, cs.resetErr
 			return nil, cs.resetErr
-		case err := <-bodyCopyErrc:
-			traceWroteRequest(cs.trace, err)
+		case err := <-bodyWriter.resc:
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
@@ -821,6 +838,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
 	defer cc.putFrameScratchBuffer(buf)
 	defer cc.putFrameScratchBuffer(buf)
 
 
 	defer func() {
 	defer func() {
+		traceWroteRequest(cs.trace, err)
 		// TODO: write h12Compare test showing whether
 		// TODO: write h12Compare test showing whether
 		// Request.Body is closed by the Transport,
 		// Request.Body is closed by the Transport,
 		// and in multiple cases: server replies <=299 and >299
 		// and in multiple cases: server replies <=299 and >299
@@ -1281,9 +1299,10 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
 	}
 	}
 
 
 	if statusCode == 100 {
 	if statusCode == 100 {
-		// Just skip 100-continue response headers for now.
-		// TODO: golang.org/issue/13851 for doing it properly.
-		// TODO: also call the httptrace.ClientTrace hooks
+		traceGot100Continue(cs.trace)
+		if cs.on100 != nil {
+			cs.on100() // forces any write delay timer to fire
+		}
 		cs.pastHeaders = false // do it all again
 		cs.pastHeaders = false // do it all again
 		return nil, nil
 		return nil, nil
 	}
 	}
@@ -1716,3 +1735,79 @@ func (gz *gzipReader) Close() error {
 type errorReader struct{ err error }
 type errorReader struct{ err error }
 
 
 func (r errorReader) Read(p []byte) (int, error) { return 0, r.err }
 func (r errorReader) Read(p []byte) (int, error) { return 0, r.err }
+
+// bodyWriterState encapsulates various state around the Transport's writing
+// of the request body, particularly regarding doing delayed writes of the body
+// when the request contains "Expect: 100-continue".
+type bodyWriterState struct {
+	cs     *clientStream
+	timer  *time.Timer   // if non-nil, we're doing a delayed write
+	fnonce *sync.Once    // to call fn with
+	fn     func()        // the code to run in the goroutine, writing the body
+	resc   chan error    // result of fn's execution
+	delay  time.Duration // how long we should delay a delayed write for
+}
+
+func (t *Transport) getBodyWriterState(cs *clientStream, body io.Reader) (s bodyWriterState) {
+	s.cs = cs
+	if body == nil {
+		return
+	}
+	resc := make(chan error, 1)
+	s.resc = resc
+	s.fn = func() {
+		resc <- cs.writeRequestBody(body, cs.req.Body)
+	}
+	s.delay = t.expectContinueTimeout()
+	if s.delay == 0 ||
+		!httplex.HeaderValuesContainsToken(
+			cs.req.Header["Expect"],
+			"100-continue") {
+		return
+	}
+	s.fnonce = new(sync.Once)
+
+	// Arm the timer with a very large duration, which we'll
+	// intentionally lower later. It has to be large now because
+	// we need a handle to it before writing the headers, but the
+	// s.delay value is defined to not start until after the
+	// request headers were written.
+	const hugeDuration = 365 * 24 * time.Hour
+	s.timer = time.AfterFunc(hugeDuration, func() {
+		s.fnonce.Do(s.fn)
+	})
+	return
+}
+
+func (s bodyWriterState) cancel() {
+	if s.timer != nil {
+		s.timer.Stop()
+	}
+}
+
+func (s bodyWriterState) on100() {
+	if s.timer == nil {
+		// If we didn't do a delayed write, ignore the server's
+		// bogus 100 continue response.
+		return
+	}
+	s.timer.Stop()
+	go func() { s.fnonce.Do(s.fn) }()
+}
+
+// scheduleBodyWrite starts writing the body, either immediately (in
+// the common case) or after the delay timeout. It should not be
+// called until after the headers have been written.
+func (s bodyWriterState) scheduleBodyWrite() {
+	if s.timer == nil {
+		// We're not doing a delayed write (see
+		// getBodyWriterState), so just start the writing
+		// goroutine immediately.
+		go s.fn()
+		return
+	}
+	traceWait100Continue(s.cs.trace)
+	if s.timer.Stop() {
+		s.timer.Reset(s.delay)
+	}
+}