|
|
@@ -0,0 +1,297 @@
|
|
|
+// Copyright 2015 The Go Authors.
|
|
|
+// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
|
|
|
+// Licensed under the same terms as Go itself:
|
|
|
+// https://go.googlesource.com/go/+/master/LICENSE
|
|
|
+
|
|
|
+package http2
|
|
|
+
|
|
|
+import (
|
|
|
+ "bufio"
|
|
|
+ "bytes"
|
|
|
+ "crypto/tls"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "log"
|
|
|
+ "net"
|
|
|
+ "net/http"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+
|
|
|
+ "github.com/bradfitz/http2/hpack"
|
|
|
+)
|
|
|
+
|
|
|
+type Transport struct {
|
|
|
+ Fallback http.RoundTripper
|
|
|
+}
|
|
|
+
|
|
|
+type clientConn struct {
|
|
|
+ tconn *tls.Conn
|
|
|
+ bw *bufio.Writer
|
|
|
+ br *bufio.Reader
|
|
|
+ fr *Framer
|
|
|
+
|
|
|
+ readerDone chan struct{} // closed on error
|
|
|
+ readerErr error // set before readerDone is closed
|
|
|
+
|
|
|
+ werr error // first write error that has occurred
|
|
|
+
|
|
|
+ hbuf bytes.Buffer // HPACK encoder writes into this
|
|
|
+ henc *hpack.Encoder
|
|
|
+
|
|
|
+ hdec *hpack.Decoder
|
|
|
+
|
|
|
+ nextRes http.Header
|
|
|
+
|
|
|
+ // Settings from peer:
|
|
|
+ maxFrameSize uint32
|
|
|
+
|
|
|
+ mu sync.Mutex
|
|
|
+ streams map[uint32]*clientStream
|
|
|
+ nextStreamID uint32
|
|
|
+}
|
|
|
+
|
|
|
+type clientStream struct {
|
|
|
+ ID uint32
|
|
|
+ resc chan *http.Response
|
|
|
+ pw *io.PipeWriter
|
|
|
+ pr *io.PipeReader
|
|
|
+}
|
|
|
+
|
|
|
+type stickyErrWriter struct {
|
|
|
+ w io.Writer
|
|
|
+ err *error
|
|
|
+}
|
|
|
+
|
|
|
+func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
|
|
|
+ if *sew.err != nil {
|
|
|
+ return 0, *sew.err
|
|
|
+ }
|
|
|
+ n, err = sew.w.Write(p)
|
|
|
+ *sew.err = err
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
+ if req.URL.Scheme != "https" {
|
|
|
+ if t.Fallback == nil {
|
|
|
+ return nil, errors.New("http2: unsupported scheme and no Fallback")
|
|
|
+ }
|
|
|
+ return t.Fallback.RoundTrip(req)
|
|
|
+ }
|
|
|
+
|
|
|
+ host, port, err := net.SplitHostPort(req.URL.Host)
|
|
|
+ if err != nil {
|
|
|
+ host = req.URL.Host
|
|
|
+ port = "443"
|
|
|
+ }
|
|
|
+ cfg := &tls.Config{
|
|
|
+ ServerName: host,
|
|
|
+ NextProtos: []string{NextProtoTLS},
|
|
|
+ }
|
|
|
+ tconn, err := tls.Dial("tcp", host+":"+port, cfg)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if err := tconn.Handshake(); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ state := tconn.ConnectionState()
|
|
|
+ if p := state.NegotiatedProtocol; p != NextProtoTLS {
|
|
|
+ // TODO(bradfitz): fall back to Fallback
|
|
|
+ return nil, fmt.Errorf("bad protocol: %v", p)
|
|
|
+ }
|
|
|
+ if !state.NegotiatedProtocolIsMutual {
|
|
|
+ return nil, errors.New("could not negotiate protocol mutually")
|
|
|
+ }
|
|
|
+ if _, err := tconn.Write(clientPreface); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ cc := &clientConn{
|
|
|
+ tconn: tconn,
|
|
|
+ readerDone: make(chan struct{}),
|
|
|
+ nextStreamID: 1,
|
|
|
+ streams: make(map[uint32]*clientStream),
|
|
|
+ }
|
|
|
+ cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
|
|
|
+ cc.br = bufio.NewReader(tconn)
|
|
|
+ cc.fr = NewFramer(cc.bw, cc.br)
|
|
|
+ cc.henc = hpack.NewEncoder(&cc.hbuf)
|
|
|
+
|
|
|
+ cc.fr.WriteSettings()
|
|
|
+ cc.bw.Flush()
|
|
|
+ if cc.werr != nil {
|
|
|
+ return nil, cc.werr
|
|
|
+ }
|
|
|
+
|
|
|
+ // Read the obligatory SETTINGS frame
|
|
|
+ f, err := cc.fr.ReadFrame()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ sf, ok := f.(*SettingsFrame)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("expected settings frame, got: %T", f)
|
|
|
+ }
|
|
|
+ cc.fr.WriteSettingsAck()
|
|
|
+ cc.bw.Flush()
|
|
|
+
|
|
|
+ sf.ForeachSetting(func(s Setting) error {
|
|
|
+ switch s.ID {
|
|
|
+ case SettingMaxFrameSize:
|
|
|
+ cc.maxFrameSize = s.Val
|
|
|
+ // TODO(bradfitz): handle the others
|
|
|
+ default:
|
|
|
+ log.Printf("Unhandled Setting: %v", s)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+ // TODO: figure out henc size
|
|
|
+ cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
|
|
|
+
|
|
|
+ go cc.readLoop()
|
|
|
+
|
|
|
+ cs := cc.newStream()
|
|
|
+ hasBody := false // TODO
|
|
|
+
|
|
|
+ // we send: HEADERS[+CONTINUATION] + (DATA?)
|
|
|
+ hdrs := cc.encodeHeaders(req)
|
|
|
+ first := true
|
|
|
+ for len(hdrs) > 0 {
|
|
|
+ chunk := hdrs
|
|
|
+ if len(chunk) > int(cc.maxFrameSize) {
|
|
|
+ chunk = chunk[:cc.maxFrameSize]
|
|
|
+ }
|
|
|
+ hdrs = hdrs[len(chunk):]
|
|
|
+ endHeaders := len(hdrs) == 0
|
|
|
+ if first {
|
|
|
+ cc.fr.WriteHeaders(HeadersFrameParam{
|
|
|
+ StreamID: cs.ID,
|
|
|
+ BlockFragment: chunk,
|
|
|
+ EndStream: !hasBody,
|
|
|
+ EndHeaders: endHeaders,
|
|
|
+ })
|
|
|
+ first = false
|
|
|
+ } else {
|
|
|
+ cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ cc.bw.Flush()
|
|
|
+ if cc.werr != nil {
|
|
|
+ return nil, cc.werr
|
|
|
+ }
|
|
|
+
|
|
|
+ return <-cs.resc, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
|
|
|
+ cc.hbuf.Reset()
|
|
|
+
|
|
|
+ // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
|
|
|
+ host := req.Host
|
|
|
+ if host == "" {
|
|
|
+ host = req.URL.Host
|
|
|
+ }
|
|
|
+
|
|
|
+ cc.writeHeader(":method", req.Method)
|
|
|
+ cc.writeHeader(":scheme", "https")
|
|
|
+ cc.writeHeader(":authority", host) // probably not right for all sites
|
|
|
+ cc.writeHeader(":path", req.URL.Path)
|
|
|
+
|
|
|
+ for k, vv := range req.Header {
|
|
|
+ for _, v := range vv {
|
|
|
+ cc.writeHeader(strings.ToLower(k), v)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if _, ok := req.Header["Host"]; !ok {
|
|
|
+ cc.writeHeader("host", host)
|
|
|
+ }
|
|
|
+
|
|
|
+ return cc.hbuf.Bytes()
|
|
|
+}
|
|
|
+
|
|
|
+func (cc *clientConn) writeHeader(name, value string) {
|
|
|
+ log.Printf("sending %q = %q", name, value)
|
|
|
+ cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
|
|
+}
|
|
|
+
|
|
|
+func (cc *clientConn) newStream() *clientStream {
|
|
|
+ cc.mu.Lock()
|
|
|
+ defer cc.mu.Unlock()
|
|
|
+
|
|
|
+ cs := &clientStream{
|
|
|
+ ID: cc.nextStreamID,
|
|
|
+ resc: make(chan *http.Response, 1),
|
|
|
+ }
|
|
|
+ cc.nextStreamID += 2
|
|
|
+ cc.streams[cs.ID] = cs
|
|
|
+
|
|
|
+ return cs
|
|
|
+}
|
|
|
+
|
|
|
+func (cc *clientConn) streamByID(id uint32) *clientStream {
|
|
|
+ cc.mu.Lock()
|
|
|
+ defer cc.mu.Unlock()
|
|
|
+ return cc.streams[id]
|
|
|
+}
|
|
|
+
|
|
|
+// runs in its own goroutine.
|
|
|
+func (cc *clientConn) readLoop() {
|
|
|
+ defer close(cc.readerDone)
|
|
|
+
|
|
|
+ for {
|
|
|
+ f, err := cc.fr.ReadFrame()
|
|
|
+ if err != nil {
|
|
|
+ cc.readerErr = err
|
|
|
+ // TODO: don't log it.
|
|
|
+ log.Printf("ReadFrame: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ cs := cc.streamByID(f.Header().StreamID)
|
|
|
+
|
|
|
+ log.Printf("Read %v: %#v", f.Header(), f)
|
|
|
+ headersEnded := false
|
|
|
+ streamEnded := false
|
|
|
+ if ff, ok := f.(interface {
|
|
|
+ StreamEnded() bool
|
|
|
+ }); ok {
|
|
|
+ streamEnded = ff.StreamEnded()
|
|
|
+ }
|
|
|
+ switch f := f.(type) {
|
|
|
+ case *HeadersFrame:
|
|
|
+ cc.nextRes = make(http.Header)
|
|
|
+ cs.pr, cs.pw = io.Pipe()
|
|
|
+ cc.hdec.Write(f.HeaderBlockFragment())
|
|
|
+ headersEnded = f.HeadersEnded()
|
|
|
+ case *ContinuationFrame:
|
|
|
+ // TODO: verify stream id is the same
|
|
|
+ cc.hdec.Write(f.HeaderBlockFragment())
|
|
|
+ headersEnded = f.HeadersEnded()
|
|
|
+ case *DataFrame:
|
|
|
+ log.Printf("DATA: %q", f.Data())
|
|
|
+ cs.pw.Write(f.Data())
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ if streamEnded {
|
|
|
+ cs.pw.Close()
|
|
|
+ }
|
|
|
+ if headersEnded {
|
|
|
+ if cs == nil {
|
|
|
+ panic("couldn't find stream") // TODO be graceful
|
|
|
+ }
|
|
|
+ cs.resc <- &http.Response{
|
|
|
+ Header: cc.nextRes,
|
|
|
+ Body: cs.pr,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
|
|
|
+ log.Printf("Header field: %+v", f)
|
|
|
+ cc.nextRes.Add(http.CanonicalHeaderKey(f.Name), f.Value)
|
|
|
+}
|