transport.go 6.6 KB


  1. // Copyright 2015 The Go Authors.
  2. // See https://go.googlesource.com/go/+/master/CONTRIBUTORS
  3. // Licensed under the same terms as Go itself:
  4. // https://go.googlesource.com/go/+/master/LICENSE
  5. package http2
  6. import (
  7. "bufio"
  8. "bytes"
  9. "crypto/tls"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "net"
  15. "net/http"
  16. "strings"
  17. "sync"
  18. "github.com/bradfitz/http2/hpack"
  19. )
  20. type Transport struct {
  21. Fallback http.RoundTripper
  22. // TODO: remove this and make more general with a TLS dial hook, like http
  23. InsecureTLSDial bool
  24. }
  25. type clientConn struct {
  26. tconn *tls.Conn
  27. bw *bufio.Writer
  28. br *bufio.Reader
  29. fr *Framer
  30. readerDone chan struct{} // closed on error
  31. readerErr error // set before readerDone is closed
  32. werr error // first write error that has occurred
  33. hbuf bytes.Buffer // HPACK encoder writes into this
  34. henc *hpack.Encoder
  35. hdec *hpack.Decoder
  36. nextRes http.Header
  37. // Settings from peer:
  38. maxFrameSize uint32
  39. mu sync.Mutex
  40. streams map[uint32]*clientStream
  41. nextStreamID uint32
  42. }
  43. type clientStream struct {
  44. ID uint32
  45. resc chan *http.Response
  46. pw *io.PipeWriter
  47. pr *io.PipeReader
  48. }
  49. type stickyErrWriter struct {
  50. w io.Writer
  51. err *error
  52. }
  53. func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
  54. if *sew.err != nil {
  55. return 0, *sew.err
  56. }
  57. n, err = sew.w.Write(p)
  58. *sew.err = err
  59. return
  60. }
  61. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  62. if req.URL.Scheme != "https" {
  63. if t.Fallback == nil {
  64. return nil, errors.New("http2: unsupported scheme and no Fallback")
  65. }
  66. return t.Fallback.RoundTrip(req)
  67. }
  68. host, port, err := net.SplitHostPort(req.URL.Host)
  69. if err != nil {
  70. host = req.URL.Host
  71. port = "443"
  72. }
  73. cfg := &tls.Config{
  74. ServerName: host,
  75. NextProtos: []string{NextProtoTLS},
  76. InsecureSkipVerify: t.InsecureTLSDial,
  77. }
  78. tconn, err := tls.Dial("tcp", host+":"+port, cfg)
  79. if err != nil {
  80. return nil, err
  81. }
  82. if err := tconn.Handshake(); err != nil {
  83. return nil, err
  84. }
  85. if !t.InsecureTLSDial {
  86. if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
  87. return nil, err
  88. }
  89. }
  90. state := tconn.ConnectionState()
  91. if p := state.NegotiatedProtocol; p != NextProtoTLS {
  92. // TODO(bradfitz): fall back to Fallback
  93. return nil, fmt.Errorf("bad protocol: %v", p)
  94. }
  95. if !state.NegotiatedProtocolIsMutual {
  96. return nil, errors.New("could not negotiate protocol mutually")
  97. }
  98. if _, err := tconn.Write(clientPreface); err != nil {
  99. return nil, err
  100. }
  101. cc := &clientConn{
  102. tconn: tconn,
  103. readerDone: make(chan struct{}),
  104. nextStreamID: 1,
  105. streams: make(map[uint32]*clientStream),
  106. }
  107. cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
  108. cc.br = bufio.NewReader(tconn)
  109. cc.fr = NewFramer(cc.bw, cc.br)
  110. cc.henc = hpack.NewEncoder(&cc.hbuf)
  111. cc.fr.WriteSettings()
  112. // TODO: re-send more conn-level flow control tokens when server uses all these.
  113. cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
  114. cc.bw.Flush()
  115. if cc.werr != nil {
  116. return nil, cc.werr
  117. }
  118. // Read the obligatory SETTINGS frame
  119. f, err := cc.fr.ReadFrame()
  120. if err != nil {
  121. return nil, err
  122. }
  123. sf, ok := f.(*SettingsFrame)
  124. if !ok {
  125. return nil, fmt.Errorf("expected settings frame, got: %T", f)
  126. }
  127. cc.fr.WriteSettingsAck()
  128. cc.bw.Flush()
  129. sf.ForeachSetting(func(s Setting) error {
  130. switch s.ID {
  131. case SettingMaxFrameSize:
  132. cc.maxFrameSize = s.Val
  133. // TODO(bradfitz): handle the others
  134. default:
  135. log.Printf("Unhandled Setting: %v", s)
  136. }
  137. return nil
  138. })
  139. // TODO: figure out henc size
  140. cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
  141. go cc.readLoop()
  142. cs := cc.newStream()
  143. hasBody := false // TODO
  144. // we send: HEADERS[+CONTINUATION] + (DATA?)
  145. hdrs := cc.encodeHeaders(req)
  146. first := true
  147. for len(hdrs) > 0 {
  148. chunk := hdrs
  149. if len(chunk) > int(cc.maxFrameSize) {
  150. chunk = chunk[:cc.maxFrameSize]
  151. }
  152. hdrs = hdrs[len(chunk):]
  153. endHeaders := len(hdrs) == 0
  154. if first {
  155. cc.fr.WriteHeaders(HeadersFrameParam{
  156. StreamID: cs.ID,
  157. BlockFragment: chunk,
  158. EndStream: !hasBody,
  159. EndHeaders: endHeaders,
  160. })
  161. first = false
  162. } else {
  163. cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
  164. }
  165. }
  166. cc.bw.Flush()
  167. if cc.werr != nil {
  168. return nil, cc.werr
  169. }
  170. return <-cs.resc, nil
  171. }
  172. func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
  173. cc.hbuf.Reset()
  174. // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
  175. host := req.Host
  176. if host == "" {
  177. host = req.URL.Host
  178. }
  179. cc.writeHeader(":authority", host) // probably not right for all sites
  180. cc.writeHeader(":method", req.Method)
  181. cc.writeHeader(":path", req.URL.Path)
  182. cc.writeHeader(":scheme", "https")
  183. for k, vv := range req.Header {
  184. lowKey := strings.ToLower(k)
  185. if lowKey == "host" {
  186. continue
  187. }
  188. for _, v := range vv {
  189. cc.writeHeader(lowKey, v)
  190. }
  191. }
  192. return cc.hbuf.Bytes()
  193. }
  194. func (cc *clientConn) writeHeader(name, value string) {
  195. log.Printf("sending %q = %q", name, value)
  196. cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
  197. }
  198. func (cc *clientConn) newStream() *clientStream {
  199. cc.mu.Lock()
  200. defer cc.mu.Unlock()
  201. cs := &clientStream{
  202. ID: cc.nextStreamID,
  203. resc: make(chan *http.Response, 1),
  204. }
  205. cc.nextStreamID += 2
  206. cc.streams[cs.ID] = cs
  207. return cs
  208. }
  209. func (cc *clientConn) streamByID(id uint32) *clientStream {
  210. cc.mu.Lock()
  211. defer cc.mu.Unlock()
  212. return cc.streams[id]
  213. }
  214. // runs in its own goroutine.
  215. func (cc *clientConn) readLoop() {
  216. defer close(cc.readerDone)
  217. for {
  218. f, err := cc.fr.ReadFrame()
  219. if err != nil {
  220. cc.readerErr = err
  221. // TODO: don't log it.
  222. log.Printf("ReadFrame: %v", err)
  223. return
  224. }
  225. cs := cc.streamByID(f.Header().StreamID)
  226. log.Printf("Read %v: %#v", f.Header(), f)
  227. headersEnded := false
  228. streamEnded := false
  229. if ff, ok := f.(interface {
  230. StreamEnded() bool
  231. }); ok {
  232. streamEnded = ff.StreamEnded()
  233. }
  234. switch f := f.(type) {
  235. case *HeadersFrame:
  236. cc.nextRes = make(http.Header)
  237. cs.pr, cs.pw = io.Pipe()
  238. cc.hdec.Write(f.HeaderBlockFragment())
  239. headersEnded = f.HeadersEnded()
  240. case *ContinuationFrame:
  241. // TODO: verify stream id is the same
  242. cc.hdec.Write(f.HeaderBlockFragment())
  243. headersEnded = f.HeadersEnded()
  244. case *DataFrame:
  245. log.Printf("DATA: %q", f.Data())
  246. cs.pw.Write(f.Data())
  247. default:
  248. }
  249. if streamEnded {
  250. cs.pw.Close()
  251. }
  252. if headersEnded {
  253. if cs == nil {
  254. panic("couldn't find stream") // TODO be graceful
  255. }
  256. cs.resc <- &http.Response{
  257. Header: cc.nextRes,
  258. Body: cs.pr,
  259. }
  260. }
  261. }
  262. }
  263. func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
  264. log.Printf("Header field: %+v", f)
  265. cc.nextRes.Add(http.CanonicalHeaderKey(f.Name), f.Value)
  266. }