transport.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. // Copyright 2015 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package http2
  5. import (
  6. "bufio"
  7. "bytes"
  8. "crypto/tls"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "log"
  13. "net"
  14. "net/http"
  15. "strconv"
  16. "strings"
  17. "sync"
  18. "golang.org/x/net/http2/hpack"
  19. )
  20. type Transport struct {
  21. Fallback http.RoundTripper
  22. // TODO: remove this and make more general with a TLS dial hook, like http
  23. InsecureTLSDial bool
  24. connMu sync.Mutex
  25. conns map[string][]*clientConn // key is host:port
  26. }
  27. type clientConn struct {
  28. t *Transport
  29. tconn *tls.Conn
  30. tlsState *tls.ConnectionState
  31. connKey []string // key(s) this connection is cached in, in t.conns
  32. readerDone chan struct{} // closed on error
  33. readerErr error // set before readerDone is closed
  34. hdec *hpack.Decoder
  35. nextRes *http.Response
  36. mu sync.Mutex
  37. closed bool
  38. goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
  39. streams map[uint32]*clientStream
  40. nextStreamID uint32
  41. bw *bufio.Writer
  42. werr error // first write error that has occurred
  43. br *bufio.Reader
  44. fr *Framer
  45. // Settings from peer:
  46. maxFrameSize uint32
  47. maxConcurrentStreams uint32
  48. initialWindowSize uint32
  49. hbuf bytes.Buffer // HPACK encoder writes into this
  50. henc *hpack.Encoder
  51. }
  52. type clientStream struct {
  53. ID uint32
  54. resc chan resAndError
  55. pw *io.PipeWriter
  56. pr *io.PipeReader
  57. }
  58. type stickyErrWriter struct {
  59. w io.Writer
  60. err *error
  61. }
  62. func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
  63. if *sew.err != nil {
  64. return 0, *sew.err
  65. }
  66. n, err = sew.w.Write(p)
  67. *sew.err = err
  68. return
  69. }
  70. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  71. if req.URL.Scheme != "https" {
  72. if t.Fallback == nil {
  73. return nil, errors.New("http2: unsupported scheme and no Fallback")
  74. }
  75. return t.Fallback.RoundTrip(req)
  76. }
  77. host, port, err := net.SplitHostPort(req.URL.Host)
  78. if err != nil {
  79. host = req.URL.Host
  80. port = "443"
  81. }
  82. for {
  83. cc, err := t.getClientConn(host, port)
  84. if err != nil {
  85. return nil, err
  86. }
  87. res, err := cc.roundTrip(req)
  88. if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
  89. continue
  90. }
  91. if err != nil {
  92. return nil, err
  93. }
  94. return res, nil
  95. }
  96. }
  97. // CloseIdleConnections closes any connections which were previously
  98. // connected from previous requests but are now sitting idle.
  99. // It does not interrupt any connections currently in use.
  100. func (t *Transport) CloseIdleConnections() {
  101. t.connMu.Lock()
  102. defer t.connMu.Unlock()
  103. for _, vv := range t.conns {
  104. for _, cc := range vv {
  105. cc.closeIfIdle()
  106. }
  107. }
  108. }
  109. var errClientConnClosed = errors.New("http2: client conn is closed")
  110. func shouldRetryRequest(err error) bool {
  111. // TODO: or GOAWAY graceful shutdown stuff
  112. return err == errClientConnClosed
  113. }
  114. func (t *Transport) removeClientConn(cc *clientConn) {
  115. t.connMu.Lock()
  116. defer t.connMu.Unlock()
  117. for _, key := range cc.connKey {
  118. vv, ok := t.conns[key]
  119. if !ok {
  120. continue
  121. }
  122. newList := filterOutClientConn(vv, cc)
  123. if len(newList) > 0 {
  124. t.conns[key] = newList
  125. } else {
  126. delete(t.conns, key)
  127. }
  128. }
  129. }
  130. func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
  131. out := in[:0]
  132. for _, v := range in {
  133. if v != exclude {
  134. out = append(out, v)
  135. }
  136. }
  137. return out
  138. }
  139. func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
  140. t.connMu.Lock()
  141. defer t.connMu.Unlock()
  142. key := net.JoinHostPort(host, port)
  143. for _, cc := range t.conns[key] {
  144. if cc.canTakeNewRequest() {
  145. return cc, nil
  146. }
  147. }
  148. if t.conns == nil {
  149. t.conns = make(map[string][]*clientConn)
  150. }
  151. cc, err := t.newClientConn(host, port, key)
  152. if err != nil {
  153. return nil, err
  154. }
  155. t.conns[key] = append(t.conns[key], cc)
  156. return cc, nil
  157. }
  158. func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
  159. cfg := &tls.Config{
  160. ServerName: host,
  161. NextProtos: []string{NextProtoTLS},
  162. InsecureSkipVerify: t.InsecureTLSDial,
  163. }
  164. tconn, err := tls.Dial("tcp", net.JoinHostPort(host, port), cfg)
  165. if err != nil {
  166. return nil, err
  167. }
  168. if err := tconn.Handshake(); err != nil {
  169. return nil, err
  170. }
  171. if !t.InsecureTLSDial {
  172. if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
  173. return nil, err
  174. }
  175. }
  176. state := tconn.ConnectionState()
  177. if p := state.NegotiatedProtocol; p != NextProtoTLS {
  178. // TODO(bradfitz): fall back to Fallback
  179. return nil, fmt.Errorf("bad protocol: %v", p)
  180. }
  181. if !state.NegotiatedProtocolIsMutual {
  182. return nil, errors.New("could not negotiate protocol mutually")
  183. }
  184. if _, err := tconn.Write(clientPreface); err != nil {
  185. return nil, err
  186. }
  187. cc := &clientConn{
  188. t: t,
  189. tconn: tconn,
  190. connKey: []string{key}, // TODO: cert's validated hostnames too
  191. tlsState: &state,
  192. readerDone: make(chan struct{}),
  193. nextStreamID: 1,
  194. maxFrameSize: 16 << 10, // spec default
  195. initialWindowSize: 65535, // spec default
  196. maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
  197. streams: make(map[uint32]*clientStream),
  198. }
  199. cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
  200. cc.br = bufio.NewReader(tconn)
  201. cc.fr = NewFramer(cc.bw, cc.br)
  202. cc.henc = hpack.NewEncoder(&cc.hbuf)
  203. cc.fr.WriteSettings()
  204. // TODO: re-send more conn-level flow control tokens when server uses all these.
  205. cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
  206. cc.bw.Flush()
  207. if cc.werr != nil {
  208. return nil, cc.werr
  209. }
  210. // Read the obligatory SETTINGS frame
  211. f, err := cc.fr.ReadFrame()
  212. if err != nil {
  213. return nil, err
  214. }
  215. sf, ok := f.(*SettingsFrame)
  216. if !ok {
  217. return nil, fmt.Errorf("expected settings frame, got: %T", f)
  218. }
  219. cc.fr.WriteSettingsAck()
  220. cc.bw.Flush()
  221. sf.ForeachSetting(func(s Setting) error {
  222. switch s.ID {
  223. case SettingMaxFrameSize:
  224. cc.maxFrameSize = s.Val
  225. case SettingMaxConcurrentStreams:
  226. cc.maxConcurrentStreams = s.Val
  227. case SettingInitialWindowSize:
  228. cc.initialWindowSize = s.Val
  229. default:
  230. // TODO(bradfitz): handle more
  231. log.Printf("Unhandled Setting: %v", s)
  232. }
  233. return nil
  234. })
  235. // TODO: figure out henc size
  236. cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
  237. go cc.readLoop()
  238. return cc, nil
  239. }
  240. func (cc *clientConn) setGoAway(f *GoAwayFrame) {
  241. cc.mu.Lock()
  242. defer cc.mu.Unlock()
  243. cc.goAway = f
  244. }
  245. func (cc *clientConn) canTakeNewRequest() bool {
  246. cc.mu.Lock()
  247. defer cc.mu.Unlock()
  248. return cc.goAway == nil &&
  249. int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
  250. cc.nextStreamID < 2147483647
  251. }
  252. func (cc *clientConn) closeIfIdle() {
  253. cc.mu.Lock()
  254. if len(cc.streams) > 0 {
  255. cc.mu.Unlock()
  256. return
  257. }
  258. cc.closed = true
  259. // TODO: do clients send GOAWAY too? maybe? Just Close:
  260. cc.mu.Unlock()
  261. cc.tconn.Close()
  262. }
  263. func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
  264. cc.mu.Lock()
  265. if cc.closed {
  266. cc.mu.Unlock()
  267. return nil, errClientConnClosed
  268. }
  269. cs := cc.newStream()
  270. hasBody := false // TODO
  271. // we send: HEADERS[+CONTINUATION] + (DATA?)
  272. hdrs := cc.encodeHeaders(req)
  273. first := true
  274. for len(hdrs) > 0 {
  275. chunk := hdrs
  276. if len(chunk) > int(cc.maxFrameSize) {
  277. chunk = chunk[:cc.maxFrameSize]
  278. }
  279. hdrs = hdrs[len(chunk):]
  280. endHeaders := len(hdrs) == 0
  281. if first {
  282. cc.fr.WriteHeaders(HeadersFrameParam{
  283. StreamID: cs.ID,
  284. BlockFragment: chunk,
  285. EndStream: !hasBody,
  286. EndHeaders: endHeaders,
  287. })
  288. first = false
  289. } else {
  290. cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
  291. }
  292. }
  293. cc.bw.Flush()
  294. werr := cc.werr
  295. cc.mu.Unlock()
  296. if hasBody {
  297. // TODO: write data. and it should probably be interleaved:
  298. // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
  299. }
  300. if werr != nil {
  301. return nil, werr
  302. }
  303. re := <-cs.resc
  304. if re.err != nil {
  305. return nil, re.err
  306. }
  307. res := re.res
  308. res.Request = req
  309. res.TLS = cc.tlsState
  310. return res, nil
  311. }
  312. // requires cc.mu be held.
  313. func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
  314. cc.hbuf.Reset()
  315. // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
  316. host := req.Host
  317. if host == "" {
  318. host = req.URL.Host
  319. }
  320. path := req.URL.Path
  321. if path == "" {
  322. path = "/"
  323. }
  324. cc.writeHeader(":authority", host) // probably not right for all sites
  325. cc.writeHeader(":method", req.Method)
  326. cc.writeHeader(":path", path)
  327. cc.writeHeader(":scheme", "https")
  328. for k, vv := range req.Header {
  329. lowKey := strings.ToLower(k)
  330. if lowKey == "host" {
  331. continue
  332. }
  333. for _, v := range vv {
  334. cc.writeHeader(lowKey, v)
  335. }
  336. }
  337. return cc.hbuf.Bytes()
  338. }
  339. func (cc *clientConn) writeHeader(name, value string) {
  340. log.Printf("sending %q = %q", name, value)
  341. cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
  342. }
  343. type resAndError struct {
  344. res *http.Response
  345. err error
  346. }
  347. // requires cc.mu be held.
  348. func (cc *clientConn) newStream() *clientStream {
  349. cs := &clientStream{
  350. ID: cc.nextStreamID,
  351. resc: make(chan resAndError, 1),
  352. }
  353. cc.nextStreamID += 2
  354. cc.streams[cs.ID] = cs
  355. return cs
  356. }
  357. func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
  358. cc.mu.Lock()
  359. defer cc.mu.Unlock()
  360. cs := cc.streams[id]
  361. if andRemove {
  362. delete(cc.streams, id)
  363. }
  364. return cs
  365. }
  366. // runs in its own goroutine.
  367. func (cc *clientConn) readLoop() {
  368. defer cc.t.removeClientConn(cc)
  369. defer close(cc.readerDone)
  370. activeRes := map[uint32]*clientStream{} // keyed by streamID
  371. // Close any response bodies if the server closes prematurely.
  372. // TODO: also do this if we've written the headers but not
  373. // gotten a response yet.
  374. defer func() {
  375. err := cc.readerErr
  376. if err == io.EOF {
  377. err = io.ErrUnexpectedEOF
  378. }
  379. for _, cs := range activeRes {
  380. cs.pw.CloseWithError(err)
  381. }
  382. }()
  383. // continueStreamID is the stream ID we're waiting for
  384. // continuation frames for.
  385. var continueStreamID uint32
  386. for {
  387. f, err := cc.fr.ReadFrame()
  388. if err != nil {
  389. cc.readerErr = err
  390. return
  391. }
  392. log.Printf("Transport received %v: %#v", f.Header(), f)
  393. streamID := f.Header().StreamID
  394. _, isContinue := f.(*ContinuationFrame)
  395. if isContinue {
  396. if streamID != continueStreamID {
  397. log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
  398. cc.readerErr = ConnectionError(ErrCodeProtocol)
  399. return
  400. }
  401. } else if continueStreamID != 0 {
  402. // Continue frames need to be adjacent in the stream
  403. // and we were in the middle of headers.
  404. log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
  405. cc.readerErr = ConnectionError(ErrCodeProtocol)
  406. return
  407. }
  408. if streamID%2 == 0 {
  409. // Ignore streams pushed from the server for now.
  410. // These always have an even stream id.
  411. continue
  412. }
  413. streamEnded := false
  414. if ff, ok := f.(streamEnder); ok {
  415. streamEnded = ff.StreamEnded()
  416. }
  417. cs := cc.streamByID(streamID, streamEnded)
  418. if cs == nil {
  419. log.Printf("Received frame for untracked stream ID %d", streamID)
  420. continue
  421. }
  422. switch f := f.(type) {
  423. case *HeadersFrame:
  424. cc.nextRes = &http.Response{
  425. Proto: "HTTP/2.0",
  426. ProtoMajor: 2,
  427. Header: make(http.Header),
  428. }
  429. cs.pr, cs.pw = io.Pipe()
  430. cc.hdec.Write(f.HeaderBlockFragment())
  431. case *ContinuationFrame:
  432. cc.hdec.Write(f.HeaderBlockFragment())
  433. case *DataFrame:
  434. log.Printf("DATA: %q", f.Data())
  435. cs.pw.Write(f.Data())
  436. case *GoAwayFrame:
  437. cc.t.removeClientConn(cc)
  438. if f.ErrCode != 0 {
  439. // TODO: deal with GOAWAY more. particularly the error code
  440. log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
  441. }
  442. cc.setGoAway(f)
  443. default:
  444. log.Printf("Transport: unhandled response frame type %T", f)
  445. }
  446. headersEnded := false
  447. if he, ok := f.(headersEnder); ok {
  448. headersEnded = he.HeadersEnded()
  449. if headersEnded {
  450. continueStreamID = 0
  451. } else {
  452. continueStreamID = streamID
  453. }
  454. }
  455. if streamEnded {
  456. cs.pw.Close()
  457. delete(activeRes, streamID)
  458. }
  459. if headersEnded {
  460. if cs == nil {
  461. panic("couldn't find stream") // TODO be graceful
  462. }
  463. // TODO: set the Body to one which notes the
  464. // Close and also sends the server a
  465. // RST_STREAM
  466. cc.nextRes.Body = cs.pr
  467. res := cc.nextRes
  468. activeRes[streamID] = cs
  469. cs.resc <- resAndError{res: res}
  470. }
  471. }
  472. }
  473. func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
  474. // TODO: verifiy pseudo headers come before non-pseudo headers
  475. // TODO: verifiy the status is set
  476. log.Printf("Header field: %+v", f)
  477. if f.Name == ":status" {
  478. code, err := strconv.Atoi(f.Value)
  479. if err != nil {
  480. panic("TODO: be graceful")
  481. }
  482. cc.nextRes.Status = f.Value + " " + http.StatusText(code)
  483. cc.nextRes.StatusCode = code
  484. return
  485. }
  486. if strings.HasPrefix(f.Name, ":") {
  487. // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
  488. // TODO: treat as invalid?
  489. return
  490. }
  491. cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
  492. }