Browse Source

http2: call httptrace.ClientTrace.GetConn in Transport when needed

Tests in CL 122591 in the standard library, to be submitted after this
CL.

Updates golang/go#23041

Change-Id: I3538cc7d2a71e3a26ab4c2f47bb220a25404cddb
Reviewed-on: https://go-review.googlesource.com/122590
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Brad Fitzpatrick 7 years ago
parent
commit
6a8eb5e2b1
4 changed files with 59 additions and 4 deletions
  1. 27 1
      http2/client_conn_pool.go
  2. 8 0
      http2/go17.go
  3. 1 0
      http2/not_go17.go
  4. 23 3
      http2/transport.go

+ 27 - 1
http2/client_conn_pool.go

@@ -52,9 +52,31 @@ const (
 	noDialOnMiss = false
 )
 
+// shouldTraceGetConn reports whether getClientConn should call any
+// ClientTrace.GetConn hook associated with the http.Request.
+//
+// This complexity is needed to avoid double calls of the GetConn hook
+// during the back-and-forth between net/http and x/net/http2 (when the
+// net/http.Transport is upgraded to also speak http2), as well as support
+// the case where x/net/http2 is being used directly.
+func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
+	// If our Transport wasn't made via ConfigureTransport, always
+	// trace the GetConn hook if provided, because that means the
+	// http2 package is being used directly and it's the one
+	// dialing, as opposed to net/http.
+	if _, ok := p.t.ConnPool.(noDialClientConnPool); !ok {
+		return true
+	}
+	// Otherwise, only use the GetConn hook if this connection has
+	// been used previously for other requests. For fresh
+	// connections, the net/http package does the dialing.
+	return !st.freshConn
+}
+
 func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
 	if isConnectionCloseRequest(req) && dialOnMiss {
 		// It gets its own connection.
+		traceGetConn(req, addr)
 		const singleUse = true
 		cc, err := p.t.dialClientConn(addr, singleUse)
 		if err != nil {
@@ -64,7 +86,10 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
 	}
 	p.mu.Lock()
 	for _, cc := range p.conns[addr] {
-		if cc.CanTakeNewRequest() {
+		if st := cc.idleState(); st.canTakeNewRequest {
+			if p.shouldTraceGetConn(st) {
+				traceGetConn(req, addr)
+			}
 			p.mu.Unlock()
 			return cc, nil
 		}
@@ -73,6 +98,7 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
 		p.mu.Unlock()
 		return nil, ErrNoCachedConn
 	}
+	traceGetConn(req, addr)
 	call := p.getStartDialLocked(addr)
 	p.mu.Unlock()
 	<-call.done

+ 8 - 0
http2/go17.go

@@ -50,6 +50,14 @@ func (t *Transport) idleConnTimeout() time.Duration {
 
 func setResponseUncompressed(res *http.Response) { res.Uncompressed = true }
 
+func traceGetConn(req *http.Request, hostPort string) {
+	trace := httptrace.ContextClientTrace(req.Context())
+	if trace == nil || trace.GetConn == nil {
+		return
+	}
+	trace.GetConn(hostPort)
+}
+
 func traceGotConn(req *http.Request, cc *ClientConn) {
 	trace := httptrace.ContextClientTrace(req.Context())
 	if trace == nil || trace.GotConn == nil {

+ 1 - 0
http2/not_go17.go

@@ -37,6 +37,7 @@ func setResponseUncompressed(res *http.Response) {
 type clientTrace struct{}
 
 func requestTrace(*http.Request) *clientTrace { return nil }
+func traceGetConn(*http.Request, string)      {}
 func traceGotConn(*http.Request, *ClientConn) {}
 func traceFirstResponseByte(*clientTrace)     {}
 func traceWroteHeaders(*clientTrace)          {}

+ 23 - 3
http2/transport.go

@@ -631,12 +631,32 @@ func (cc *ClientConn) CanTakeNewRequest() bool {
 	return cc.canTakeNewRequestLocked()
 }
 
-func (cc *ClientConn) canTakeNewRequestLocked() bool {
+// clientConnIdleState describes the suitability of a client
+// connection to initiate a new RoundTrip request.
+type clientConnIdleState struct {
+	canTakeNewRequest bool
+	freshConn         bool // whether it's unused by any previous request
+}
+
+func (cc *ClientConn) idleState() clientConnIdleState {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	return cc.idleStateLocked()
+}
+
+func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
 	if cc.singleUse && cc.nextStreamID > 1 {
-		return false
+		return
 	}
-	return cc.goAway == nil && !cc.closed && !cc.closing &&
+	st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing &&
 		int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
+	st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
+	return
+}
+
+func (cc *ClientConn) canTakeNewRequestLocked() bool {
+	st := cc.idleStateLocked()
+	return st.canTakeNewRequest
 }
 
 // onIdleTimeout is called from a time.AfterFunc goroutine. It will