Browse Source

transport: cache and re-use client connections

Brad Fitzpatrick 11 years ago
parent
commit
45702eac0f
3 changed files with 190 additions and 35 deletions
  1. 4 0
      frame.go
  2. 148 32
      transport.go
  3. 38 3
      transport_test.go

+ 4 - 0
frame.go

@@ -1107,3 +1107,7 @@ func readUint32(p []byte) (remain []byte, v uint32, err error) {
 type streamEnder interface {
 	StreamEnded() bool
 }
+
+type headersEnder interface {
+	HeadersEnded() bool
+}

+ 148 - 32
transport.go

@@ -27,32 +27,36 @@ type Transport struct {
 
 	// TODO: remove this and make more general with a TLS dial hook, like http
 	InsecureTLSDial bool
+
+	connMu sync.Mutex
+	conns  map[string][]*clientConn // key is host:port
 }
 
 type clientConn struct {
+	t        *Transport
 	tconn    *tls.Conn
 	tlsState *tls.ConnectionState
+	connKey  []string // key(s) this connection is cached in, in t.conns
 
 	readerDone chan struct{} // closed on error
 	readerErr  error         // set before readerDone is closed
-
-	hbuf bytes.Buffer // HPACK encoder writes into this
-	henc *hpack.Encoder
-
-	hdec *hpack.Decoder
-
-	nextRes *http.Response
-
-	// Settings from peer:
-	maxFrameSize uint32
+	hdec       *hpack.Decoder
+	nextRes    *http.Response
 
 	mu           sync.Mutex
+	goAway       *GoAwayFrame // if non-nil, the GoAwayFrame we received
 	streams      map[uint32]*clientStream
 	nextStreamID uint32
 	bw           *bufio.Writer
 	werr         error // first write error that has occurred
 	br           *bufio.Reader
 	fr           *Framer
+	// Settings from peer:
+	maxFrameSize         uint32
+	maxConcurrentStreams uint32
+	initialWindowSize    uint32
+	hbuf                 bytes.Buffer // HPACK encoder writes into this
+	henc                 *hpack.Encoder
 }
 
 type clientStream struct {
@@ -96,7 +100,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 			return nil, err
 		}
 		res, err := cc.roundTrip(req)
-		if isShutdownError(err) {
+		if isShutdownError(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
 			continue
 		}
 		if err != nil {
@@ -106,13 +110,69 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 }
 
+// CloseIdleConnections closes any connections which were previously
+// connected from previous requests but are now sitting idle.
+// It does not interrupt any connections currently in use.
+func (t *Transport) CloseIdleConnections() {
+	t.connMu.Lock()
+	defer t.connMu.Unlock()
+	for _, vv := range t.conns {
+		for _, cc := range vv {
+			cc.closeIfIdle()
+		}
+	}
+}
+
 func isShutdownError(err error) bool {
 	// TODO: implement
 	return false
 }
 
+func (t *Transport) removeClientConn(cc *clientConn) {
+	t.connMu.Lock()
+	defer t.connMu.Unlock()
+	for _, key := range cc.connKey {
+		vv, ok := t.conns[key]
+		if !ok {
+			continue
+		}
+		t.conns[key] = filterOutClientConn(vv, cc)
+	}
+}
+
+func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
+	out := in[:0]
+	for _, v := range in {
+		if v != exclude {
+			out = append(out, v)
+		}
+	}
+	return out
+}
+
 func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
-	// TODO: cache these
+	t.connMu.Lock()
+	defer t.connMu.Unlock()
+
+	key := net.JoinHostPort(host, port)
+
+	for _, cc := range t.conns[key] {
+		if cc.canTakeNewRequest() {
+			return cc, nil
+		}
+	}
+	if t.conns == nil {
+		t.conns = make(map[string][]*clientConn)
+	}
+	cc, err := t.newClientConn(host, port, key)
+	if err != nil {
+		return nil, err
+	}
+	t.conns[key] = append(t.conns[key], cc)
+	return cc, nil
+}
+
+func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
 	cfg := &tls.Config{
 		ServerName:         host,
 		NextProtos:         []string{NextProtoTLS},
@@ -143,11 +203,16 @@ func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
 	}
 
 	cc := &clientConn{
-		tconn:        tconn,
-		tlsState:     &state,
-		readerDone:   make(chan struct{}),
-		nextStreamID: 1,
-		streams:      make(map[uint32]*clientStream),
+		t:                    t,
+		tconn:                tconn,
+		connKey:              []string{key}, // TODO: cert's validated hostnames too
+		tlsState:             &state,
+		readerDone:           make(chan struct{}),
+		nextStreamID:         1,
+		maxFrameSize:         16 << 10, // spec default
+		initialWindowSize:    65535,    // spec default
+		maxConcurrentStreams: 1000,     // "infinite", per spec. 1000 seems good enough.
+		streams:              make(map[uint32]*clientStream),
 	}
 	cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
 	cc.br = bufio.NewReader(tconn)
@@ -178,8 +243,12 @@ func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
 		switch s.ID {
 		case SettingMaxFrameSize:
 			cc.maxFrameSize = s.Val
-		// TODO(bradfitz): handle the others
+		case SettingMaxConcurrentStreams:
+			cc.maxConcurrentStreams = s.Val
+		case SettingInitialWindowSize:
+			cc.initialWindowSize = s.Val
 		default:
+			// TODO(bradfitz): handle more
 			log.Printf("Unhandled Setting: %v", s)
 		}
 		return nil
@@ -191,6 +260,26 @@ func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
 	return cc, nil
 }
 
+func (cc *clientConn) setGoAway(f *GoAwayFrame) {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	cc.goAway = f
+}
+
+func (cc *clientConn) canTakeNewRequest() bool {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	return cc.goAway == nil && int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams)
+}
+
+func (cc *clientConn) closeIfIdle() {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	if len(cc.streams) > 0 {
+		return
+	}
+}
+
 func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
 	cc.mu.Lock()
 
@@ -223,6 +312,11 @@ func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
 	werr := cc.werr
 	cc.mu.Unlock()
 
+	if hasBody {
+		// TODO: write data. and it should probably be interleaved:
+		//   go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
+	}
+
 	if werr != nil {
 		return nil, werr
 	}
@@ -290,14 +384,19 @@ func (cc *clientConn) newStream() *clientStream {
 	return cs
 }
 
-func (cc *clientConn) streamByID(id uint32) *clientStream {
+func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
-	return cc.streams[id]
+	cs := cc.streams[id]
+	if andRemove {
+		delete(cc.streams, id)
+	}
+	return cs
 }
 
 // runs in its own goroutine.
 func (cc *clientConn) readLoop() {
+	defer cc.t.removeClientConn(cc)
 	defer close(cc.readerDone)
 
 	activeRes := map[uint32]*clientStream{} // keyed by streamID
@@ -314,6 +413,8 @@ func (cc *clientConn) readLoop() {
 		}
 	}()
 
+	defer println("Transport readLoop returning")
+
 	// continueStreamID is the stream ID we're waiting for
 	// continuation frames for.
 	var continueStreamID uint32
@@ -331,12 +432,14 @@ func (cc *clientConn) readLoop() {
 		_, isContinue := f.(*ContinuationFrame)
 		if isContinue {
 			if streamID != continueStreamID {
+				log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
 				cc.readerErr = ConnectionError(ErrCodeProtocol)
 				return
 			}
 		} else if continueStreamID != 0 {
 			// Continue frames need to be adjacent in the stream
 			// and we were in the middle of headers.
+			log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
 			cc.readerErr = ConnectionError(ErrCodeProtocol)
 			return
 		}
@@ -346,17 +449,17 @@ func (cc *clientConn) readLoop() {
 			// These always have an even stream id.
 			continue
 		}
-		cs := cc.streamByID(streamID)
+		streamEnded := false
+		if ff, ok := f.(streamEnder); ok {
+			streamEnded = ff.StreamEnded()
+		}
+
+		cs := cc.streamByID(streamID, streamEnded)
 		if cs == nil {
 			log.Printf("Received frame for untracked stream ID %d", streamID)
 			continue
 		}
 
-		headersEnded := false
-		streamEnded := false
-		if ff, ok := f.(streamEnder); ok {
-			streamEnded = ff.StreamEnded()
-		}
 		switch f := f.(type) {
 		case *HeadersFrame:
 			cc.nextRes = &http.Response{
@@ -366,19 +469,29 @@ func (cc *clientConn) readLoop() {
 			}
 			cs.pr, cs.pw = io.Pipe()
 			cc.hdec.Write(f.HeaderBlockFragment())
-			headersEnded = f.HeadersEnded()
 		case *ContinuationFrame:
 			cc.hdec.Write(f.HeaderBlockFragment())
-			headersEnded = f.HeadersEnded()
 		case *DataFrame:
 			log.Printf("DATA: %q", f.Data())
 			cs.pw.Write(f.Data())
+		case *GoAwayFrame:
+			cc.t.removeClientConn(cc)
+			if f.ErrCode != 0 {
+				// TODO: deal with GOAWAY more. particularly the error code
+				log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
+			}
+			cc.setGoAway(f)
 		default:
+			log.Printf("Transport: unhandled response frame type %T", f)
 		}
-		if headersEnded {
-			continueStreamID = 0
-		} else {
-			continueStreamID = streamID
+		headersEnded := false
+		if he, ok := f.(headersEnder); ok {
+			headersEnded = he.HeadersEnded()
+			if headersEnded {
+				continueStreamID = 0
+			} else {
+				continueStreamID = streamID
+			}
 		}
 
 		if streamEnded {
@@ -389,6 +502,9 @@ func (cc *clientConn) readLoop() {
 			if cs == nil {
 				panic("couldn't find stream") // TODO be graceful
 			}
+			// TODO: set the Body to one which notes the
+			// Close and also sends the server a
+			// RST_STREAM
 			cc.nextRes.Body = cs.pr
 			res := cc.nextRes
 			activeRes[streamID] = cs

+ 38 - 3
transport_test.go

@@ -12,6 +12,7 @@ import (
 	"net/http"
 	"os"
 	"reflect"
+	"strings"
 	"testing"
 )
 
@@ -43,9 +44,9 @@ func TestTransport(t *testing.T) {
 	})
 	defer st.Close()
 
-	tr := &Transport{
-		InsecureTLSDial: true,
-	}
+	tr := &Transport{InsecureTLSDial: true}
+	defer tr.CloseIdleConnections()
+
 	req, err := http.NewRequest("GET", st.ts.URL, nil)
 	if err != nil {
 		t.Fatal(err)
@@ -84,3 +85,37 @@ func TestTransport(t *testing.T) {
 	}
 
 }
+
+func TestTransportReusesConns(t *testing.T) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		io.WriteString(w, r.RemoteAddr)
+	}, optOnlyServer)
+	defer st.Close()
+	tr := &Transport{InsecureTLSDial: true}
+	defer tr.CloseIdleConnections()
+	get := func() string {
+		req, err := http.NewRequest("GET", st.ts.URL, nil)
+		if err != nil {
+			t.Fatal(err)
+		}
+		res, err := tr.RoundTrip(req)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer res.Body.Close()
+		slurp, err := ioutil.ReadAll(res.Body)
+		if err != nil {
+			t.Fatalf("Body read: %v", err)
+		}
+		addr := strings.TrimSpace(string(slurp))
+		if addr == "" {
+			t.Fatalf("didn't get an addr in response")
+		}
+		return addr
+	}
+	first := get()
+	second := get()
+	if first != second {
+		t.Errorf("first and second responses were on different connections: %q vs %q", first, second)
+	}
+}