Browse Source

Add Framer type.

Invalidate frames after subsequent frames are read.
Change frame parser signature to start with the entire payload buffer.
Brad Fitzpatrick 11 years ago
parent
commit
465880975f
2 changed files with 115 additions and 84 deletions
  1. 109 64
      frame.go
  2. 6 20
      http2.go

+ 109 - 64
frame.go

@@ -9,7 +9,6 @@ import (
 	"encoding/binary"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"log"
 	"sync"
 )
@@ -83,7 +82,7 @@ func knownSetting(id SettingID) bool {
 // a frameParser parses a frame. The parser can assume that the Reader will
 // not read past the length of a frame (e.g. it acts like an io.LimitReader
 // bounded by the FrameHeader.Length)
-type frameParser func(FrameHeader, io.Reader) (Frame, error)
+type frameParser func(FrameHeader, []byte) (Frame, error)
 
 var FrameParsers = map[FrameType]frameParser{
 	FrameSettings:     parseSettingsFrame,
@@ -108,6 +107,8 @@ func (f Flags) Has(v Flags) bool {
 //
 // See http://http2.github.io/http2-spec/#FrameHeader
 type FrameHeader struct {
+	valid bool // caller can access []byte fields in the Frame
+
 	Type     FrameType
 	Flags    Flags
 	Length   uint32 // actually a uint24 max; default is uint16 max
@@ -116,6 +117,19 @@ type FrameHeader struct {
 
 func (h FrameHeader) Header() FrameHeader { return h }
 
+func (h FrameHeader) String() string {
+	return fmt.Sprintf("[FrameHeader type=%v flags=%v stream=%v len=%v]",
+		h.Type, h.Flags, h.StreamID, h.Length)
+}
+
+func (h *FrameHeader) checkValid() {
+	if !h.valid {
+		panic("Frame accessor called on non-owned Frame")
+	}
+}
+
+func (h *FrameHeader) invalidate() { h.valid = false }
+
 // frame header bytes
 var fhBytes = sync.Pool{
 	New: func() interface{} {
@@ -127,8 +141,11 @@ var fhBytes = sync.Pool{
 func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
 	bufp := fhBytes.Get().(*[]byte)
 	defer fhBytes.Put(bufp)
-	buf := *bufp
-	_, err := io.ReadFull(r, buf)
+	return readFrameHeader(*bufp, r)
+}
+
+func readFrameHeader(buf []byte, r io.Reader) (FrameHeader, error) {
+	_, err := io.ReadFull(r, buf[:frameHeaderLen])
 	if err != nil {
 		return FrameHeader{}, err
 	}
@@ -137,11 +154,57 @@ func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
 		Type:     FrameType(buf[3]),
 		Flags:    Flags(buf[4]),
 		StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
+		valid:    true,
 	}, nil
 }
 
 type Frame interface {
 	Header() FrameHeader
+	invalidate()
+}
+
+// A Framer reads and writes Frames.
+type Framer struct {
+	r         io.Reader
+	lr        io.LimitedReader
+	lastFrame Frame
+	readBuf   []byte
+
+	w io.Writer
+}
+
+// NewFramer returns a Framer that writes frames to w and reads them from r.
+func NewFramer(w io.Writer, r io.Reader) *Framer {
+	return &Framer{
+		w:       w,
+		r:       r,
+		readBuf: make([]byte, 1<<10),
+	}
+}
+
+// ReadFrame reads a single frame. The returned Frame is only valid
+// until the next call to ReadFrame.
+func (fr *Framer) ReadFrame() (Frame, error) {
+	if fr.lastFrame != nil {
+		fr.lastFrame.invalidate()
+	}
+	fh, err := readFrameHeader(fr.readBuf, fr.r)
+	if err != nil {
+		return nil, err
+	}
+	if uint32(len(fr.readBuf)) < fh.Length {
+		fr.readBuf = make([]byte, fh.Length)
+	}
+	payload := fr.readBuf[:fh.Length]
+	if _, err := io.ReadFull(fr.r, payload); err != nil {
+		return nil, err
+	}
+	f, err := typeFrameParser(fh.Type)(fh, payload)
+	if err != nil {
+		return nil, err
+	}
+	fr.lastFrame = f
+	return f, nil
 }
 
 type SettingsFrame struct {
@@ -149,7 +212,7 @@ type SettingsFrame struct {
 	Settings map[SettingID]uint32
 }
 
-func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
+func parseSettingsFrame(fh FrameHeader, p []byte) (Frame, error) {
 	if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 {
 		// When this (ACK 0x1) bit is set, the payload of the
 		// SETTINGS frame MUST be empty.  Receipt of a
@@ -176,16 +239,10 @@ func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
 	}
 	s := make(map[SettingID]uint32)
 	nSettings := int(fh.Length / 6)
-	var buf [4]byte
 	for i := 0; i < nSettings; i++ {
-		if _, err := io.ReadFull(r, buf[:2]); err != nil {
-			return nil, err
-		}
-		settingID := SettingID(binary.BigEndian.Uint16(buf[:2]))
-		if _, err := io.ReadFull(r, buf[:4]); err != nil {
-			return nil, err
-		}
-		value := binary.BigEndian.Uint32(buf[:4])
+		sbuf := p[i*6:]
+		settingID := SettingID(binary.BigEndian.Uint16(sbuf[:2]))
+		value := binary.BigEndian.Uint32(sbuf[2:4])
 		if settingID == SettingInitialWindowSize && value > (1<<31)-1 {
 			// Values above the maximum flow control window size of 2^31 - 1 MUST
 			// be treated as a connection error (Section 5.4.1) of type
@@ -205,11 +262,19 @@ func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
 
 type UnknownFrame struct {
 	FrameHeader
+	p []byte
+}
+
+// Payload returns the frame's payload (after the header).
+// It is not valid to call this method after a subsequent
+// call to Framer.ReadFrame.
+func (f *UnknownFrame) Payload() []byte {
+	f.checkValid()
+	return f.p
 }
 
-func parseUnknownFrame(fh FrameHeader, r io.Reader) (Frame, error) {
-	_, err := io.CopyN(ioutil.Discard, r, int64(fh.Length))
-	return UnknownFrame{fh}, err
+func parseUnknownFrame(fh FrameHeader, p []byte) (Frame, error) {
+	return &UnknownFrame{fh, p}, nil
 }
 
 type WindowUpdateFrame struct {
@@ -217,27 +282,14 @@ type WindowUpdateFrame struct {
 	Increment uint32
 }
 
-func parseWindowUpdateFrame(fh FrameHeader, r io.Reader) (Frame, error) {
-	if fh.Length < 4 {
+func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) {
+	if len(p) < 4 {
 		// Too short.
 		return nil, ConnectionError(ErrCodeProtocol)
 	}
-	f := WindowUpdateFrame{
+	f := &WindowUpdateFrame{
 		FrameHeader: fh,
-	}
-	var err error
-	f.Increment, err = readUint32(r)
-	if err != nil {
-		return nil, err
-	}
-	f.Increment &= 0x7fffffff // mask off high reserved bit
-
-	// Future-proof: ignore any extra length in the frame. The spec doesn't
-	// say what to do if Length is too large.
-	if fh.Length > 4 {
-		if _, err := io.CopyN(ioutil.Discard, r, int64(fh.Length-4)); err != nil {
-			return nil, err
-		}
+		Increment:   binary.BigEndian.Uint32(p[:4]) & 0x7fffffff, // mask off high reserved bit
 	}
 	return f, nil
 }
@@ -254,11 +306,16 @@ type HeaderFrame struct {
 	// also add 1 to get to spec-defined [1,256] range.
 	Weight uint8
 
-	HeaderFragBuf []byte
+	headerFragBuf []byte // not owned
+}
+
+func (f *HeaderFrame) HeaderBlockFragment() []byte {
+	f.checkValid()
+	return f.headerFragBuf
 }
 
-func parseHeadersFrame(fh FrameHeader, r io.Reader) (_ Frame, err error) {
-	hf := HeaderFrame{
+func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) {
+	hf := &HeaderFrame{
 		FrameHeader: fh,
 	}
 	if fh.StreamID == 0 {
@@ -269,53 +326,41 @@ func parseHeadersFrame(fh FrameHeader, r io.Reader) (_ Frame, err error) {
 		return nil, ConnectionError(ErrCodeProtocol)
 	}
 	var padLength uint8
-	var notHeaders int // Header Block Fragment length = fh.Length - notHeaders
 	if fh.Flags.Has(FlagHeadersPadded) {
-		notHeaders += 1
-		if padLength, err = readByte(r); err != nil {
+		if p, padLength, err = readByte(p); err != nil {
 			return
 		}
 	}
 	if fh.Flags.Has(FlagHeadersPriority) {
-		notHeaders += 5
-		v, err := readUint32(r)
+		var v uint32
+		p, v, err = readUint32(p)
 		if err != nil {
 			return nil, err
 		}
 		hf.StreamDep = v & 0x7fffffff
 		hf.ExclusiveDep = (v != hf.StreamDep) // high bit was set
-		hf.Weight, err = readByte(r)
+		p, hf.Weight, err = readByte(p)
 		if err != nil {
 			return nil, err
 		}
 	}
-	headerFragLen := int(fh.Length) - notHeaders
-	if headerFragLen <= 0 {
+	if len(p)-int(padLength) <= 0 {
 		return nil, StreamError(fh.StreamID)
 	}
-	buf := make([]byte, headerFragLen)
-	if _, err := io.ReadFull(r, buf); err != nil {
-		return nil, err
-	}
-	if _, err := io.CopyN(ioutil.Discard, r, int64(padLength)); err != nil {
-		return nil, err
-	}
-	hf.HeaderFragBuf = buf
+	hf.headerFragBuf = p[:len(p)-int(padLength)]
 	return hf, nil
 }
 
-func readByte(r io.Reader) (uint8, error) {
-	// TODO: optimize, reuse buffers
-	var buf [1]byte
-	_, err := io.ReadFull(r, buf[:1])
-	return buf[0], err
+func readByte(p []byte) (remain []byte, b byte, err error) {
+	if len(p) == 0 {
+		return nil, 0, io.ErrUnexpectedEOF
+	}
+	return p[1:], p[0], nil
 }
 
-func readUint32(r io.Reader) (uint32, error) {
-	// TODO: optimize, reuse buffers
-	var buf [4]byte
-	if _, err := io.ReadFull(r, buf[:4]); err != nil {
-		return 0, err
+func readUint32(p []byte) (remain []byte, v uint32, err error) {
+	if len(p) < 4 {
+		return nil, 0, io.ErrUnexpectedEOF
 	}
-	return binary.BigEndian.Uint32(buf[:4]), nil
+	return p[4:], binary.BigEndian.Uint32(p[:4]), nil
 }

+ 6 - 20
http2.go

@@ -14,7 +14,6 @@ import (
 	"bytes"
 	"crypto/tls"
 	"io"
-	"io/ioutil"
 	"log"
 	"net/http"
 	"sync"
@@ -41,7 +40,7 @@ type Server struct {
 }
 
 func (srv *Server) handleClientConn(hs *http.Server, c *tls.Conn, h http.Handler) {
-	cc := &clientConn{hs, c, h}
+	cc := &clientConn{hs, c, h, NewFramer(c, c)}
 	cc.serve()
 }
 
@@ -49,6 +48,7 @@ type clientConn struct {
 	hs *http.Server
 	c  *tls.Conn
 	h  http.Handler
+	fr *Framer
 }
 
 func (cc *clientConn) logf(format string, args ...interface{}) {
@@ -74,19 +74,9 @@ func (cc *clientConn) serve() {
 		return
 	}
 	log.Printf("client %v said hello", cc.c.RemoteAddr())
-	var frameReader = io.LimitedReader{
-		R: cc.c,
-	}
 	for {
-		fh, err := ReadFrameHeader(cc.c)
-		if err != nil {
-			if err != io.EOF {
-				cc.logf("error reading frame: %v", err)
-			}
-			return
-		}
-		frameReader.N = int64(fh.Length)
-		f, err := typeFrameParser(fh.Type)(fh, &frameReader)
+
+		f, err := cc.fr.ReadFrame()
 		if h2e, ok := err.(Error); ok {
 			if h2e.IsConnectionError() {
 				log.Printf("Disconnection; connection error: %v", err)
@@ -95,14 +85,10 @@ func (cc *clientConn) serve() {
 			// TODO: stream errors, etc
 		}
 		if err != nil {
-			log.Printf("Disconnection to other error: %v", err)
-			return
-		}
-		if n, _ := io.Copy(ioutil.Discard, &frameReader); n > 0 {
-			log.Printf("Frame reader for %s failed to read %d bytes", fh.Type, n)
+			log.Printf("Disconnection due to other error: %v", err)
 			return
 		}
-		log.Printf("got frame: %#v", f)
+		log.Printf("got %v: %#v", f.Header(), f)
 	}
 }