transport.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. // Copyright 2015 The Go Authors.
  2. // See https://go.googlesource.com/go/+/master/CONTRIBUTORS
  3. // Licensed under the same terms as Go itself:
  4. // https://go.googlesource.com/go/+/master/LICENSE
  5. package http2
  6. import (
  7. "bufio"
  8. "bytes"
  9. "crypto/tls"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "net"
  15. "net/http"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "github.com/bradfitz/http2/hpack"
  20. )
  21. type Transport struct {
  22. Fallback http.RoundTripper
  23. // TODO: remove this and make more general with a TLS dial hook, like http
  24. InsecureTLSDial bool
  25. connMu sync.Mutex
  26. conns map[string][]*clientConn // key is host:port
  27. }
  28. type clientConn struct {
  29. t *Transport
  30. tconn *tls.Conn
  31. tlsState *tls.ConnectionState
  32. connKey []string // key(s) this connection is cached in, in t.conns
  33. readerDone chan struct{} // closed on error
  34. readerErr error // set before readerDone is closed
  35. hdec *hpack.Decoder
  36. nextRes *http.Response
  37. mu sync.Mutex
  38. closed bool
  39. goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
  40. streams map[uint32]*clientStream
  41. nextStreamID uint32
  42. bw *bufio.Writer
  43. werr error // first write error that has occurred
  44. br *bufio.Reader
  45. fr *Framer
  46. // Settings from peer:
  47. maxFrameSize uint32
  48. maxConcurrentStreams uint32
  49. initialWindowSize uint32
  50. hbuf bytes.Buffer // HPACK encoder writes into this
  51. henc *hpack.Encoder
  52. }
  53. type clientStream struct {
  54. ID uint32
  55. resc chan resAndError
  56. pw *io.PipeWriter
  57. pr *io.PipeReader
  58. }
  59. type stickyErrWriter struct {
  60. w io.Writer
  61. err *error
  62. }
  63. func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
  64. if *sew.err != nil {
  65. return 0, *sew.err
  66. }
  67. n, err = sew.w.Write(p)
  68. *sew.err = err
  69. return
  70. }
  71. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  72. if req.URL.Scheme != "https" {
  73. if t.Fallback == nil {
  74. return nil, errors.New("http2: unsupported scheme and no Fallback")
  75. }
  76. return t.Fallback.RoundTrip(req)
  77. }
  78. host, port, err := net.SplitHostPort(req.URL.Host)
  79. if err != nil {
  80. host = req.URL.Host
  81. port = "443"
  82. }
  83. for {
  84. cc, err := t.getClientConn(host, port)
  85. if err != nil {
  86. return nil, err
  87. }
  88. res, err := cc.roundTrip(req)
  89. if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
  90. continue
  91. }
  92. if err != nil {
  93. return nil, err
  94. }
  95. return res, nil
  96. }
  97. }
  98. // CloseIdleConnections closes any connections which were previously
  99. // connected from previous requests but are now sitting idle.
  100. // It does not interrupt any connections currently in use.
  101. func (t *Transport) CloseIdleConnections() {
  102. t.connMu.Lock()
  103. defer t.connMu.Unlock()
  104. for _, vv := range t.conns {
  105. for _, cc := range vv {
  106. cc.closeIfIdle()
  107. }
  108. }
  109. }
  110. var errClientConnClosed = errors.New("http2: client conn is closed")
  111. func shouldRetryRequest(err error) bool {
  112. // TODO: or GOAWAY graceful shutdown stuff
  113. return err == errClientConnClosed
  114. }
  115. func (t *Transport) removeClientConn(cc *clientConn) {
  116. t.connMu.Lock()
  117. defer t.connMu.Unlock()
  118. for _, key := range cc.connKey {
  119. vv, ok := t.conns[key]
  120. if !ok {
  121. continue
  122. }
  123. newList := filterOutClientConn(vv, cc)
  124. if len(newList) > 0 {
  125. t.conns[key] = newList
  126. } else {
  127. delete(t.conns, key)
  128. }
  129. }
  130. }
  131. func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
  132. out := in[:0]
  133. for _, v := range in {
  134. if v != exclude {
  135. out = append(out, v)
  136. }
  137. }
  138. return out
  139. }
  140. func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
  141. t.connMu.Lock()
  142. defer t.connMu.Unlock()
  143. key := net.JoinHostPort(host, port)
  144. for _, cc := range t.conns[key] {
  145. if cc.canTakeNewRequest() {
  146. return cc, nil
  147. }
  148. }
  149. if t.conns == nil {
  150. t.conns = make(map[string][]*clientConn)
  151. }
  152. cc, err := t.newClientConn(host, port, key)
  153. if err != nil {
  154. return nil, err
  155. }
  156. t.conns[key] = append(t.conns[key], cc)
  157. return cc, nil
  158. }
  159. func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
  160. cfg := &tls.Config{
  161. ServerName: host,
  162. NextProtos: []string{NextProtoTLS},
  163. InsecureSkipVerify: t.InsecureTLSDial,
  164. }
  165. tconn, err := tls.Dial("tcp", host+":"+port, cfg)
  166. if err != nil {
  167. return nil, err
  168. }
  169. if err := tconn.Handshake(); err != nil {
  170. return nil, err
  171. }
  172. if !t.InsecureTLSDial {
  173. if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
  174. return nil, err
  175. }
  176. }
  177. state := tconn.ConnectionState()
  178. if p := state.NegotiatedProtocol; p != NextProtoTLS {
  179. // TODO(bradfitz): fall back to Fallback
  180. return nil, fmt.Errorf("bad protocol: %v", p)
  181. }
  182. if !state.NegotiatedProtocolIsMutual {
  183. return nil, errors.New("could not negotiate protocol mutually")
  184. }
  185. if _, err := tconn.Write(clientPreface); err != nil {
  186. return nil, err
  187. }
  188. cc := &clientConn{
  189. t: t,
  190. tconn: tconn,
  191. connKey: []string{key}, // TODO: cert's validated hostnames too
  192. tlsState: &state,
  193. readerDone: make(chan struct{}),
  194. nextStreamID: 1,
  195. maxFrameSize: 16 << 10, // spec default
  196. initialWindowSize: 65535, // spec default
  197. maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
  198. streams: make(map[uint32]*clientStream),
  199. }
  200. cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
  201. cc.br = bufio.NewReader(tconn)
  202. cc.fr = NewFramer(cc.bw, cc.br)
  203. cc.henc = hpack.NewEncoder(&cc.hbuf)
  204. cc.fr.WriteSettings()
  205. // TODO: re-send more conn-level flow control tokens when server uses all these.
  206. cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
  207. cc.bw.Flush()
  208. if cc.werr != nil {
  209. return nil, cc.werr
  210. }
  211. // Read the obligatory SETTINGS frame
  212. f, err := cc.fr.ReadFrame()
  213. if err != nil {
  214. return nil, err
  215. }
  216. sf, ok := f.(*SettingsFrame)
  217. if !ok {
  218. return nil, fmt.Errorf("expected settings frame, got: %T", f)
  219. }
  220. cc.fr.WriteSettingsAck()
  221. cc.bw.Flush()
  222. sf.ForeachSetting(func(s Setting) error {
  223. switch s.ID {
  224. case SettingMaxFrameSize:
  225. cc.maxFrameSize = s.Val
  226. case SettingMaxConcurrentStreams:
  227. cc.maxConcurrentStreams = s.Val
  228. case SettingInitialWindowSize:
  229. cc.initialWindowSize = s.Val
  230. default:
  231. // TODO(bradfitz): handle more
  232. log.Printf("Unhandled Setting: %v", s)
  233. }
  234. return nil
  235. })
  236. // TODO: figure out henc size
  237. cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
  238. go cc.readLoop()
  239. return cc, nil
  240. }
  241. func (cc *clientConn) setGoAway(f *GoAwayFrame) {
  242. cc.mu.Lock()
  243. defer cc.mu.Unlock()
  244. cc.goAway = f
  245. }
  246. func (cc *clientConn) canTakeNewRequest() bool {
  247. cc.mu.Lock()
  248. defer cc.mu.Unlock()
  249. return cc.goAway == nil &&
  250. int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
  251. cc.nextStreamID < 2147483647
  252. }
  253. func (cc *clientConn) closeIfIdle() {
  254. cc.mu.Lock()
  255. if len(cc.streams) > 0 {
  256. cc.mu.Unlock()
  257. return
  258. }
  259. cc.closed = true
  260. // TODO: do clients send GOAWAY too? maybe? Just Close:
  261. cc.mu.Unlock()
  262. cc.tconn.Close()
  263. }
  264. func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
  265. cc.mu.Lock()
  266. if cc.closed {
  267. cc.mu.Unlock()
  268. return nil, errClientConnClosed
  269. }
  270. cs := cc.newStream()
  271. hasBody := false // TODO
  272. // we send: HEADERS[+CONTINUATION] + (DATA?)
  273. hdrs := cc.encodeHeaders(req)
  274. first := true
  275. for len(hdrs) > 0 {
  276. chunk := hdrs
  277. if len(chunk) > int(cc.maxFrameSize) {
  278. chunk = chunk[:cc.maxFrameSize]
  279. }
  280. hdrs = hdrs[len(chunk):]
  281. endHeaders := len(hdrs) == 0
  282. if first {
  283. cc.fr.WriteHeaders(HeadersFrameParam{
  284. StreamID: cs.ID,
  285. BlockFragment: chunk,
  286. EndStream: !hasBody,
  287. EndHeaders: endHeaders,
  288. })
  289. first = false
  290. } else {
  291. cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
  292. }
  293. }
  294. cc.bw.Flush()
  295. werr := cc.werr
  296. cc.mu.Unlock()
  297. if hasBody {
  298. // TODO: write data. and it should probably be interleaved:
  299. // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
  300. }
  301. if werr != nil {
  302. return nil, werr
  303. }
  304. re := <-cs.resc
  305. if re.err != nil {
  306. return nil, re.err
  307. }
  308. res := re.res
  309. res.Request = req
  310. res.TLS = cc.tlsState
  311. return res, nil
  312. }
  313. // requires cc.mu be held.
  314. func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
  315. cc.hbuf.Reset()
  316. // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
  317. host := req.Host
  318. if host == "" {
  319. host = req.URL.Host
  320. }
  321. path := req.URL.Path
  322. if path == "" {
  323. path = "/"
  324. }
  325. cc.writeHeader(":authority", host) // probably not right for all sites
  326. cc.writeHeader(":method", req.Method)
  327. cc.writeHeader(":path", path)
  328. cc.writeHeader(":scheme", "https")
  329. for k, vv := range req.Header {
  330. lowKey := strings.ToLower(k)
  331. if lowKey == "host" {
  332. continue
  333. }
  334. for _, v := range vv {
  335. cc.writeHeader(lowKey, v)
  336. }
  337. }
  338. return cc.hbuf.Bytes()
  339. }
  340. func (cc *clientConn) writeHeader(name, value string) {
  341. log.Printf("sending %q = %q", name, value)
  342. cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
  343. }
  344. type resAndError struct {
  345. res *http.Response
  346. err error
  347. }
  348. // requires cc.mu be held.
  349. func (cc *clientConn) newStream() *clientStream {
  350. cs := &clientStream{
  351. ID: cc.nextStreamID,
  352. resc: make(chan resAndError, 1),
  353. }
  354. cc.nextStreamID += 2
  355. cc.streams[cs.ID] = cs
  356. return cs
  357. }
  358. func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
  359. cc.mu.Lock()
  360. defer cc.mu.Unlock()
  361. cs := cc.streams[id]
  362. if andRemove {
  363. delete(cc.streams, id)
  364. }
  365. return cs
  366. }
  367. // runs in its own goroutine.
  368. func (cc *clientConn) readLoop() {
  369. defer cc.t.removeClientConn(cc)
  370. defer close(cc.readerDone)
  371. activeRes := map[uint32]*clientStream{} // keyed by streamID
  372. // Close any response bodies if the server closes prematurely.
  373. // TODO: also do this if we've written the headers but not
  374. // gotten a response yet.
  375. defer func() {
  376. err := cc.readerErr
  377. if err == io.EOF {
  378. err = io.ErrUnexpectedEOF
  379. }
  380. for _, cs := range activeRes {
  381. cs.pw.CloseWithError(err)
  382. }
  383. }()
  384. // continueStreamID is the stream ID we're waiting for
  385. // continuation frames for.
  386. var continueStreamID uint32
  387. for {
  388. f, err := cc.fr.ReadFrame()
  389. if err != nil {
  390. cc.readerErr = err
  391. return
  392. }
  393. log.Printf("Transport received %v: %#v", f.Header(), f)
  394. streamID := f.Header().StreamID
  395. _, isContinue := f.(*ContinuationFrame)
  396. if isContinue {
  397. if streamID != continueStreamID {
  398. log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
  399. cc.readerErr = ConnectionError(ErrCodeProtocol)
  400. return
  401. }
  402. } else if continueStreamID != 0 {
  403. // Continue frames need to be adjacent in the stream
  404. // and we were in the middle of headers.
  405. log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
  406. cc.readerErr = ConnectionError(ErrCodeProtocol)
  407. return
  408. }
  409. if streamID%2 == 0 {
  410. // Ignore streams pushed from the server for now.
  411. // These always have an even stream id.
  412. continue
  413. }
  414. streamEnded := false
  415. if ff, ok := f.(streamEnder); ok {
  416. streamEnded = ff.StreamEnded()
  417. }
  418. cs := cc.streamByID(streamID, streamEnded)
  419. if cs == nil {
  420. log.Printf("Received frame for untracked stream ID %d", streamID)
  421. continue
  422. }
  423. switch f := f.(type) {
  424. case *HeadersFrame:
  425. cc.nextRes = &http.Response{
  426. Proto: "HTTP/2.0",
  427. ProtoMajor: 2,
  428. Header: make(http.Header),
  429. }
  430. cs.pr, cs.pw = io.Pipe()
  431. cc.hdec.Write(f.HeaderBlockFragment())
  432. case *ContinuationFrame:
  433. cc.hdec.Write(f.HeaderBlockFragment())
  434. case *DataFrame:
  435. log.Printf("DATA: %q", f.Data())
  436. cs.pw.Write(f.Data())
  437. case *GoAwayFrame:
  438. cc.t.removeClientConn(cc)
  439. if f.ErrCode != 0 {
  440. // TODO: deal with GOAWAY more. particularly the error code
  441. log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
  442. }
  443. cc.setGoAway(f)
  444. default:
  445. log.Printf("Transport: unhandled response frame type %T", f)
  446. }
  447. headersEnded := false
  448. if he, ok := f.(headersEnder); ok {
  449. headersEnded = he.HeadersEnded()
  450. if headersEnded {
  451. continueStreamID = 0
  452. } else {
  453. continueStreamID = streamID
  454. }
  455. }
  456. if streamEnded {
  457. cs.pw.Close()
  458. delete(activeRes, streamID)
  459. }
  460. if headersEnded {
  461. if cs == nil {
  462. panic("couldn't find stream") // TODO be graceful
  463. }
  464. // TODO: set the Body to one which notes the
  465. // Close and also sends the server a
  466. // RST_STREAM
  467. cc.nextRes.Body = cs.pr
  468. res := cc.nextRes
  469. activeRes[streamID] = cs
  470. cs.resc <- resAndError{res: res}
  471. }
  472. }
  473. }
  474. func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
  475. // TODO: verifiy pseudo headers come before non-pseudo headers
  476. // TODO: verifiy the status is set
  477. log.Printf("Header field: %+v", f)
  478. if f.Name == ":status" {
  479. code, err := strconv.Atoi(f.Value)
  480. if err != nil {
  481. panic("TODO: be graceful")
  482. }
  483. cc.nextRes.Status = f.Value + " " + http.StatusText(code)
  484. cc.nextRes.StatusCode = code
  485. return
  486. }
  487. if strings.HasPrefix(f.Name, ":") {
  488. // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
  489. // TODO: treat as invalid?
  490. return
  491. }
  492. cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
  493. }