transport.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. // Copyright 2015 The Go Authors.
  2. // See https://go.googlesource.com/go/+/master/CONTRIBUTORS
  3. // Licensed under the same terms as Go itself:
  4. // https://go.googlesource.com/go/+/master/LICENSE
  5. package http2
  6. import (
  7. "bufio"
  8. "bytes"
  9. "crypto/tls"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "net"
  15. "net/http"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "github.com/bradfitz/http2/hpack"
  20. )
  21. type Transport struct {
  22. Fallback http.RoundTripper
  23. // TODO: remove this and make more general with a TLS dial hook, like http
  24. InsecureTLSDial bool
  25. connMu sync.Mutex
  26. conns map[string][]*clientConn // key is host:port
  27. }
  28. type clientConn struct {
  29. t *Transport
  30. tconn *tls.Conn
  31. tlsState *tls.ConnectionState
  32. connKey []string // key(s) this connection is cached in, in t.conns
  33. readerDone chan struct{} // closed on error
  34. readerErr error // set before readerDone is closed
  35. hdec *hpack.Decoder
  36. nextRes *http.Response
  37. mu sync.Mutex
  38. goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
  39. streams map[uint32]*clientStream
  40. nextStreamID uint32
  41. bw *bufio.Writer
  42. werr error // first write error that has occurred
  43. br *bufio.Reader
  44. fr *Framer
  45. // Settings from peer:
  46. maxFrameSize uint32
  47. maxConcurrentStreams uint32
  48. initialWindowSize uint32
  49. hbuf bytes.Buffer // HPACK encoder writes into this
  50. henc *hpack.Encoder
  51. }
  52. type clientStream struct {
  53. ID uint32
  54. resc chan resAndError
  55. pw *io.PipeWriter
  56. pr *io.PipeReader
  57. }
  58. type stickyErrWriter struct {
  59. w io.Writer
  60. err *error
  61. }
  62. func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
  63. if *sew.err != nil {
  64. return 0, *sew.err
  65. }
  66. n, err = sew.w.Write(p)
  67. *sew.err = err
  68. return
  69. }
  70. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  71. if req.URL.Scheme != "https" {
  72. if t.Fallback == nil {
  73. return nil, errors.New("http2: unsupported scheme and no Fallback")
  74. }
  75. return t.Fallback.RoundTrip(req)
  76. }
  77. host, port, err := net.SplitHostPort(req.URL.Host)
  78. if err != nil {
  79. host = req.URL.Host
  80. port = "443"
  81. }
  82. for {
  83. cc, err := t.getClientConn(host, port)
  84. if err != nil {
  85. return nil, err
  86. }
  87. res, err := cc.roundTrip(req)
  88. if isShutdownError(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
  89. continue
  90. }
  91. if err != nil {
  92. return nil, err
  93. }
  94. return res, nil
  95. }
  96. }
  97. // CloseIdleConnections closes any connections which were previously
  98. // connected from previous requests but are now sitting idle.
  99. // It does not interrupt any connections currently in use.
  100. func (t *Transport) CloseIdleConnections() {
  101. t.connMu.Lock()
  102. defer t.connMu.Unlock()
  103. for _, vv := range t.conns {
  104. for _, cc := range vv {
  105. cc.closeIfIdle()
  106. }
  107. }
  108. }
  109. func isShutdownError(err error) bool {
  110. // TODO: implement
  111. return false
  112. }
  113. func (t *Transport) removeClientConn(cc *clientConn) {
  114. t.connMu.Lock()
  115. defer t.connMu.Unlock()
  116. for _, key := range cc.connKey {
  117. vv, ok := t.conns[key]
  118. if !ok {
  119. continue
  120. }
  121. t.conns[key] = filterOutClientConn(vv, cc)
  122. }
  123. }
  124. func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
  125. out := in[:0]
  126. for _, v := range in {
  127. if v != exclude {
  128. out = append(out, v)
  129. }
  130. }
  131. return out
  132. }
  133. func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
  134. t.connMu.Lock()
  135. defer t.connMu.Unlock()
  136. key := net.JoinHostPort(host, port)
  137. for _, cc := range t.conns[key] {
  138. if cc.canTakeNewRequest() {
  139. return cc, nil
  140. }
  141. }
  142. if t.conns == nil {
  143. t.conns = make(map[string][]*clientConn)
  144. }
  145. cc, err := t.newClientConn(host, port, key)
  146. if err != nil {
  147. return nil, err
  148. }
  149. t.conns[key] = append(t.conns[key], cc)
  150. return cc, nil
  151. }
  152. func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
  153. cfg := &tls.Config{
  154. ServerName: host,
  155. NextProtos: []string{NextProtoTLS},
  156. InsecureSkipVerify: t.InsecureTLSDial,
  157. }
  158. tconn, err := tls.Dial("tcp", host+":"+port, cfg)
  159. if err != nil {
  160. return nil, err
  161. }
  162. if err := tconn.Handshake(); err != nil {
  163. return nil, err
  164. }
  165. if !t.InsecureTLSDial {
  166. if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
  167. return nil, err
  168. }
  169. }
  170. state := tconn.ConnectionState()
  171. if p := state.NegotiatedProtocol; p != NextProtoTLS {
  172. // TODO(bradfitz): fall back to Fallback
  173. return nil, fmt.Errorf("bad protocol: %v", p)
  174. }
  175. if !state.NegotiatedProtocolIsMutual {
  176. return nil, errors.New("could not negotiate protocol mutually")
  177. }
  178. if _, err := tconn.Write(clientPreface); err != nil {
  179. return nil, err
  180. }
  181. cc := &clientConn{
  182. t: t,
  183. tconn: tconn,
  184. connKey: []string{key}, // TODO: cert's validated hostnames too
  185. tlsState: &state,
  186. readerDone: make(chan struct{}),
  187. nextStreamID: 1,
  188. maxFrameSize: 16 << 10, // spec default
  189. initialWindowSize: 65535, // spec default
  190. maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
  191. streams: make(map[uint32]*clientStream),
  192. }
  193. cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
  194. cc.br = bufio.NewReader(tconn)
  195. cc.fr = NewFramer(cc.bw, cc.br)
  196. cc.henc = hpack.NewEncoder(&cc.hbuf)
  197. cc.fr.WriteSettings()
  198. // TODO: re-send more conn-level flow control tokens when server uses all these.
  199. cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
  200. cc.bw.Flush()
  201. if cc.werr != nil {
  202. return nil, cc.werr
  203. }
  204. // Read the obligatory SETTINGS frame
  205. f, err := cc.fr.ReadFrame()
  206. if err != nil {
  207. return nil, err
  208. }
  209. sf, ok := f.(*SettingsFrame)
  210. if !ok {
  211. return nil, fmt.Errorf("expected settings frame, got: %T", f)
  212. }
  213. cc.fr.WriteSettingsAck()
  214. cc.bw.Flush()
  215. sf.ForeachSetting(func(s Setting) error {
  216. switch s.ID {
  217. case SettingMaxFrameSize:
  218. cc.maxFrameSize = s.Val
  219. case SettingMaxConcurrentStreams:
  220. cc.maxConcurrentStreams = s.Val
  221. case SettingInitialWindowSize:
  222. cc.initialWindowSize = s.Val
  223. default:
  224. // TODO(bradfitz): handle more
  225. log.Printf("Unhandled Setting: %v", s)
  226. }
  227. return nil
  228. })
  229. // TODO: figure out henc size
  230. cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
  231. go cc.readLoop()
  232. return cc, nil
  233. }
  234. func (cc *clientConn) setGoAway(f *GoAwayFrame) {
  235. cc.mu.Lock()
  236. defer cc.mu.Unlock()
  237. cc.goAway = f
  238. }
  239. func (cc *clientConn) canTakeNewRequest() bool {
  240. cc.mu.Lock()
  241. defer cc.mu.Unlock()
  242. return cc.goAway == nil && int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams)
  243. }
  244. func (cc *clientConn) closeIfIdle() {
  245. cc.mu.Lock()
  246. defer cc.mu.Unlock()
  247. if len(cc.streams) > 0 {
  248. return
  249. }
  250. }
  251. func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
  252. cc.mu.Lock()
  253. cs := cc.newStream()
  254. hasBody := false // TODO
  255. // we send: HEADERS[+CONTINUATION] + (DATA?)
  256. hdrs := cc.encodeHeaders(req)
  257. first := true
  258. for len(hdrs) > 0 {
  259. chunk := hdrs
  260. if len(chunk) > int(cc.maxFrameSize) {
  261. chunk = chunk[:cc.maxFrameSize]
  262. }
  263. hdrs = hdrs[len(chunk):]
  264. endHeaders := len(hdrs) == 0
  265. if first {
  266. cc.fr.WriteHeaders(HeadersFrameParam{
  267. StreamID: cs.ID,
  268. BlockFragment: chunk,
  269. EndStream: !hasBody,
  270. EndHeaders: endHeaders,
  271. })
  272. first = false
  273. } else {
  274. cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
  275. }
  276. }
  277. cc.bw.Flush()
  278. werr := cc.werr
  279. cc.mu.Unlock()
  280. if hasBody {
  281. // TODO: write data. and it should probably be interleaved:
  282. // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
  283. }
  284. if werr != nil {
  285. return nil, werr
  286. }
  287. re := <-cs.resc
  288. if re.err != nil {
  289. return nil, re.err
  290. }
  291. res := re.res
  292. res.Request = req
  293. res.TLS = cc.tlsState
  294. return res, nil
  295. }
  296. // requires cc.mu be held.
  297. func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
  298. cc.hbuf.Reset()
  299. // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
  300. host := req.Host
  301. if host == "" {
  302. host = req.URL.Host
  303. }
  304. path := req.URL.Path
  305. if path == "" {
  306. path = "/"
  307. }
  308. cc.writeHeader(":authority", host) // probably not right for all sites
  309. cc.writeHeader(":method", req.Method)
  310. cc.writeHeader(":path", path)
  311. cc.writeHeader(":scheme", "https")
  312. for k, vv := range req.Header {
  313. lowKey := strings.ToLower(k)
  314. if lowKey == "host" {
  315. continue
  316. }
  317. for _, v := range vv {
  318. cc.writeHeader(lowKey, v)
  319. }
  320. }
  321. return cc.hbuf.Bytes()
  322. }
  323. func (cc *clientConn) writeHeader(name, value string) {
  324. log.Printf("sending %q = %q", name, value)
  325. cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
  326. }
  327. type resAndError struct {
  328. res *http.Response
  329. err error
  330. }
  331. // requires cc.mu be held.
  332. func (cc *clientConn) newStream() *clientStream {
  333. cs := &clientStream{
  334. ID: cc.nextStreamID,
  335. resc: make(chan resAndError, 1),
  336. }
  337. cc.nextStreamID += 2
  338. cc.streams[cs.ID] = cs
  339. return cs
  340. }
  341. func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
  342. cc.mu.Lock()
  343. defer cc.mu.Unlock()
  344. cs := cc.streams[id]
  345. if andRemove {
  346. delete(cc.streams, id)
  347. }
  348. return cs
  349. }
  350. // runs in its own goroutine.
  351. func (cc *clientConn) readLoop() {
  352. defer cc.t.removeClientConn(cc)
  353. defer close(cc.readerDone)
  354. activeRes := map[uint32]*clientStream{} // keyed by streamID
  355. // Close any response bodies if the server closes prematurely.
  356. // TODO: also do this if we've written the headers but not
  357. // gotten a response yet.
  358. defer func() {
  359. err := cc.readerErr
  360. if err == io.EOF {
  361. err = io.ErrUnexpectedEOF
  362. }
  363. for _, cs := range activeRes {
  364. cs.pw.CloseWithError(err)
  365. }
  366. }()
  367. defer println("Transport readLoop returning")
  368. // continueStreamID is the stream ID we're waiting for
  369. // continuation frames for.
  370. var continueStreamID uint32
  371. for {
  372. f, err := cc.fr.ReadFrame()
  373. if err != nil {
  374. cc.readerErr = err
  375. return
  376. }
  377. log.Printf("Transport received %v: %#v", f.Header(), f)
  378. streamID := f.Header().StreamID
  379. _, isContinue := f.(*ContinuationFrame)
  380. if isContinue {
  381. if streamID != continueStreamID {
  382. log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
  383. cc.readerErr = ConnectionError(ErrCodeProtocol)
  384. return
  385. }
  386. } else if continueStreamID != 0 {
  387. // Continue frames need to be adjacent in the stream
  388. // and we were in the middle of headers.
  389. log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
  390. cc.readerErr = ConnectionError(ErrCodeProtocol)
  391. return
  392. }
  393. if streamID%2 == 0 {
  394. // Ignore streams pushed from the server for now.
  395. // These always have an even stream id.
  396. continue
  397. }
  398. streamEnded := false
  399. if ff, ok := f.(streamEnder); ok {
  400. streamEnded = ff.StreamEnded()
  401. }
  402. cs := cc.streamByID(streamID, streamEnded)
  403. if cs == nil {
  404. log.Printf("Received frame for untracked stream ID %d", streamID)
  405. continue
  406. }
  407. switch f := f.(type) {
  408. case *HeadersFrame:
  409. cc.nextRes = &http.Response{
  410. Proto: "HTTP/2.0",
  411. ProtoMajor: 2,
  412. Header: make(http.Header),
  413. }
  414. cs.pr, cs.pw = io.Pipe()
  415. cc.hdec.Write(f.HeaderBlockFragment())
  416. case *ContinuationFrame:
  417. cc.hdec.Write(f.HeaderBlockFragment())
  418. case *DataFrame:
  419. log.Printf("DATA: %q", f.Data())
  420. cs.pw.Write(f.Data())
  421. case *GoAwayFrame:
  422. cc.t.removeClientConn(cc)
  423. if f.ErrCode != 0 {
  424. // TODO: deal with GOAWAY more. particularly the error code
  425. log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
  426. }
  427. cc.setGoAway(f)
  428. default:
  429. log.Printf("Transport: unhandled response frame type %T", f)
  430. }
  431. headersEnded := false
  432. if he, ok := f.(headersEnder); ok {
  433. headersEnded = he.HeadersEnded()
  434. if headersEnded {
  435. continueStreamID = 0
  436. } else {
  437. continueStreamID = streamID
  438. }
  439. }
  440. if streamEnded {
  441. cs.pw.Close()
  442. delete(activeRes, streamID)
  443. }
  444. if headersEnded {
  445. if cs == nil {
  446. panic("couldn't find stream") // TODO be graceful
  447. }
  448. // TODO: set the Body to one which notes the
  449. // Close and also sends the server a
  450. // RST_STREAM
  451. cc.nextRes.Body = cs.pr
  452. res := cc.nextRes
  453. activeRes[streamID] = cs
  454. cs.resc <- resAndError{res: res}
  455. }
  456. }
  457. }
  458. func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
  459. // TODO: verifiy pseudo headers come before non-pseudo headers
  460. // TODO: verifiy the status is set
  461. log.Printf("Header field: %+v", f)
  462. if f.Name == ":status" {
  463. code, err := strconv.Atoi(f.Value)
  464. if err != nil {
  465. panic("TODO: be graceful")
  466. }
  467. cc.nextRes.Status = f.Value + " " + http.StatusText(code)
  468. cc.nextRes.StatusCode = code
  469. return
  470. }
  471. if strings.HasPrefix(f.Name, ":") {
  472. // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
  473. // TODO: treat as invalid?
  474. return
  475. }
  476. cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
  477. }