Explorar o código

Add Framer type.

Invalidate frames after subsequent frames are read.
Change frame parser signature to start with the entire payload buffer.
Brad Fitzpatrick %!s(int64=11) %!d(string=hai) anos
pai
achega
465880975f
Modificáronse 2 ficheiros con 115 adicións e 84 borrados
  1. 109 64
      frame.go
  2. 6 20
      http2.go

+ 109 - 64
frame.go

@@ -9,7 +9,6 @@ import (
 	"encoding/binary"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"log"
 	"sync"
 )
@@ -83,7 +82,7 @@ func knownSetting(id SettingID) bool {
 // a frameParser parses a frame. The parser can assume that the Reader will
 // not read past the length of a frame (e.g. it acts like an io.LimitReader
 // bounded by the FrameHeader.Length)
-type frameParser func(FrameHeader, io.Reader) (Frame, error)
+type frameParser func(FrameHeader, []byte) (Frame, error)
 
 var FrameParsers = map[FrameType]frameParser{
 	FrameSettings:     parseSettingsFrame,
@@ -108,6 +107,8 @@ func (f Flags) Has(v Flags) bool {
 //
 // See http://http2.github.io/http2-spec/#FrameHeader
 type FrameHeader struct {
+	valid bool // caller can access []byte fields in the Frame
+
 	Type     FrameType
 	Flags    Flags
 	Length   uint32 // actually a uint24 max; default is uint16 max
@@ -116,6 +117,19 @@ type FrameHeader struct {
 
 func (h FrameHeader) Header() FrameHeader { return h }
 
+func (h FrameHeader) String() string {
+	return fmt.Sprintf("[FrameHeader type=%v flags=%v stream=%v len=%v]",
+		h.Type, h.Flags, h.StreamID, h.Length)
+}
+
+func (h *FrameHeader) checkValid() {
+	if !h.valid {
+		panic("Frame accessor called on non-owned Frame")
+	}
+}
+
+func (h *FrameHeader) invalidate() { h.valid = false }
+
 // frame header bytes
 var fhBytes = sync.Pool{
 	New: func() interface{} {
@@ -127,8 +141,11 @@ var fhBytes = sync.Pool{
 func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
 	bufp := fhBytes.Get().(*[]byte)
 	defer fhBytes.Put(bufp)
-	buf := *bufp
-	_, err := io.ReadFull(r, buf)
+	return readFrameHeader(*bufp, r)
+}
+
+func readFrameHeader(buf []byte, r io.Reader) (FrameHeader, error) {
+	_, err := io.ReadFull(r, buf[:frameHeaderLen])
 	if err != nil {
 		return FrameHeader{}, err
 	}
@@ -137,11 +154,57 @@ func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
 		Type:     FrameType(buf[3]),
 		Flags:    Flags(buf[4]),
 		StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
+		valid:    true,
 	}, nil
 }
 
 type Frame interface {
 	Header() FrameHeader
+	invalidate()
+}
+
+// A Framer reads and writes Frames.
+type Framer struct {
+	r         io.Reader
+	lr        io.LimitedReader
+	lastFrame Frame
+	readBuf   []byte
+
+	w io.Writer
+}
+
+// NewFramer returns a Framer that writes frames to w and reads them from r.
+func NewFramer(w io.Writer, r io.Reader) *Framer {
+	return &Framer{
+		w:       w,
+		r:       r,
+		readBuf: make([]byte, 1<<10),
+	}
+}
+
+// ReadFrame reads a single frame. The returned Frame is only valid
+// until the next call to ReadFrame.
+func (fr *Framer) ReadFrame() (Frame, error) {
+	if fr.lastFrame != nil {
+		fr.lastFrame.invalidate()
+	}
+	fh, err := readFrameHeader(fr.readBuf, fr.r)
+	if err != nil {
+		return nil, err
+	}
+	if uint32(len(fr.readBuf)) < fh.Length {
+		fr.readBuf = make([]byte, fh.Length)
+	}
+	payload := fr.readBuf[:fh.Length]
+	if _, err := io.ReadFull(fr.r, payload); err != nil {
+		return nil, err
+	}
+	f, err := typeFrameParser(fh.Type)(fh, payload)
+	if err != nil {
+		return nil, err
+	}
+	fr.lastFrame = f
+	return f, nil
 }
 
 type SettingsFrame struct {
@@ -149,7 +212,7 @@ type SettingsFrame struct {
 	Settings map[SettingID]uint32
 }
 
-func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
+func parseSettingsFrame(fh FrameHeader, p []byte) (Frame, error) {
 	if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 {
 		// When this (ACK 0x1) bit is set, the payload of the
 		// SETTINGS frame MUST be empty.  Receipt of a
@@ -176,16 +239,10 @@ func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
 	}
 	s := make(map[SettingID]uint32)
 	nSettings := int(fh.Length / 6)
-	var buf [4]byte
 	for i := 0; i < nSettings; i++ {
-		if _, err := io.ReadFull(r, buf[:2]); err != nil {
-			return nil, err
-		}
-		settingID := SettingID(binary.BigEndian.Uint16(buf[:2]))
-		if _, err := io.ReadFull(r, buf[:4]); err != nil {
-			return nil, err
-		}
-		value := binary.BigEndian.Uint32(buf[:4])
+		sbuf := p[i*6:]
+		settingID := SettingID(binary.BigEndian.Uint16(sbuf[:2]))
+		value := binary.BigEndian.Uint32(sbuf[2:4])
 		if settingID == SettingInitialWindowSize && value > (1<<31)-1 {
 			// Values above the maximum flow control window size of 2^31 - 1 MUST
 			// be treated as a connection error (Section 5.4.1) of type
@@ -205,11 +262,19 @@ func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
 
 type UnknownFrame struct {
 	FrameHeader
+	p []byte
+}
+
+// Payload returns the frame's payload (after the header).
+// It is not valid to call this method after a subsequent
+// call to Framer.ReadFrame.
+func (f *UnknownFrame) Payload() []byte {
+	f.checkValid()
+	return f.p
 }
 
-func parseUnknownFrame(fh FrameHeader, r io.Reader) (Frame, error) {
-	_, err := io.CopyN(ioutil.Discard, r, int64(fh.Length))
-	return UnknownFrame{fh}, err
+func parseUnknownFrame(fh FrameHeader, p []byte) (Frame, error) {
+	return &UnknownFrame{fh, p}, nil
 }
 
 type WindowUpdateFrame struct {
@@ -217,27 +282,14 @@ type WindowUpdateFrame struct {
 	Increment uint32
 }
 
-func parseWindowUpdateFrame(fh FrameHeader, r io.Reader) (Frame, error) {
-	if fh.Length < 4 {
+func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) {
+	if len(p) < 4 {
 		// Too short.
 		return nil, ConnectionError(ErrCodeProtocol)
 	}
-	f := WindowUpdateFrame{
+	f := &WindowUpdateFrame{
 		FrameHeader: fh,
-	}
-	var err error
-	f.Increment, err = readUint32(r)
-	if err != nil {
-		return nil, err
-	}
-	f.Increment &= 0x7fffffff // mask off high reserved bit
-
-	// Future-proof: ignore any extra length in the frame. The spec doesn't
-	// say what to do if Length is too large.
-	if fh.Length > 4 {
-		if _, err := io.CopyN(ioutil.Discard, r, int64(fh.Length-4)); err != nil {
-			return nil, err
-		}
+		Increment:   binary.BigEndian.Uint32(p[:4]) & 0x7fffffff, // mask off high reserved bit
 	}
 	return f, nil
 }
@@ -254,11 +306,16 @@ type HeaderFrame struct {
 	// also add 1 to get to spec-defined [1,256] range.
 	Weight uint8
 
-	HeaderFragBuf []byte
+	headerFragBuf []byte // not owned
+}
+
+func (f *HeaderFrame) HeaderBlockFragment() []byte {
+	f.checkValid()
+	return f.headerFragBuf
 }
 
-func parseHeadersFrame(fh FrameHeader, r io.Reader) (_ Frame, err error) {
-	hf := HeaderFrame{
+func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) {
+	hf := &HeaderFrame{
 		FrameHeader: fh,
 	}
 	if fh.StreamID == 0 {
@@ -269,53 +326,41 @@ func parseHeadersFrame(fh FrameHeader, r io.Reader) (_ Frame, err error) {
 		return nil, ConnectionError(ErrCodeProtocol)
 	}
 	var padLength uint8
-	var notHeaders int // Header Block Fragment length = fh.Length - notHeaders
 	if fh.Flags.Has(FlagHeadersPadded) {
-		notHeaders += 1
-		if padLength, err = readByte(r); err != nil {
+		if p, padLength, err = readByte(p); err != nil {
 			return
 		}
 	}
 	if fh.Flags.Has(FlagHeadersPriority) {
-		notHeaders += 5
-		v, err := readUint32(r)
+		var v uint32
+		p, v, err = readUint32(p)
 		if err != nil {
 			return nil, err
 		}
 		hf.StreamDep = v & 0x7fffffff
 		hf.ExclusiveDep = (v != hf.StreamDep) // high bit was set
-		hf.Weight, err = readByte(r)
+		p, hf.Weight, err = readByte(p)
 		if err != nil {
 			return nil, err
 		}
 	}
-	headerFragLen := int(fh.Length) - notHeaders
-	if headerFragLen <= 0 {
+	if len(p)-int(padLength) <= 0 {
 		return nil, StreamError(fh.StreamID)
 	}
-	buf := make([]byte, headerFragLen)
-	if _, err := io.ReadFull(r, buf); err != nil {
-		return nil, err
-	}
-	if _, err := io.CopyN(ioutil.Discard, r, int64(padLength)); err != nil {
-		return nil, err
-	}
-	hf.HeaderFragBuf = buf
+	hf.headerFragBuf = p[:len(p)-int(padLength)]
 	return hf, nil
 }
 
-func readByte(r io.Reader) (uint8, error) {
-	// TODO: optimize, reuse buffers
-	var buf [1]byte
-	_, err := io.ReadFull(r, buf[:1])
-	return buf[0], err
+func readByte(p []byte) (remain []byte, b byte, err error) {
+	if len(p) == 0 {
+		return nil, 0, io.ErrUnexpectedEOF
+	}
+	return p[1:], p[0], nil
 }
 
-func readUint32(r io.Reader) (uint32, error) {
-	// TODO: optimize, reuse buffers
-	var buf [4]byte
-	if _, err := io.ReadFull(r, buf[:4]); err != nil {
-		return 0, err
+func readUint32(p []byte) (remain []byte, v uint32, err error) {
+	if len(p) < 4 {
+		return nil, 0, io.ErrUnexpectedEOF
 	}
-	return binary.BigEndian.Uint32(buf[:4]), nil
+	return p[4:], binary.BigEndian.Uint32(p[:4]), nil
 }

+ 6 - 20
http2.go

@@ -14,7 +14,6 @@ import (
 	"bytes"
 	"crypto/tls"
 	"io"
-	"io/ioutil"
 	"log"
 	"net/http"
 	"sync"
@@ -41,7 +40,7 @@ type Server struct {
 }
 
 func (srv *Server) handleClientConn(hs *http.Server, c *tls.Conn, h http.Handler) {
-	cc := &clientConn{hs, c, h}
+	cc := &clientConn{hs, c, h, NewFramer(c, c)}
 	cc.serve()
 }
 
@@ -49,6 +48,7 @@ type clientConn struct {
 	hs *http.Server
 	c  *tls.Conn
 	h  http.Handler
+	fr *Framer
 }
 
 func (cc *clientConn) logf(format string, args ...interface{}) {
@@ -74,19 +74,9 @@ func (cc *clientConn) serve() {
 		return
 	}
 	log.Printf("client %v said hello", cc.c.RemoteAddr())
-	var frameReader = io.LimitedReader{
-		R: cc.c,
-	}
 	for {
-		fh, err := ReadFrameHeader(cc.c)
-		if err != nil {
-			if err != io.EOF {
-				cc.logf("error reading frame: %v", err)
-			}
-			return
-		}
-		frameReader.N = int64(fh.Length)
-		f, err := typeFrameParser(fh.Type)(fh, &frameReader)
+
+		f, err := cc.fr.ReadFrame()
 		if h2e, ok := err.(Error); ok {
 			if h2e.IsConnectionError() {
 				log.Printf("Disconnection; connection error: %v", err)
@@ -95,14 +85,10 @@ func (cc *clientConn) serve() {
 			// TODO: stream errors, etc
 		}
 		if err != nil {
-			log.Printf("Disconnection to other error: %v", err)
-			return
-		}
-		if n, _ := io.Copy(ioutil.Discard, &frameReader); n > 0 {
-			log.Printf("Frame reader for %s failed to read %d bytes", fh.Type, n)
+			log.Printf("Disconnection due to other error: %v", err)
 			return
 		}
-		log.Printf("got frame: %#v", f)
+		log.Printf("got %v: %#v", f.Header(), f)
 	}
 }