transport.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. // Copyright 2015 The Go Authors.
  2. // See https://go.googlesource.com/go/+/master/CONTRIBUTORS
  3. // Licensed under the same terms as Go itself:
  4. // https://go.googlesource.com/go/+/master/LICENSE
  5. package http2
  6. import (
  7. "bufio"
  8. "bytes"
  9. "crypto/tls"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "net"
  15. "net/http"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "github.com/bradfitz/http2/hpack"
  20. )
  21. type Transport struct {
  22. Fallback http.RoundTripper
  23. // TODO: remove this and make more general with a TLS dial hook, like http
  24. InsecureTLSDial bool
  25. }
  26. type clientConn struct {
  27. tconn *tls.Conn
  28. tlsState *tls.ConnectionState
  29. readerDone chan struct{} // closed on error
  30. readerErr error // set before readerDone is closed
  31. hbuf bytes.Buffer // HPACK encoder writes into this
  32. henc *hpack.Encoder
  33. hdec *hpack.Decoder
  34. nextRes *http.Response
  35. // Settings from peer:
  36. maxFrameSize uint32
  37. mu sync.Mutex
  38. streams map[uint32]*clientStream
  39. nextStreamID uint32
  40. bw *bufio.Writer
  41. werr error // first write error that has occurred
  42. br *bufio.Reader
  43. fr *Framer
  44. }
  45. type clientStream struct {
  46. ID uint32
  47. resc chan resAndError
  48. pw *io.PipeWriter
  49. pr *io.PipeReader
  50. }
  51. type stickyErrWriter struct {
  52. w io.Writer
  53. err *error
  54. }
  55. func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
  56. if *sew.err != nil {
  57. return 0, *sew.err
  58. }
  59. n, err = sew.w.Write(p)
  60. *sew.err = err
  61. return
  62. }
  63. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  64. if req.URL.Scheme != "https" {
  65. if t.Fallback == nil {
  66. return nil, errors.New("http2: unsupported scheme and no Fallback")
  67. }
  68. return t.Fallback.RoundTrip(req)
  69. }
  70. host, port, err := net.SplitHostPort(req.URL.Host)
  71. if err != nil {
  72. host = req.URL.Host
  73. port = "443"
  74. }
  75. for {
  76. cc, err := t.getClientConn(host, port)
  77. if err != nil {
  78. return nil, err
  79. }
  80. res, err := cc.roundTrip(req)
  81. if isShutdownError(err) {
  82. continue
  83. }
  84. if err != nil {
  85. return nil, err
  86. }
  87. return res, nil
  88. }
  89. }
  90. func isShutdownError(err error) bool {
  91. // TODO: implement
  92. return false
  93. }
  94. func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
  95. // TODO: cache these
  96. cfg := &tls.Config{
  97. ServerName: host,
  98. NextProtos: []string{NextProtoTLS},
  99. InsecureSkipVerify: t.InsecureTLSDial,
  100. }
  101. tconn, err := tls.Dial("tcp", host+":"+port, cfg)
  102. if err != nil {
  103. return nil, err
  104. }
  105. if err := tconn.Handshake(); err != nil {
  106. return nil, err
  107. }
  108. if !t.InsecureTLSDial {
  109. if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
  110. return nil, err
  111. }
  112. }
  113. state := tconn.ConnectionState()
  114. if p := state.NegotiatedProtocol; p != NextProtoTLS {
  115. // TODO(bradfitz): fall back to Fallback
  116. return nil, fmt.Errorf("bad protocol: %v", p)
  117. }
  118. if !state.NegotiatedProtocolIsMutual {
  119. return nil, errors.New("could not negotiate protocol mutually")
  120. }
  121. if _, err := tconn.Write(clientPreface); err != nil {
  122. return nil, err
  123. }
  124. cc := &clientConn{
  125. tconn: tconn,
  126. tlsState: &state,
  127. readerDone: make(chan struct{}),
  128. nextStreamID: 1,
  129. streams: make(map[uint32]*clientStream),
  130. }
  131. cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
  132. cc.br = bufio.NewReader(tconn)
  133. cc.fr = NewFramer(cc.bw, cc.br)
  134. cc.henc = hpack.NewEncoder(&cc.hbuf)
  135. cc.fr.WriteSettings()
  136. // TODO: re-send more conn-level flow control tokens when server uses all these.
  137. cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
  138. cc.bw.Flush()
  139. if cc.werr != nil {
  140. return nil, cc.werr
  141. }
  142. // Read the obligatory SETTINGS frame
  143. f, err := cc.fr.ReadFrame()
  144. if err != nil {
  145. return nil, err
  146. }
  147. sf, ok := f.(*SettingsFrame)
  148. if !ok {
  149. return nil, fmt.Errorf("expected settings frame, got: %T", f)
  150. }
  151. cc.fr.WriteSettingsAck()
  152. cc.bw.Flush()
  153. sf.ForeachSetting(func(s Setting) error {
  154. switch s.ID {
  155. case SettingMaxFrameSize:
  156. cc.maxFrameSize = s.Val
  157. // TODO(bradfitz): handle the others
  158. default:
  159. log.Printf("Unhandled Setting: %v", s)
  160. }
  161. return nil
  162. })
  163. // TODO: figure out henc size
  164. cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
  165. go cc.readLoop()
  166. return cc, nil
  167. }
  168. func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
  169. cc.mu.Lock()
  170. cs := cc.newStream()
  171. hasBody := false // TODO
  172. // we send: HEADERS[+CONTINUATION] + (DATA?)
  173. hdrs := cc.encodeHeaders(req)
  174. first := true
  175. for len(hdrs) > 0 {
  176. chunk := hdrs
  177. if len(chunk) > int(cc.maxFrameSize) {
  178. chunk = chunk[:cc.maxFrameSize]
  179. }
  180. hdrs = hdrs[len(chunk):]
  181. endHeaders := len(hdrs) == 0
  182. if first {
  183. cc.fr.WriteHeaders(HeadersFrameParam{
  184. StreamID: cs.ID,
  185. BlockFragment: chunk,
  186. EndStream: !hasBody,
  187. EndHeaders: endHeaders,
  188. })
  189. first = false
  190. } else {
  191. cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
  192. }
  193. }
  194. cc.bw.Flush()
  195. werr := cc.werr
  196. cc.mu.Unlock()
  197. if werr != nil {
  198. return nil, werr
  199. }
  200. re := <-cs.resc
  201. if re.err != nil {
  202. return nil, re.err
  203. }
  204. res := re.res
  205. res.Request = req
  206. res.TLS = cc.tlsState
  207. return res, nil
  208. }
  209. // requires cc.mu be held.
  210. func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
  211. cc.hbuf.Reset()
  212. // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
  213. host := req.Host
  214. if host == "" {
  215. host = req.URL.Host
  216. }
  217. path := req.URL.Path
  218. if path == "" {
  219. path = "/"
  220. }
  221. cc.writeHeader(":authority", host) // probably not right for all sites
  222. cc.writeHeader(":method", req.Method)
  223. cc.writeHeader(":path", path)
  224. cc.writeHeader(":scheme", "https")
  225. for k, vv := range req.Header {
  226. lowKey := strings.ToLower(k)
  227. if lowKey == "host" {
  228. continue
  229. }
  230. for _, v := range vv {
  231. cc.writeHeader(lowKey, v)
  232. }
  233. }
  234. return cc.hbuf.Bytes()
  235. }
  236. func (cc *clientConn) writeHeader(name, value string) {
  237. log.Printf("sending %q = %q", name, value)
  238. cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
  239. }
  240. type resAndError struct {
  241. res *http.Response
  242. err error
  243. }
  244. // requires cc.mu be held.
  245. func (cc *clientConn) newStream() *clientStream {
  246. cs := &clientStream{
  247. ID: cc.nextStreamID,
  248. resc: make(chan resAndError, 1),
  249. }
  250. cc.nextStreamID += 2
  251. cc.streams[cs.ID] = cs
  252. return cs
  253. }
  254. func (cc *clientConn) streamByID(id uint32) *clientStream {
  255. cc.mu.Lock()
  256. defer cc.mu.Unlock()
  257. return cc.streams[id]
  258. }
  259. // runs in its own goroutine.
  260. func (cc *clientConn) readLoop() {
  261. defer close(cc.readerDone)
  262. activeRes := map[uint32]*clientStream{} // keyed by streamID
  263. // Close any response bodies if the server closes prematurely.
  264. // TODO: also do this if we've written the headers but not
  265. // gotten a response yet.
  266. defer func() {
  267. err := cc.readerErr
  268. if err == io.EOF {
  269. err = io.ErrUnexpectedEOF
  270. }
  271. for _, cs := range activeRes {
  272. cs.pw.CloseWithError(err)
  273. }
  274. }()
  275. // continueStreamID is the stream ID we're waiting for
  276. // continuation frames for.
  277. var continueStreamID uint32
  278. for {
  279. f, err := cc.fr.ReadFrame()
  280. if err != nil {
  281. cc.readerErr = err
  282. return
  283. }
  284. log.Printf("Transport received %v: %#v", f.Header(), f)
  285. streamID := f.Header().StreamID
  286. _, isContinue := f.(*ContinuationFrame)
  287. if isContinue {
  288. if streamID != continueStreamID {
  289. cc.readerErr = ConnectionError(ErrCodeProtocol)
  290. return
  291. }
  292. } else if continueStreamID != 0 {
  293. // Continue frames need to be adjacent in the stream
  294. // and we were in the middle of headers.
  295. cc.readerErr = ConnectionError(ErrCodeProtocol)
  296. return
  297. }
  298. if streamID%2 == 0 {
  299. // Ignore streams pushed from the server for now.
  300. // These always have an even stream id.
  301. continue
  302. }
  303. cs := cc.streamByID(streamID)
  304. if cs == nil {
  305. log.Printf("Received frame for untracked stream ID %d", streamID)
  306. continue
  307. }
  308. headersEnded := false
  309. streamEnded := false
  310. if ff, ok := f.(streamEnder); ok {
  311. streamEnded = ff.StreamEnded()
  312. }
  313. switch f := f.(type) {
  314. case *HeadersFrame:
  315. cc.nextRes = &http.Response{
  316. Proto: "HTTP/2.0",
  317. ProtoMajor: 2,
  318. Header: make(http.Header),
  319. }
  320. cs.pr, cs.pw = io.Pipe()
  321. cc.hdec.Write(f.HeaderBlockFragment())
  322. headersEnded = f.HeadersEnded()
  323. case *ContinuationFrame:
  324. cc.hdec.Write(f.HeaderBlockFragment())
  325. headersEnded = f.HeadersEnded()
  326. case *DataFrame:
  327. log.Printf("DATA: %q", f.Data())
  328. cs.pw.Write(f.Data())
  329. default:
  330. }
  331. if headersEnded {
  332. continueStreamID = 0
  333. } else {
  334. continueStreamID = streamID
  335. }
  336. if streamEnded {
  337. cs.pw.Close()
  338. delete(activeRes, streamID)
  339. }
  340. if headersEnded {
  341. if cs == nil {
  342. panic("couldn't find stream") // TODO be graceful
  343. }
  344. cc.nextRes.Body = cs.pr
  345. res := cc.nextRes
  346. activeRes[streamID] = cs
  347. cs.resc <- resAndError{res: res}
  348. }
  349. }
  350. }
  351. func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
  352. // TODO: verifiy pseudo headers come before non-pseudo headers
  353. // TODO: verifiy the status is set
  354. log.Printf("Header field: %+v", f)
  355. if f.Name == ":status" {
  356. code, err := strconv.Atoi(f.Value)
  357. if err != nil {
  358. panic("TODO: be graceful")
  359. }
  360. cc.nextRes.Status = f.Value + " " + http.StatusText(code)
  361. cc.nextRes.StatusCode = code
  362. return
  363. }
  364. if strings.HasPrefix(f.Name, ":") {
  365. // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
  366. // TODO: treat as invalid?
  367. return
  368. }
  369. cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
  370. }