| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535 |
- // Copyright 2015 The Go Authors.
- // See https://go.googlesource.com/go/+/master/CONTRIBUTORS
- // Licensed under the same terms as Go itself:
- // https://go.googlesource.com/go/+/master/LICENSE
- package http2
- import (
- "bufio"
- "bytes"
- "crypto/tls"
- "errors"
- "fmt"
- "io"
- "log"
- "net"
- "net/http"
- "strconv"
- "strings"
- "sync"
- "github.com/bradfitz/http2/hpack"
- )
- type Transport struct {
- Fallback http.RoundTripper
- // TODO: remove this and make more general with a TLS dial hook, like http
- InsecureTLSDial bool
- connMu sync.Mutex
- conns map[string][]*clientConn // key is host:port
- }
- type clientConn struct {
- t *Transport
- tconn *tls.Conn
- tlsState *tls.ConnectionState
- connKey []string // key(s) this connection is cached in, in t.conns
- readerDone chan struct{} // closed on error
- readerErr error // set before readerDone is closed
- hdec *hpack.Decoder
- nextRes *http.Response
- mu sync.Mutex
- goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
- streams map[uint32]*clientStream
- nextStreamID uint32
- bw *bufio.Writer
- werr error // first write error that has occurred
- br *bufio.Reader
- fr *Framer
- // Settings from peer:
- maxFrameSize uint32
- maxConcurrentStreams uint32
- initialWindowSize uint32
- hbuf bytes.Buffer // HPACK encoder writes into this
- henc *hpack.Encoder
- }
- type clientStream struct {
- ID uint32
- resc chan resAndError
- pw *io.PipeWriter
- pr *io.PipeReader
- }
- type stickyErrWriter struct {
- w io.Writer
- err *error
- }
- func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
- if *sew.err != nil {
- return 0, *sew.err
- }
- n, err = sew.w.Write(p)
- *sew.err = err
- return
- }
- func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
- if req.URL.Scheme != "https" {
- if t.Fallback == nil {
- return nil, errors.New("http2: unsupported scheme and no Fallback")
- }
- return t.Fallback.RoundTrip(req)
- }
- host, port, err := net.SplitHostPort(req.URL.Host)
- if err != nil {
- host = req.URL.Host
- port = "443"
- }
- for {
- cc, err := t.getClientConn(host, port)
- if err != nil {
- return nil, err
- }
- res, err := cc.roundTrip(req)
- if isShutdownError(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
- continue
- }
- if err != nil {
- return nil, err
- }
- return res, nil
- }
- }
- // CloseIdleConnections closes any connections which were previously
- // connected from previous requests but are now sitting idle.
- // It does not interrupt any connections currently in use.
- func (t *Transport) CloseIdleConnections() {
- t.connMu.Lock()
- defer t.connMu.Unlock()
- for _, vv := range t.conns {
- for _, cc := range vv {
- cc.closeIfIdle()
- }
- }
- }
- func isShutdownError(err error) bool {
- // TODO: implement
- return false
- }
- func (t *Transport) removeClientConn(cc *clientConn) {
- t.connMu.Lock()
- defer t.connMu.Unlock()
- for _, key := range cc.connKey {
- vv, ok := t.conns[key]
- if !ok {
- continue
- }
- t.conns[key] = filterOutClientConn(vv, cc)
- }
- }
- func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
- out := in[:0]
- for _, v := range in {
- if v != exclude {
- out = append(out, v)
- }
- }
- return out
- }
- func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
- t.connMu.Lock()
- defer t.connMu.Unlock()
- key := net.JoinHostPort(host, port)
- for _, cc := range t.conns[key] {
- if cc.canTakeNewRequest() {
- return cc, nil
- }
- }
- if t.conns == nil {
- t.conns = make(map[string][]*clientConn)
- }
- cc, err := t.newClientConn(host, port, key)
- if err != nil {
- return nil, err
- }
- t.conns[key] = append(t.conns[key], cc)
- return cc, nil
- }
- func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
- cfg := &tls.Config{
- ServerName: host,
- NextProtos: []string{NextProtoTLS},
- InsecureSkipVerify: t.InsecureTLSDial,
- }
- tconn, err := tls.Dial("tcp", host+":"+port, cfg)
- if err != nil {
- return nil, err
- }
- if err := tconn.Handshake(); err != nil {
- return nil, err
- }
- if !t.InsecureTLSDial {
- if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
- return nil, err
- }
- }
- state := tconn.ConnectionState()
- if p := state.NegotiatedProtocol; p != NextProtoTLS {
- // TODO(bradfitz): fall back to Fallback
- return nil, fmt.Errorf("bad protocol: %v", p)
- }
- if !state.NegotiatedProtocolIsMutual {
- return nil, errors.New("could not negotiate protocol mutually")
- }
- if _, err := tconn.Write(clientPreface); err != nil {
- return nil, err
- }
- cc := &clientConn{
- t: t,
- tconn: tconn,
- connKey: []string{key}, // TODO: cert's validated hostnames too
- tlsState: &state,
- readerDone: make(chan struct{}),
- nextStreamID: 1,
- maxFrameSize: 16 << 10, // spec default
- initialWindowSize: 65535, // spec default
- maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
- streams: make(map[uint32]*clientStream),
- }
- cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
- cc.br = bufio.NewReader(tconn)
- cc.fr = NewFramer(cc.bw, cc.br)
- cc.henc = hpack.NewEncoder(&cc.hbuf)
- cc.fr.WriteSettings()
- // TODO: re-send more conn-level flow control tokens when server uses all these.
- cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
- cc.bw.Flush()
- if cc.werr != nil {
- return nil, cc.werr
- }
- // Read the obligatory SETTINGS frame
- f, err := cc.fr.ReadFrame()
- if err != nil {
- return nil, err
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- return nil, fmt.Errorf("expected settings frame, got: %T", f)
- }
- cc.fr.WriteSettingsAck()
- cc.bw.Flush()
- sf.ForeachSetting(func(s Setting) error {
- switch s.ID {
- case SettingMaxFrameSize:
- cc.maxFrameSize = s.Val
- case SettingMaxConcurrentStreams:
- cc.maxConcurrentStreams = s.Val
- case SettingInitialWindowSize:
- cc.initialWindowSize = s.Val
- default:
- // TODO(bradfitz): handle more
- log.Printf("Unhandled Setting: %v", s)
- }
- return nil
- })
- // TODO: figure out henc size
- cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
- go cc.readLoop()
- return cc, nil
- }
- func (cc *clientConn) setGoAway(f *GoAwayFrame) {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- cc.goAway = f
- }
- func (cc *clientConn) canTakeNewRequest() bool {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- return cc.goAway == nil && int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams)
- }
- func (cc *clientConn) closeIfIdle() {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- if len(cc.streams) > 0 {
- return
- }
- }
- func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
- cc.mu.Lock()
- cs := cc.newStream()
- hasBody := false // TODO
- // we send: HEADERS[+CONTINUATION] + (DATA?)
- hdrs := cc.encodeHeaders(req)
- first := true
- for len(hdrs) > 0 {
- chunk := hdrs
- if len(chunk) > int(cc.maxFrameSize) {
- chunk = chunk[:cc.maxFrameSize]
- }
- hdrs = hdrs[len(chunk):]
- endHeaders := len(hdrs) == 0
- if first {
- cc.fr.WriteHeaders(HeadersFrameParam{
- StreamID: cs.ID,
- BlockFragment: chunk,
- EndStream: !hasBody,
- EndHeaders: endHeaders,
- })
- first = false
- } else {
- cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
- }
- }
- cc.bw.Flush()
- werr := cc.werr
- cc.mu.Unlock()
- if hasBody {
- // TODO: write data. and it should probably be interleaved:
- // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
- }
- if werr != nil {
- return nil, werr
- }
- re := <-cs.resc
- if re.err != nil {
- return nil, re.err
- }
- res := re.res
- res.Request = req
- res.TLS = cc.tlsState
- return res, nil
- }
- // requires cc.mu be held.
- func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
- cc.hbuf.Reset()
- // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
- host := req.Host
- if host == "" {
- host = req.URL.Host
- }
- path := req.URL.Path
- if path == "" {
- path = "/"
- }
- cc.writeHeader(":authority", host) // probably not right for all sites
- cc.writeHeader(":method", req.Method)
- cc.writeHeader(":path", path)
- cc.writeHeader(":scheme", "https")
- for k, vv := range req.Header {
- lowKey := strings.ToLower(k)
- if lowKey == "host" {
- continue
- }
- for _, v := range vv {
- cc.writeHeader(lowKey, v)
- }
- }
- return cc.hbuf.Bytes()
- }
- func (cc *clientConn) writeHeader(name, value string) {
- log.Printf("sending %q = %q", name, value)
- cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
- }
- type resAndError struct {
- res *http.Response
- err error
- }
- // requires cc.mu be held.
- func (cc *clientConn) newStream() *clientStream {
- cs := &clientStream{
- ID: cc.nextStreamID,
- resc: make(chan resAndError, 1),
- }
- cc.nextStreamID += 2
- cc.streams[cs.ID] = cs
- return cs
- }
- func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- cs := cc.streams[id]
- if andRemove {
- delete(cc.streams, id)
- }
- return cs
- }
- // runs in its own goroutine.
- func (cc *clientConn) readLoop() {
- defer cc.t.removeClientConn(cc)
- defer close(cc.readerDone)
- activeRes := map[uint32]*clientStream{} // keyed by streamID
- // Close any response bodies if the server closes prematurely.
- // TODO: also do this if we've written the headers but not
- // gotten a response yet.
- defer func() {
- err := cc.readerErr
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
- }
- for _, cs := range activeRes {
- cs.pw.CloseWithError(err)
- }
- }()
- defer println("Transport readLoop returning")
- // continueStreamID is the stream ID we're waiting for
- // continuation frames for.
- var continueStreamID uint32
- for {
- f, err := cc.fr.ReadFrame()
- if err != nil {
- cc.readerErr = err
- return
- }
- log.Printf("Transport received %v: %#v", f.Header(), f)
- streamID := f.Header().StreamID
- _, isContinue := f.(*ContinuationFrame)
- if isContinue {
- if streamID != continueStreamID {
- log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
- cc.readerErr = ConnectionError(ErrCodeProtocol)
- return
- }
- } else if continueStreamID != 0 {
- // Continue frames need to be adjacent in the stream
- // and we were in the middle of headers.
- log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
- cc.readerErr = ConnectionError(ErrCodeProtocol)
- return
- }
- if streamID%2 == 0 {
- // Ignore streams pushed from the server for now.
- // These always have an even stream id.
- continue
- }
- streamEnded := false
- if ff, ok := f.(streamEnder); ok {
- streamEnded = ff.StreamEnded()
- }
- cs := cc.streamByID(streamID, streamEnded)
- if cs == nil {
- log.Printf("Received frame for untracked stream ID %d", streamID)
- continue
- }
- switch f := f.(type) {
- case *HeadersFrame:
- cc.nextRes = &http.Response{
- Proto: "HTTP/2.0",
- ProtoMajor: 2,
- Header: make(http.Header),
- }
- cs.pr, cs.pw = io.Pipe()
- cc.hdec.Write(f.HeaderBlockFragment())
- case *ContinuationFrame:
- cc.hdec.Write(f.HeaderBlockFragment())
- case *DataFrame:
- log.Printf("DATA: %q", f.Data())
- cs.pw.Write(f.Data())
- case *GoAwayFrame:
- cc.t.removeClientConn(cc)
- if f.ErrCode != 0 {
- // TODO: deal with GOAWAY more. particularly the error code
- log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
- }
- cc.setGoAway(f)
- default:
- log.Printf("Transport: unhandled response frame type %T", f)
- }
- headersEnded := false
- if he, ok := f.(headersEnder); ok {
- headersEnded = he.HeadersEnded()
- if headersEnded {
- continueStreamID = 0
- } else {
- continueStreamID = streamID
- }
- }
- if streamEnded {
- cs.pw.Close()
- delete(activeRes, streamID)
- }
- if headersEnded {
- if cs == nil {
- panic("couldn't find stream") // TODO be graceful
- }
- // TODO: set the Body to one which notes the
- // Close and also sends the server a
- // RST_STREAM
- cc.nextRes.Body = cs.pr
- res := cc.nextRes
- activeRes[streamID] = cs
- cs.resc <- resAndError{res: res}
- }
- }
- }
- func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
- // TODO: verifiy pseudo headers come before non-pseudo headers
- // TODO: verifiy the status is set
- log.Printf("Header field: %+v", f)
- if f.Name == ":status" {
- code, err := strconv.Atoi(f.Value)
- if err != nil {
- panic("TODO: be graceful")
- }
- cc.nextRes.Status = f.Value + " " + http.StatusText(code)
- cc.nextRes.StatusCode = code
- return
- }
- if strings.HasPrefix(f.Name, ":") {
- // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
- // TODO: treat as invalid?
- return
- }
- cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
- }
|