Browse Source

http2: client conn pool abstraction

Change-Id: Icbf40b26a25c7084efd062a0a66385450ec537aa
Reviewed-on: https://go-review.googlesource.com/16699
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
Brad Fitzpatrick 10 years ago
parent
commit
d62542d18c
2 changed files with 184 additions and 138 deletions
  1. 118 0
      http2/client_conn_pool.go
  2. 66 138
      http2/transport.go

+ 118 - 0
http2/client_conn_pool.go

@@ -0,0 +1,118 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Transport code's client connection pooling.
+
+package http2
+
+import (
+	"net/http"
+	"sync"
+)
+
+// ClientConnPool manages a pool of HTTP/2 client connections.
+type ClientConnPool interface {
+	GetClientConn(req *http.Request, addr string) (*ClientConn, error)
+	MarkDead(*ClientConn)
+}
+
+type clientConnPool struct {
+	t  *Transport
+	mu sync.Mutex // TODO: switch to RWMutex
+	// TODO: add support for sharing conns based on cert names
+	// (e.g. share conn for googleapis.com and appspot.com)
+	conns map[string][]*ClientConn // key is host:port
+	keys  map[*ClientConn][]string
+}
+
+func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
+	return p.getClientConn(req, addr, true)
+}
+
+func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
+	p.mu.Lock()
+	for _, cc := range p.conns[addr] {
+		if cc.CanTakeNewRequest() {
+			p.mu.Unlock()
+			return cc, nil
+		}
+	}
+	p.mu.Unlock()
+	if !dialOnMiss {
+		return nil, ErrNoCachedConn
+	}
+
+	// TODO(bradfitz): use a singleflight.Group to only lock once per 'key'.
+	// Probably need to vendor it in as github.com/golang/sync/singleflight
+	// though, since the net package already uses it? Also lines up with
+	// sameer, bcmills, et al wanting to open source some sync stuff.
+	cc, err := p.t.dialClientConn(addr)
+	if err != nil {
+		return nil, err
+	}
+	p.addConn(addr, cc)
+	return cc, nil
+}
+
+func (p *clientConnPool) addConn(key string, cc *ClientConn) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	if p.conns == nil {
+		p.conns = make(map[string][]*ClientConn)
+	}
+	if p.keys == nil {
+		p.keys = make(map[*ClientConn][]string)
+	}
+	p.conns[key] = append(p.conns[key], cc)
+	p.keys[cc] = append(p.keys[cc], key)
+}
+
+func (p *clientConnPool) MarkDead(cc *ClientConn) {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	for _, key := range p.keys[cc] {
+		vv, ok := p.conns[key]
+		if !ok {
+			continue
+		}
+		newList := filterOutClientConn(vv, cc)
+		if len(newList) > 0 {
+			p.conns[key] = newList
+		} else {
+			delete(p.conns, key)
+		}
+	}
+	delete(p.keys, cc)
+}
+
+func (p *clientConnPool) closeIdleConnections() {
+	p.mu.Lock()
+	defer p.mu.Unlock()
+	// TODO: don't close a cc if it was just added to the pool
+	// milliseconds ago and has never been used. There's currently
+	// a small race window with the HTTP/1 Transport's integration
+	// where it can add an idle conn just before using it, and
+	// somebody else can concurrently call CloseIdleConns and
+	// break some caller's RoundTrip.
+	for _, vv := range p.conns {
+		for _, cc := range vv {
+			cc.closeIfIdle()
+		}
+	}
+}
+
+func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
+	out := in[:0]
+	for _, v := range in {
+		if v != exclude {
+			out = append(out, v)
+		}
+	}
+	// If we filtered it out, zero out the last item to prevent
+	// the GC from seeing it.
+	if len(in) != len(out) {
+		in[len(in)-1] = nil
+	}
+	return out
+}

+ 66 - 138
http2/transport.go

@@ -57,20 +57,33 @@ type Transport struct {
 	// tls.Client. If nil, the default configuration is used.
 	TLSClientConfig *tls.Config
 
-	// TODO: switch to RWMutex
-	// TODO: add support for sharing conns based on cert names
-	// (e.g. share conn for googleapis.com and appspot.com)
-	connMu sync.Mutex
-	conns  map[string][]*clientConn // key is host:port
+	// ConnPool optionally specifies an alternate connection pool to use.
+	// If nil, the default is used.
+	ConnPool ClientConnPool
+
+	connPoolOnce  sync.Once
+	connPoolOrDef ClientConnPool // non-nil version of ConnPool
+}
+
+func (t *Transport) connPool() ClientConnPool {
+	t.connPoolOnce.Do(t.initConnPool)
+	return t.connPoolOrDef
+}
+
+func (t *Transport) initConnPool() {
+	if t.ConnPool != nil {
+		t.connPoolOrDef = t.ConnPool
+	} else {
+		t.connPoolOrDef = &clientConnPool{t: t}
+	}
 }
 
-// clientConn is the state of a single HTTP/2 client connection to an
+// ClientConn is the state of a single HTTP/2 client connection to an
 // HTTP/2 server.
-type clientConn struct {
+type ClientConn struct {
 	t        *Transport
-	tconn    net.Conn
-	tlsState *tls.ConnectionState
-	connKey  []string // key(s) this connection is cached in, in t.conns
+	tconn    net.Conn             // usually *tls.Conn, except specialized impls
+	tlsState *tls.ConnectionState // nil only for specialized impls
 
 	// readLoop goroutine fields:
 	readerDone chan struct{} // closed on error
@@ -102,7 +115,7 @@ type clientConn struct {
 // clientStream is the state for a single HTTP/2 stream. One of these
 // is created for each Transport.RoundTrip call.
 type clientStream struct {
-	cc      *clientConn
+	cc      *ClientConn
 	ID      uint32
 	resc    chan resAndError
 	bufPipe pipe // buffered pipe with the flow-controlled response payload
@@ -154,24 +167,28 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 	return t.RoundTripOpt(req, RoundTripOpt{})
 }
 
+// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
+// and returns a host:port. The port 443 is added if needed.
+func authorityAddr(authority string) (addr string) {
+	if _, _, err := net.SplitHostPort(authority); err == nil {
+		return authority
+	}
+	return net.JoinHostPort(authority, "443")
+}
+
 // RoundTripOpt is like RoundTrip, but takes options.
 func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
 	if req.URL.Scheme != "https" {
 		return nil, errors.New("http2: unsupported scheme")
 	}
 
-	host, port, err := net.SplitHostPort(req.URL.Host)
-	if err != nil {
-		host = req.URL.Host
-		port = "443"
-	}
-
+	addr := authorityAddr(req.URL.Host)
 	for {
-		cc, err := t.getClientConn(host, port, opt.OnlyCachedConn)
+		cc, err := t.connPool().GetClientConn(req, addr)
 		if err != nil {
 			return nil, err
 		}
-		res, err := cc.roundTrip(req)
+		res, err := cc.RoundTrip(req)
 		if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
 			continue
 		}
@@ -186,12 +203,8 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
 // connected from previous requests but are now sitting idle.
 // It does not interrupt any connections currently in use.
 func (t *Transport) CloseIdleConnections() {
-	t.connMu.Lock()
-	defer t.connMu.Unlock()
-	for _, vv := range t.conns {
-		for _, cc := range vv {
-			cc.closeIfIdle()
-		}
+	if cp, ok := t.connPool().(*clientConnPool); ok {
+		cp.closeIdleConnections()
 	}
 }
 
@@ -202,100 +215,16 @@ func shouldRetryRequest(err error) bool {
 	return err == errClientConnClosed
 }
 
-func (t *Transport) removeClientConn(cc *clientConn) {
-	t.connMu.Lock()
-	defer t.connMu.Unlock()
-	for _, key := range cc.connKey {
-		vv, ok := t.conns[key]
-		if !ok {
-			continue
-		}
-		newList := filterOutClientConn(vv, cc)
-		if len(newList) > 0 {
-			t.conns[key] = newList
-		} else {
-			delete(t.conns, key)
-		}
-	}
-}
-
-func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
-	out := in[:0]
-	for _, v := range in {
-		if v != exclude {
-			out = append(out, v)
-		}
-	}
-	// If we filtered it out, zero out the last item to prevent
-	// the GC from seeing it.
-	if len(in) != len(out) {
-		in[len(in)-1] = nil
-	}
-	return out
-}
-
-// AddIdleConn adds c as an idle conn for Transport.
-// It assumes that c has not yet exchanged SETTINGS frames.
-// The addr maybe be either "host" or "host:port".
-func (t *Transport) AddIdleConn(addr string, c *tls.Conn) error {
-	var key string
-	_, _, err := net.SplitHostPort(addr)
-	if err == nil {
-		key = addr
-	} else {
-		key = addr + ":443"
-	}
-	cc, err := t.newClientConn(key, c)
-	if err != nil {
-		return err
-	}
-
-	t.addConn(key, cc)
-	return nil
-}
-
-func (t *Transport) addConn(key string, cc *clientConn) {
-	t.connMu.Lock()
-	defer t.connMu.Unlock()
-	if t.conns == nil {
-		t.conns = make(map[string][]*clientConn)
-	}
-	t.conns[key] = append(t.conns[key], cc)
-}
-
-func (t *Transport) getClientConn(host, port string, onlyCached bool) (*clientConn, error) {
-	key := net.JoinHostPort(host, port)
-
-	t.connMu.Lock()
-	for _, cc := range t.conns[key] {
-		if cc.canTakeNewRequest() {
-			t.connMu.Unlock()
-			return cc, nil
-		}
-	}
-	t.connMu.Unlock()
-	if onlyCached {
-		return nil, ErrNoCachedConn
-	}
-
-	// TODO(bradfitz): use a singleflight.Group to only lock once per 'key'.
-	// Probably need to vendor it in as github.com/golang/sync/singleflight
-	// though, since the net package already uses it? Also lines up with
-	// sameer, bcmills, et al wanting to open source some sync stuff.
-	cc, err := t.dialClientConn(host, port, key)
+func (t *Transport) dialClientConn(addr string) (*ClientConn, error) {
+	host, _, err := net.SplitHostPort(addr)
 	if err != nil {
 		return nil, err
 	}
-	t.addConn(key, cc)
-	return cc, nil
-}
-
-func (t *Transport) dialClientConn(host, port, key string) (*clientConn, error) {
-	tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host))
+	tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
 	if err != nil {
 		return nil, err
 	}
-	return t.newClientConn(key, tconn)
+	return t.NewClientConn(tconn)
 }
 
 func (t *Transport) newTLSConfig(host string) *tls.Config {
@@ -338,15 +267,14 @@ func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.C
 	return cn, nil
 }
 
-func (t *Transport) newClientConn(key string, tconn net.Conn) (*clientConn, error) {
-	if _, err := tconn.Write(clientPreface); err != nil {
+func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
+	if _, err := c.Write(clientPreface); err != nil {
 		return nil, err
 	}
 
-	cc := &clientConn{
+	cc := &ClientConn{
 		t:                    t,
-		tconn:                tconn,
-		connKey:              []string{key}, // TODO: cert's validated hostnames too
+		tconn:                c,
 		readerDone:           make(chan struct{}),
 		nextStreamID:         1,
 		maxFrameSize:         16 << 10, // spec default
@@ -359,15 +287,15 @@ func (t *Transport) newClientConn(key string, tconn net.Conn) (*clientConn, erro
 
 	// TODO: adjust this writer size to account for frame size +
 	// MTU + crypto/tls record padding.
-	cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
-	cc.br = bufio.NewReader(tconn)
+	cc.bw = bufio.NewWriter(stickyErrWriter{c, &cc.werr})
+	cc.br = bufio.NewReader(c)
 	cc.fr = NewFramer(cc.bw, cc.br)
 	cc.henc = hpack.NewEncoder(&cc.hbuf)
 
 	type connectionStater interface {
 		ConnectionState() tls.ConnectionState
 	}
-	if cs, ok := tconn.(connectionStater); ok {
+	if cs, ok := c.(connectionStater); ok {
 		state := cs.ConnectionState()
 		cc.tlsState = &state
 	}
@@ -414,13 +342,13 @@ func (t *Transport) newClientConn(key string, tconn net.Conn) (*clientConn, erro
 	return cc, nil
 }
 
-func (cc *clientConn) setGoAway(f *GoAwayFrame) {
+func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	cc.goAway = f
 }
 
-func (cc *clientConn) canTakeNewRequest() bool {
+func (cc *ClientConn) CanTakeNewRequest() bool {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	return cc.goAway == nil &&
@@ -428,7 +356,7 @@ func (cc *clientConn) canTakeNewRequest() bool {
 		cc.nextStreamID < 2147483647
 }
 
-func (cc *clientConn) closeIfIdle() {
+func (cc *ClientConn) closeIfIdle() {
 	cc.mu.Lock()
 	if len(cc.streams) > 0 {
 		cc.mu.Unlock()
@@ -447,7 +375,7 @@ const maxAllocFrameSize = 512 << 10
 // They're capped at the min of the peer's max frame size or 512KB
 // (kinda arbitrarily), but definitely capped so we don't allocate 4GB
 // bufers.
-func (cc *clientConn) frameScratchBuffer() []byte {
+func (cc *ClientConn) frameScratchBuffer() []byte {
 	cc.mu.Lock()
 	size := cc.maxFrameSize
 	if size > maxAllocFrameSize {
@@ -464,7 +392,7 @@ func (cc *clientConn) frameScratchBuffer() []byte {
 	return make([]byte, size)
 }
 
-func (cc *clientConn) putFrameScratchBuffer(buf []byte) {
+func (cc *ClientConn) putFrameScratchBuffer(buf []byte) {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate.
@@ -481,7 +409,7 @@ func (cc *clientConn) putFrameScratchBuffer(buf []byte) {
 	// forget about it.
 }
 
-func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
+func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	cc.mu.Lock()
 
 	if cc.closed {
@@ -649,7 +577,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int32) (taken int32, err error
 }
 
 // requires cc.mu be held.
-func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
+func (cc *ClientConn) encodeHeaders(req *http.Request) []byte {
 	cc.hbuf.Reset()
 
 	// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
@@ -680,7 +608,7 @@ func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
 	return cc.hbuf.Bytes()
 }
 
-func (cc *clientConn) writeHeader(name, value string) {
+func (cc *ClientConn) writeHeader(name, value string) {
 	cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
 }
 
@@ -690,7 +618,7 @@ type resAndError struct {
 }
 
 // requires cc.mu be held.
-func (cc *clientConn) newStream() *clientStream {
+func (cc *ClientConn) newStream() *clientStream {
 	cs := &clientStream{
 		cc:        cc,
 		ID:        cc.nextStreamID,
@@ -706,7 +634,7 @@ func (cc *clientConn) newStream() *clientStream {
 	return cs
 }
 
-func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
+func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	cs := cc.streams[id]
@@ -718,7 +646,7 @@ func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
 
 // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
 type clientConnReadLoop struct {
-	cc        *clientConn
+	cc        *ClientConn
 	activeRes map[uint32]*clientStream // keyed by streamID
 
 	// continueStreamID is the stream ID we're waiting for
@@ -734,7 +662,7 @@ type clientConnReadLoop struct {
 }
 
 // readLoop runs in its own goroutine and reads and dispatches frames.
-func (cc *clientConn) readLoop() {
+func (cc *ClientConn) readLoop() {
 	rl := &clientConnReadLoop{
 		cc:        cc,
 		activeRes: make(map[uint32]*clientStream),
@@ -754,7 +682,7 @@ func (cc *clientConn) readLoop() {
 func (rl *clientConnReadLoop) cleanup() {
 	cc := rl.cc
 	defer cc.tconn.Close()
-	defer cc.t.removeClientConn(cc)
+	defer cc.t.connPool().MarkDead(cc)
 	defer close(cc.readerDone)
 
 	// Close any response bodies if the server closes prematurely.
@@ -978,7 +906,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
 
 func (rl *clientConnReadLoop) processGoAway(f *GoAwayFrame) error {
 	cc := rl.cc
-	cc.t.removeClientConn(cc)
+	cc.t.connPool().MarkDead(cc)
 	if f.ErrCode != 0 {
 		// TODO: deal with GOAWAY more. particularly the error code
 		cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
@@ -1066,7 +994,7 @@ func (rl *clientConnReadLoop) processPushPromise(f *PushPromiseFrame) error {
 	return ConnectionError(ErrCodeProtocol)
 }
 
-func (cc *clientConn) writeStreamReset(streamID uint32, code ErrCode, err error) {
+func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, err error) {
 	// TODO: do something with err? send it as a debug frame to the peer?
 	// But that's only in GOAWAY. Invent a new frame type? Is there one already?
 	cc.wmu.Lock()
@@ -1108,11 +1036,11 @@ func (rl *clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) {
 	}
 }
 
-func (cc *clientConn) logf(format string, args ...interface{}) {
+func (cc *ClientConn) logf(format string, args ...interface{}) {
 	cc.t.logf(format, args...)
 }
 
-func (cc *clientConn) vlogf(format string, args ...interface{}) {
+func (cc *ClientConn) vlogf(format string, args ...interface{}) {
 	cc.t.vlogf(format, args...)
 }