|
|
@@ -27,32 +27,36 @@ type Transport struct {
|
|
|
|
|
|
// TODO: remove this and make more general with a TLS dial hook, like http
|
|
|
InsecureTLSDial bool
|
|
|
+
|
|
|
+ connMu sync.Mutex
|
|
|
+ conns map[string][]*clientConn // key is host:port
|
|
|
}
|
|
|
|
|
|
type clientConn struct {
|
|
|
+ t *Transport
|
|
|
tconn *tls.Conn
|
|
|
tlsState *tls.ConnectionState
|
|
|
+ connKey []string // key(s) this connection is cached in, in t.conns
|
|
|
|
|
|
readerDone chan struct{} // closed on error
|
|
|
readerErr error // set before readerDone is closed
|
|
|
-
|
|
|
- hbuf bytes.Buffer // HPACK encoder writes into this
|
|
|
- henc *hpack.Encoder
|
|
|
-
|
|
|
- hdec *hpack.Decoder
|
|
|
-
|
|
|
- nextRes *http.Response
|
|
|
-
|
|
|
- // Settings from peer:
|
|
|
- maxFrameSize uint32
|
|
|
+ hdec *hpack.Decoder
|
|
|
+ nextRes *http.Response
|
|
|
|
|
|
mu sync.Mutex
|
|
|
+ goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
|
|
|
streams map[uint32]*clientStream
|
|
|
nextStreamID uint32
|
|
|
bw *bufio.Writer
|
|
|
werr error // first write error that has occurred
|
|
|
br *bufio.Reader
|
|
|
fr *Framer
|
|
|
+ // Settings from peer:
|
|
|
+ maxFrameSize uint32
|
|
|
+ maxConcurrentStreams uint32
|
|
|
+ initialWindowSize uint32
|
|
|
+ hbuf bytes.Buffer // HPACK encoder writes into this
|
|
|
+ henc *hpack.Encoder
|
|
|
}
|
|
|
|
|
|
type clientStream struct {
|
|
|
@@ -96,7 +100,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
return nil, err
|
|
|
}
|
|
|
res, err := cc.roundTrip(req)
|
|
|
- if isShutdownError(err) {
|
|
|
+ if isShutdownError(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
|
|
|
continue
|
|
|
}
|
|
|
if err != nil {
|
|
|
@@ -106,13 +110,69 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// CloseIdleConnections closes any connections which were previously
|
|
|
+// connected from previous requests but are now sitting idle.
|
|
|
+// It does not interrupt any connections currently in use.
|
|
|
+func (t *Transport) CloseIdleConnections() {
|
|
|
+ t.connMu.Lock()
|
|
|
+ defer t.connMu.Unlock()
|
|
|
+ for _, vv := range t.conns {
|
|
|
+ for _, cc := range vv {
|
|
|
+ cc.closeIfIdle()
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func isShutdownError(err error) bool {
|
|
|
// TODO: implement
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
+func (t *Transport) removeClientConn(cc *clientConn) {
|
|
|
+ t.connMu.Lock()
|
|
|
+ defer t.connMu.Unlock()
|
|
|
+ for _, key := range cc.connKey {
|
|
|
+ vv, ok := t.conns[key]
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ t.conns[key] = filterOutClientConn(vv, cc)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
|
|
|
+ out := in[:0]
|
|
|
+ for _, v := range in {
|
|
|
+ if v != exclude {
|
|
|
+ out = append(out, v)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return out
|
|
|
+}
|
|
|
+
|
|
|
func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
|
|
|
- // TODO: cache these
|
|
|
+ t.connMu.Lock()
|
|
|
+ defer t.connMu.Unlock()
|
|
|
+
|
|
|
+ key := net.JoinHostPort(host, port)
|
|
|
+
|
|
|
+ for _, cc := range t.conns[key] {
|
|
|
+ if cc.canTakeNewRequest() {
|
|
|
+ return cc, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if t.conns == nil {
|
|
|
+ t.conns = make(map[string][]*clientConn)
|
|
|
+ }
|
|
|
+ cc, err := t.newClientConn(host, port, key)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ t.conns[key] = append(t.conns[key], cc)
|
|
|
+ return cc, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
|
|
|
cfg := &tls.Config{
|
|
|
ServerName: host,
|
|
|
NextProtos: []string{NextProtoTLS},
|
|
|
@@ -143,11 +203,16 @@ func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
|
|
|
}
|
|
|
|
|
|
cc := &clientConn{
|
|
|
- tconn: tconn,
|
|
|
- tlsState: &state,
|
|
|
- readerDone: make(chan struct{}),
|
|
|
- nextStreamID: 1,
|
|
|
- streams: make(map[uint32]*clientStream),
|
|
|
+ t: t,
|
|
|
+ tconn: tconn,
|
|
|
+ connKey: []string{key}, // TODO: cert's validated hostnames too
|
|
|
+ tlsState: &state,
|
|
|
+ readerDone: make(chan struct{}),
|
|
|
+ nextStreamID: 1,
|
|
|
+ maxFrameSize: 16 << 10, // spec default
|
|
|
+ initialWindowSize: 65535, // spec default
|
|
|
+ maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
|
|
|
+ streams: make(map[uint32]*clientStream),
|
|
|
}
|
|
|
cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
|
|
|
cc.br = bufio.NewReader(tconn)
|
|
|
@@ -178,8 +243,12 @@ func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
|
|
|
switch s.ID {
|
|
|
case SettingMaxFrameSize:
|
|
|
cc.maxFrameSize = s.Val
|
|
|
- // TODO(bradfitz): handle the others
|
|
|
+ case SettingMaxConcurrentStreams:
|
|
|
+ cc.maxConcurrentStreams = s.Val
|
|
|
+ case SettingInitialWindowSize:
|
|
|
+ cc.initialWindowSize = s.Val
|
|
|
default:
|
|
|
+ // TODO(bradfitz): handle more
|
|
|
log.Printf("Unhandled Setting: %v", s)
|
|
|
}
|
|
|
return nil
|
|
|
@@ -191,6 +260,26 @@ func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
|
|
|
return cc, nil
|
|
|
}
|
|
|
|
|
|
+func (cc *clientConn) setGoAway(f *GoAwayFrame) {
|
|
|
+ cc.mu.Lock()
|
|
|
+ defer cc.mu.Unlock()
|
|
|
+ cc.goAway = f
|
|
|
+}
|
|
|
+
|
|
|
+func (cc *clientConn) canTakeNewRequest() bool {
|
|
|
+ cc.mu.Lock()
|
|
|
+ defer cc.mu.Unlock()
|
|
|
+ return cc.goAway == nil && int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams)
|
|
|
+}
|
|
|
+
|
|
|
+func (cc *clientConn) closeIfIdle() {
|
|
|
+ cc.mu.Lock()
|
|
|
+ defer cc.mu.Unlock()
|
|
|
+ if len(cc.streams) > 0 {
|
|
|
+ return
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
|
|
|
cc.mu.Lock()
|
|
|
|
|
|
@@ -223,6 +312,11 @@ func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
|
|
|
werr := cc.werr
|
|
|
cc.mu.Unlock()
|
|
|
|
|
|
+ if hasBody {
|
|
|
+ // TODO: write data. and it should probably be interleaved:
|
|
|
+ // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
|
|
|
+ }
|
|
|
+
|
|
|
if werr != nil {
|
|
|
return nil, werr
|
|
|
}
|
|
|
@@ -290,14 +384,19 @@ func (cc *clientConn) newStream() *clientStream {
|
|
|
return cs
|
|
|
}
|
|
|
|
|
|
-func (cc *clientConn) streamByID(id uint32) *clientStream {
|
|
|
+func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
|
|
|
cc.mu.Lock()
|
|
|
defer cc.mu.Unlock()
|
|
|
- return cc.streams[id]
|
|
|
+ cs := cc.streams[id]
|
|
|
+ if andRemove {
|
|
|
+ delete(cc.streams, id)
|
|
|
+ }
|
|
|
+ return cs
|
|
|
}
|
|
|
|
|
|
// runs in its own goroutine.
|
|
|
func (cc *clientConn) readLoop() {
|
|
|
+ defer cc.t.removeClientConn(cc)
|
|
|
defer close(cc.readerDone)
|
|
|
|
|
|
activeRes := map[uint32]*clientStream{} // keyed by streamID
|
|
|
@@ -314,6 +413,8 @@ func (cc *clientConn) readLoop() {
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
+ defer println("Transport readLoop returning")
|
|
|
+
|
|
|
// continueStreamID is the stream ID we're waiting for
|
|
|
// continuation frames for.
|
|
|
var continueStreamID uint32
|
|
|
@@ -331,12 +432,14 @@ func (cc *clientConn) readLoop() {
|
|
|
_, isContinue := f.(*ContinuationFrame)
|
|
|
if isContinue {
|
|
|
if streamID != continueStreamID {
|
|
|
+ log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
|
|
|
cc.readerErr = ConnectionError(ErrCodeProtocol)
|
|
|
return
|
|
|
}
|
|
|
} else if continueStreamID != 0 {
|
|
|
// Continue frames need to be adjacent in the stream
|
|
|
// and we were in the middle of headers.
|
|
|
+ log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
|
|
|
cc.readerErr = ConnectionError(ErrCodeProtocol)
|
|
|
return
|
|
|
}
|
|
|
@@ -346,17 +449,17 @@ func (cc *clientConn) readLoop() {
|
|
|
// These always have an even stream id.
|
|
|
continue
|
|
|
}
|
|
|
- cs := cc.streamByID(streamID)
|
|
|
+ streamEnded := false
|
|
|
+ if ff, ok := f.(streamEnder); ok {
|
|
|
+ streamEnded = ff.StreamEnded()
|
|
|
+ }
|
|
|
+
|
|
|
+ cs := cc.streamByID(streamID, streamEnded)
|
|
|
if cs == nil {
|
|
|
log.Printf("Received frame for untracked stream ID %d", streamID)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- headersEnded := false
|
|
|
- streamEnded := false
|
|
|
- if ff, ok := f.(streamEnder); ok {
|
|
|
- streamEnded = ff.StreamEnded()
|
|
|
- }
|
|
|
switch f := f.(type) {
|
|
|
case *HeadersFrame:
|
|
|
cc.nextRes = &http.Response{
|
|
|
@@ -366,19 +469,29 @@ func (cc *clientConn) readLoop() {
|
|
|
}
|
|
|
cs.pr, cs.pw = io.Pipe()
|
|
|
cc.hdec.Write(f.HeaderBlockFragment())
|
|
|
- headersEnded = f.HeadersEnded()
|
|
|
case *ContinuationFrame:
|
|
|
cc.hdec.Write(f.HeaderBlockFragment())
|
|
|
- headersEnded = f.HeadersEnded()
|
|
|
case *DataFrame:
|
|
|
log.Printf("DATA: %q", f.Data())
|
|
|
cs.pw.Write(f.Data())
|
|
|
+ case *GoAwayFrame:
|
|
|
+ cc.t.removeClientConn(cc)
|
|
|
+ if f.ErrCode != 0 {
|
|
|
+ // TODO: deal with GOAWAY more. particularly the error code
|
|
|
+ log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
|
|
|
+ }
|
|
|
+ cc.setGoAway(f)
|
|
|
default:
|
|
|
+ log.Printf("Transport: unhandled response frame type %T", f)
|
|
|
}
|
|
|
- if headersEnded {
|
|
|
- continueStreamID = 0
|
|
|
- } else {
|
|
|
- continueStreamID = streamID
|
|
|
+ headersEnded := false
|
|
|
+ if he, ok := f.(headersEnder); ok {
|
|
|
+ headersEnded = he.HeadersEnded()
|
|
|
+ if headersEnded {
|
|
|
+ continueStreamID = 0
|
|
|
+ } else {
|
|
|
+ continueStreamID = streamID
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
if streamEnded {
|
|
|
@@ -389,6 +502,9 @@ func (cc *clientConn) readLoop() {
|
|
|
if cs == nil {
|
|
|
panic("couldn't find stream") // TODO be graceful
|
|
|
}
|
|
|
+ // TODO: set the Body to one which notes the
|
|
|
+ // Close and also sends the server a
|
|
|
+ // RST_STREAM
|
|
|
cc.nextRes.Body = cs.pr
|
|
|
res := cc.nextRes
|
|
|
activeRes[streamID] = cs
|