transport.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. // Copyright 2015 The Go Authors.
  2. // See https://go.googlesource.com/go/+/master/CONTRIBUTORS
  3. // Licensed under the same terms as Go itself:
  4. // https://go.googlesource.com/go/+/master/LICENSE
  5. package http2
  6. import (
  7. "bufio"
  8. "bytes"
  9. "crypto/tls"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "net"
  15. "net/http"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "github.com/bradfitz/http2/hpack"
  20. )
  21. type Transport struct {
  22. Fallback http.RoundTripper
  23. // TODO: remove this and make more general with a TLS dial hook, like http
  24. InsecureTLSDial bool
  25. }
  26. type clientConn struct {
  27. tconn *tls.Conn
  28. bw *bufio.Writer
  29. br *bufio.Reader
  30. fr *Framer
  31. readerDone chan struct{} // closed on error
  32. readerErr error // set before readerDone is closed
  33. werr error // first write error that has occurred
  34. hbuf bytes.Buffer // HPACK encoder writes into this
  35. henc *hpack.Encoder
  36. hdec *hpack.Decoder
  37. nextRes *http.Response
  38. // Settings from peer:
  39. maxFrameSize uint32
  40. mu sync.Mutex
  41. streams map[uint32]*clientStream
  42. nextStreamID uint32
  43. }
  44. type clientStream struct {
  45. ID uint32
  46. resc chan *http.Response
  47. pw *io.PipeWriter
  48. pr *io.PipeReader
  49. }
  50. type stickyErrWriter struct {
  51. w io.Writer
  52. err *error
  53. }
  54. func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
  55. if *sew.err != nil {
  56. return 0, *sew.err
  57. }
  58. n, err = sew.w.Write(p)
  59. *sew.err = err
  60. return
  61. }
  62. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  63. if req.URL.Scheme != "https" {
  64. if t.Fallback == nil {
  65. return nil, errors.New("http2: unsupported scheme and no Fallback")
  66. }
  67. return t.Fallback.RoundTrip(req)
  68. }
  69. host, port, err := net.SplitHostPort(req.URL.Host)
  70. if err != nil {
  71. host = req.URL.Host
  72. port = "443"
  73. }
  74. cfg := &tls.Config{
  75. ServerName: host,
  76. NextProtos: []string{NextProtoTLS},
  77. InsecureSkipVerify: t.InsecureTLSDial,
  78. }
  79. tconn, err := tls.Dial("tcp", host+":"+port, cfg)
  80. if err != nil {
  81. return nil, err
  82. }
  83. if err := tconn.Handshake(); err != nil {
  84. return nil, err
  85. }
  86. if !t.InsecureTLSDial {
  87. if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
  88. return nil, err
  89. }
  90. }
  91. state := tconn.ConnectionState()
  92. if p := state.NegotiatedProtocol; p != NextProtoTLS {
  93. // TODO(bradfitz): fall back to Fallback
  94. return nil, fmt.Errorf("bad protocol: %v", p)
  95. }
  96. if !state.NegotiatedProtocolIsMutual {
  97. return nil, errors.New("could not negotiate protocol mutually")
  98. }
  99. if _, err := tconn.Write(clientPreface); err != nil {
  100. return nil, err
  101. }
  102. cc := &clientConn{
  103. tconn: tconn,
  104. readerDone: make(chan struct{}),
  105. nextStreamID: 1,
  106. streams: make(map[uint32]*clientStream),
  107. }
  108. cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
  109. cc.br = bufio.NewReader(tconn)
  110. cc.fr = NewFramer(cc.bw, cc.br)
  111. cc.henc = hpack.NewEncoder(&cc.hbuf)
  112. cc.fr.WriteSettings()
  113. // TODO: re-send more conn-level flow control tokens when server uses all these.
  114. cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
  115. cc.bw.Flush()
  116. if cc.werr != nil {
  117. return nil, cc.werr
  118. }
  119. // Read the obligatory SETTINGS frame
  120. f, err := cc.fr.ReadFrame()
  121. if err != nil {
  122. return nil, err
  123. }
  124. sf, ok := f.(*SettingsFrame)
  125. if !ok {
  126. return nil, fmt.Errorf("expected settings frame, got: %T", f)
  127. }
  128. cc.fr.WriteSettingsAck()
  129. cc.bw.Flush()
  130. sf.ForeachSetting(func(s Setting) error {
  131. switch s.ID {
  132. case SettingMaxFrameSize:
  133. cc.maxFrameSize = s.Val
  134. // TODO(bradfitz): handle the others
  135. default:
  136. log.Printf("Unhandled Setting: %v", s)
  137. }
  138. return nil
  139. })
  140. // TODO: figure out henc size
  141. cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
  142. go cc.readLoop()
  143. cs := cc.newStream()
  144. hasBody := false // TODO
  145. // we send: HEADERS[+CONTINUATION] + (DATA?)
  146. hdrs := cc.encodeHeaders(req)
  147. first := true
  148. for len(hdrs) > 0 {
  149. chunk := hdrs
  150. if len(chunk) > int(cc.maxFrameSize) {
  151. chunk = chunk[:cc.maxFrameSize]
  152. }
  153. hdrs = hdrs[len(chunk):]
  154. endHeaders := len(hdrs) == 0
  155. if first {
  156. cc.fr.WriteHeaders(HeadersFrameParam{
  157. StreamID: cs.ID,
  158. BlockFragment: chunk,
  159. EndStream: !hasBody,
  160. EndHeaders: endHeaders,
  161. })
  162. first = false
  163. } else {
  164. cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
  165. }
  166. }
  167. cc.bw.Flush()
  168. if cc.werr != nil {
  169. return nil, cc.werr
  170. }
  171. return <-cs.resc, nil
  172. }
  173. func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
  174. cc.hbuf.Reset()
  175. // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
  176. host := req.Host
  177. if host == "" {
  178. host = req.URL.Host
  179. }
  180. path := req.URL.Path
  181. if path == "" {
  182. path = "/"
  183. }
  184. cc.writeHeader(":authority", host) // probably not right for all sites
  185. cc.writeHeader(":method", req.Method)
  186. cc.writeHeader(":path", path)
  187. cc.writeHeader(":scheme", "https")
  188. for k, vv := range req.Header {
  189. lowKey := strings.ToLower(k)
  190. if lowKey == "host" {
  191. continue
  192. }
  193. for _, v := range vv {
  194. cc.writeHeader(lowKey, v)
  195. }
  196. }
  197. return cc.hbuf.Bytes()
  198. }
  199. func (cc *clientConn) writeHeader(name, value string) {
  200. log.Printf("sending %q = %q", name, value)
  201. cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
  202. }
  203. func (cc *clientConn) newStream() *clientStream {
  204. cc.mu.Lock()
  205. defer cc.mu.Unlock()
  206. cs := &clientStream{
  207. ID: cc.nextStreamID,
  208. resc: make(chan *http.Response, 1),
  209. }
  210. cc.nextStreamID += 2
  211. cc.streams[cs.ID] = cs
  212. return cs
  213. }
  214. func (cc *clientConn) streamByID(id uint32) *clientStream {
  215. cc.mu.Lock()
  216. defer cc.mu.Unlock()
  217. return cc.streams[id]
  218. }
  219. // runs in its own goroutine.
  220. func (cc *clientConn) readLoop() {
  221. defer close(cc.readerDone)
  222. for {
  223. f, err := cc.fr.ReadFrame()
  224. if err != nil {
  225. cc.readerErr = err
  226. // TODO: don't log it.
  227. log.Printf("ReadFrame: %v", err)
  228. return
  229. }
  230. cs := cc.streamByID(f.Header().StreamID)
  231. log.Printf("Transport received %v: %#v", f.Header(), f)
  232. headersEnded := false
  233. streamEnded := false
  234. if ff, ok := f.(interface {
  235. StreamEnded() bool
  236. }); ok {
  237. streamEnded = ff.StreamEnded()
  238. }
  239. switch f := f.(type) {
  240. case *HeadersFrame:
  241. cc.nextRes = &http.Response{
  242. Proto: "HTTP/2.0",
  243. ProtoMajor: 2,
  244. Header: make(http.Header),
  245. Request: nil, // TODO: set this
  246. TLS: nil, // TODO: set this
  247. }
  248. cs.pr, cs.pw = io.Pipe()
  249. cc.hdec.Write(f.HeaderBlockFragment())
  250. headersEnded = f.HeadersEnded()
  251. case *ContinuationFrame:
  252. // TODO: verify stream id is the same
  253. cc.hdec.Write(f.HeaderBlockFragment())
  254. headersEnded = f.HeadersEnded()
  255. case *DataFrame:
  256. log.Printf("DATA: %q", f.Data())
  257. cs.pw.Write(f.Data())
  258. default:
  259. }
  260. if streamEnded {
  261. cs.pw.Close()
  262. }
  263. if headersEnded {
  264. if cs == nil {
  265. panic("couldn't find stream") // TODO be graceful
  266. }
  267. cc.nextRes.Body = cs.pr
  268. cs.resc <- cc.nextRes
  269. }
  270. }
  271. }
  272. func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
  273. // TODO: verifiy pseudo headers come before non-pseudo headers
  274. // TODO: verifiy the status is set
  275. log.Printf("Header field: %+v", f)
  276. if f.Name == ":status" {
  277. code, err := strconv.Atoi(f.Value)
  278. if err != nil {
  279. panic("TODO: be graceful")
  280. }
  281. cc.nextRes.Status = f.Value + " " + http.StatusText(code)
  282. cc.nextRes.StatusCode = code
  283. return
  284. }
  285. if strings.HasPrefix(f.Name, ":") {
  286. // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
  287. // TODO: treat as invalid?
  288. return
  289. }
  290. cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
  291. }