Browse Source

Add Framer type.

Invalidate frames after subsequent frames are read.
Change frame parser signature to start with the entire payload buffer.
Brad Fitzpatrick 11 years ago
parent
commit
465880975f
2 changed files with 115 additions and 84 deletions
  1. 109 64
      frame.go
  2. 6 20
      http2.go

+ 109 - 64
frame.go

@@ -9,7 +9,6 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"io/ioutil"
 	"log"
 	"log"
 	"sync"
 	"sync"
 )
 )
@@ -83,7 +82,7 @@ func knownSetting(id SettingID) bool {
 // a frameParser parses a frame. The parser can assume that the Reader will
 // a frameParser parses a frame. The parser can assume that the Reader will
 // not read past the length of a frame (e.g. it acts like an io.LimitReader
 // not read past the length of a frame (e.g. it acts like an io.LimitReader
 // bounded by the FrameHeader.Length)
 // bounded by the FrameHeader.Length)
-type frameParser func(FrameHeader, io.Reader) (Frame, error)
+type frameParser func(FrameHeader, []byte) (Frame, error)
 
 
 var FrameParsers = map[FrameType]frameParser{
 var FrameParsers = map[FrameType]frameParser{
 	FrameSettings:     parseSettingsFrame,
 	FrameSettings:     parseSettingsFrame,
@@ -108,6 +107,8 @@ func (f Flags) Has(v Flags) bool {
 //
 //
 // See http://http2.github.io/http2-spec/#FrameHeader
 // See http://http2.github.io/http2-spec/#FrameHeader
 type FrameHeader struct {
 type FrameHeader struct {
+	valid bool // caller can access []byte fields in the Frame
+
 	Type     FrameType
 	Type     FrameType
 	Flags    Flags
 	Flags    Flags
 	Length   uint32 // actually a uint24 max; default is uint16 max
 	Length   uint32 // actually a uint24 max; default is uint16 max
@@ -116,6 +117,19 @@ type FrameHeader struct {
 
 
 func (h FrameHeader) Header() FrameHeader { return h }
 func (h FrameHeader) Header() FrameHeader { return h }
 
 
+func (h FrameHeader) String() string {
+	return fmt.Sprintf("[FrameHeader type=%v flags=%v stream=%v len=%v]",
+		h.Type, h.Flags, h.StreamID, h.Length)
+}
+
+func (h *FrameHeader) checkValid() {
+	if !h.valid {
+		panic("Frame accessor called on non-owned Frame")
+	}
+}
+
+func (h *FrameHeader) invalidate() { h.valid = false }
+
 // frame header bytes
 // frame header bytes
 var fhBytes = sync.Pool{
 var fhBytes = sync.Pool{
 	New: func() interface{} {
 	New: func() interface{} {
@@ -127,8 +141,11 @@ var fhBytes = sync.Pool{
 func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
 func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
 	bufp := fhBytes.Get().(*[]byte)
 	bufp := fhBytes.Get().(*[]byte)
 	defer fhBytes.Put(bufp)
 	defer fhBytes.Put(bufp)
-	buf := *bufp
-	_, err := io.ReadFull(r, buf)
+	return readFrameHeader(*bufp, r)
+}
+
+func readFrameHeader(buf []byte, r io.Reader) (FrameHeader, error) {
+	_, err := io.ReadFull(r, buf[:frameHeaderLen])
 	if err != nil {
 	if err != nil {
 		return FrameHeader{}, err
 		return FrameHeader{}, err
 	}
 	}
@@ -137,11 +154,57 @@ func ReadFrameHeader(r io.Reader) (FrameHeader, error) {
 		Type:     FrameType(buf[3]),
 		Type:     FrameType(buf[3]),
 		Flags:    Flags(buf[4]),
 		Flags:    Flags(buf[4]),
 		StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
 		StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
+		valid:    true,
 	}, nil
 	}, nil
 }
 }
 
 
 type Frame interface {
 type Frame interface {
 	Header() FrameHeader
 	Header() FrameHeader
+	invalidate()
+}
+
+// A Framer reads and writes Frames.
+type Framer struct {
+	r         io.Reader
+	lr        io.LimitedReader
+	lastFrame Frame
+	readBuf   []byte
+
+	w io.Writer
+}
+
+// NewFramer returns a Framer that writes frames to w and reads them from r.
+func NewFramer(w io.Writer, r io.Reader) *Framer {
+	return &Framer{
+		w:       w,
+		r:       r,
+		readBuf: make([]byte, 1<<10),
+	}
+}
+
+// ReadFrame reads a single frame. The returned Frame is only valid
+// until the next call to ReadFrame.
+func (fr *Framer) ReadFrame() (Frame, error) {
+	if fr.lastFrame != nil {
+		fr.lastFrame.invalidate()
+	}
+	fh, err := readFrameHeader(fr.readBuf, fr.r)
+	if err != nil {
+		return nil, err
+	}
+	if uint32(len(fr.readBuf)) < fh.Length {
+		fr.readBuf = make([]byte, fh.Length)
+	}
+	payload := fr.readBuf[:fh.Length]
+	if _, err := io.ReadFull(fr.r, payload); err != nil {
+		return nil, err
+	}
+	f, err := typeFrameParser(fh.Type)(fh, payload)
+	if err != nil {
+		return nil, err
+	}
+	fr.lastFrame = f
+	return f, nil
 }
 }
 
 
 type SettingsFrame struct {
 type SettingsFrame struct {
@@ -149,7 +212,7 @@ type SettingsFrame struct {
 	Settings map[SettingID]uint32
 	Settings map[SettingID]uint32
 }
 }
 
 
-func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
+func parseSettingsFrame(fh FrameHeader, p []byte) (Frame, error) {
 	if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 {
 	if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 {
 		// When this (ACK 0x1) bit is set, the payload of the
 		// When this (ACK 0x1) bit is set, the payload of the
 		// SETTINGS frame MUST be empty.  Receipt of a
 		// SETTINGS frame MUST be empty.  Receipt of a
@@ -176,16 +239,10 @@ func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
 	}
 	}
 	s := make(map[SettingID]uint32)
 	s := make(map[SettingID]uint32)
 	nSettings := int(fh.Length / 6)
 	nSettings := int(fh.Length / 6)
-	var buf [4]byte
 	for i := 0; i < nSettings; i++ {
 	for i := 0; i < nSettings; i++ {
-		if _, err := io.ReadFull(r, buf[:2]); err != nil {
-			return nil, err
-		}
-		settingID := SettingID(binary.BigEndian.Uint16(buf[:2]))
-		if _, err := io.ReadFull(r, buf[:4]); err != nil {
-			return nil, err
-		}
-		value := binary.BigEndian.Uint32(buf[:4])
+		sbuf := p[i*6:]
+		settingID := SettingID(binary.BigEndian.Uint16(sbuf[:2]))
+		value := binary.BigEndian.Uint32(sbuf[2:4])
 		if settingID == SettingInitialWindowSize && value > (1<<31)-1 {
 		if settingID == SettingInitialWindowSize && value > (1<<31)-1 {
 			// Values above the maximum flow control window size of 2^31 - 1 MUST
 			// Values above the maximum flow control window size of 2^31 - 1 MUST
 			// be treated as a connection error (Section 5.4.1) of type
 			// be treated as a connection error (Section 5.4.1) of type
@@ -205,11 +262,19 @@ func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
 
 
 type UnknownFrame struct {
 type UnknownFrame struct {
 	FrameHeader
 	FrameHeader
+	p []byte
+}
+
+// Payload returns the frame's payload (after the header).
+// It is not valid to call this method after a subsequent
+// call to Framer.ReadFrame.
+func (f *UnknownFrame) Payload() []byte {
+	f.checkValid()
+	return f.p
 }
 }
 
 
-func parseUnknownFrame(fh FrameHeader, r io.Reader) (Frame, error) {
-	_, err := io.CopyN(ioutil.Discard, r, int64(fh.Length))
-	return UnknownFrame{fh}, err
+func parseUnknownFrame(fh FrameHeader, p []byte) (Frame, error) {
+	return &UnknownFrame{fh, p}, nil
 }
 }
 
 
 type WindowUpdateFrame struct {
 type WindowUpdateFrame struct {
@@ -217,27 +282,14 @@ type WindowUpdateFrame struct {
 	Increment uint32
 	Increment uint32
 }
 }
 
 
-func parseWindowUpdateFrame(fh FrameHeader, r io.Reader) (Frame, error) {
-	if fh.Length < 4 {
+func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) {
+	if len(p) < 4 {
 		// Too short.
 		// Too short.
 		return nil, ConnectionError(ErrCodeProtocol)
 		return nil, ConnectionError(ErrCodeProtocol)
 	}
 	}
-	f := WindowUpdateFrame{
+	f := &WindowUpdateFrame{
 		FrameHeader: fh,
 		FrameHeader: fh,
-	}
-	var err error
-	f.Increment, err = readUint32(r)
-	if err != nil {
-		return nil, err
-	}
-	f.Increment &= 0x7fffffff // mask off high reserved bit
-
-	// Future-proof: ignore any extra length in the frame. The spec doesn't
-	// say what to do if Length is too large.
-	if fh.Length > 4 {
-		if _, err := io.CopyN(ioutil.Discard, r, int64(fh.Length-4)); err != nil {
-			return nil, err
-		}
+		Increment:   binary.BigEndian.Uint32(p[:4]) & 0x7fffffff, // mask off high reserved bit
 	}
 	}
 	return f, nil
 	return f, nil
 }
 }
@@ -254,11 +306,16 @@ type HeaderFrame struct {
 	// also add 1 to get to spec-defined [1,256] range.
 	// also add 1 to get to spec-defined [1,256] range.
 	Weight uint8
 	Weight uint8
 
 
-	HeaderFragBuf []byte
+	headerFragBuf []byte // not owned
+}
+
+func (f *HeaderFrame) HeaderBlockFragment() []byte {
+	f.checkValid()
+	return f.headerFragBuf
 }
 }
 
 
-func parseHeadersFrame(fh FrameHeader, r io.Reader) (_ Frame, err error) {
-	hf := HeaderFrame{
+func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) {
+	hf := &HeaderFrame{
 		FrameHeader: fh,
 		FrameHeader: fh,
 	}
 	}
 	if fh.StreamID == 0 {
 	if fh.StreamID == 0 {
@@ -269,53 +326,41 @@ func parseHeadersFrame(fh FrameHeader, r io.Reader) (_ Frame, err error) {
 		return nil, ConnectionError(ErrCodeProtocol)
 		return nil, ConnectionError(ErrCodeProtocol)
 	}
 	}
 	var padLength uint8
 	var padLength uint8
-	var notHeaders int // Header Block Fragment length = fh.Length - notHeaders
 	if fh.Flags.Has(FlagHeadersPadded) {
 	if fh.Flags.Has(FlagHeadersPadded) {
-		notHeaders += 1
-		if padLength, err = readByte(r); err != nil {
+		if p, padLength, err = readByte(p); err != nil {
 			return
 			return
 		}
 		}
 	}
 	}
 	if fh.Flags.Has(FlagHeadersPriority) {
 	if fh.Flags.Has(FlagHeadersPriority) {
-		notHeaders += 5
-		v, err := readUint32(r)
+		var v uint32
+		p, v, err = readUint32(p)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 		hf.StreamDep = v & 0x7fffffff
 		hf.StreamDep = v & 0x7fffffff
 		hf.ExclusiveDep = (v != hf.StreamDep) // high bit was set
 		hf.ExclusiveDep = (v != hf.StreamDep) // high bit was set
-		hf.Weight, err = readByte(r)
+		p, hf.Weight, err = readByte(p)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 	}
 	}
-	headerFragLen := int(fh.Length) - notHeaders
-	if headerFragLen <= 0 {
+	if len(p)-int(padLength) <= 0 {
 		return nil, StreamError(fh.StreamID)
 		return nil, StreamError(fh.StreamID)
 	}
 	}
-	buf := make([]byte, headerFragLen)
-	if _, err := io.ReadFull(r, buf); err != nil {
-		return nil, err
-	}
-	if _, err := io.CopyN(ioutil.Discard, r, int64(padLength)); err != nil {
-		return nil, err
-	}
-	hf.HeaderFragBuf = buf
+	hf.headerFragBuf = p[:len(p)-int(padLength)]
 	return hf, nil
 	return hf, nil
 }
 }
 
 
-func readByte(r io.Reader) (uint8, error) {
-	// TODO: optimize, reuse buffers
-	var buf [1]byte
-	_, err := io.ReadFull(r, buf[:1])
-	return buf[0], err
+func readByte(p []byte) (remain []byte, b byte, err error) {
+	if len(p) == 0 {
+		return nil, 0, io.ErrUnexpectedEOF
+	}
+	return p[1:], p[0], nil
 }
 }
 
 
-func readUint32(r io.Reader) (uint32, error) {
-	// TODO: optimize, reuse buffers
-	var buf [4]byte
-	if _, err := io.ReadFull(r, buf[:4]); err != nil {
-		return 0, err
+func readUint32(p []byte) (remain []byte, v uint32, err error) {
+	if len(p) < 4 {
+		return nil, 0, io.ErrUnexpectedEOF
 	}
 	}
-	return binary.BigEndian.Uint32(buf[:4]), nil
+	return p[4:], binary.BigEndian.Uint32(p[:4]), nil
 }
 }

+ 6 - 20
http2.go

@@ -14,7 +14,6 @@ import (
 	"bytes"
 	"bytes"
 	"crypto/tls"
 	"crypto/tls"
 	"io"
 	"io"
-	"io/ioutil"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
 	"sync"
 	"sync"
@@ -41,7 +40,7 @@ type Server struct {
 }
 }
 
 
 func (srv *Server) handleClientConn(hs *http.Server, c *tls.Conn, h http.Handler) {
 func (srv *Server) handleClientConn(hs *http.Server, c *tls.Conn, h http.Handler) {
-	cc := &clientConn{hs, c, h}
+	cc := &clientConn{hs, c, h, NewFramer(c, c)}
 	cc.serve()
 	cc.serve()
 }
 }
 
 
@@ -49,6 +48,7 @@ type clientConn struct {
 	hs *http.Server
 	hs *http.Server
 	c  *tls.Conn
 	c  *tls.Conn
 	h  http.Handler
 	h  http.Handler
+	fr *Framer
 }
 }
 
 
 func (cc *clientConn) logf(format string, args ...interface{}) {
 func (cc *clientConn) logf(format string, args ...interface{}) {
@@ -74,19 +74,9 @@ func (cc *clientConn) serve() {
 		return
 		return
 	}
 	}
 	log.Printf("client %v said hello", cc.c.RemoteAddr())
 	log.Printf("client %v said hello", cc.c.RemoteAddr())
-	var frameReader = io.LimitedReader{
-		R: cc.c,
-	}
 	for {
 	for {
-		fh, err := ReadFrameHeader(cc.c)
-		if err != nil {
-			if err != io.EOF {
-				cc.logf("error reading frame: %v", err)
-			}
-			return
-		}
-		frameReader.N = int64(fh.Length)
-		f, err := typeFrameParser(fh.Type)(fh, &frameReader)
+
+		f, err := cc.fr.ReadFrame()
 		if h2e, ok := err.(Error); ok {
 		if h2e, ok := err.(Error); ok {
 			if h2e.IsConnectionError() {
 			if h2e.IsConnectionError() {
 				log.Printf("Disconnection; connection error: %v", err)
 				log.Printf("Disconnection; connection error: %v", err)
@@ -95,14 +85,10 @@ func (cc *clientConn) serve() {
 			// TODO: stream errors, etc
 			// TODO: stream errors, etc
 		}
 		}
 		if err != nil {
 		if err != nil {
-			log.Printf("Disconnection to other error: %v", err)
-			return
-		}
-		if n, _ := io.Copy(ioutil.Discard, &frameReader); n > 0 {
-			log.Printf("Frame reader for %s failed to read %d bytes", fh.Type, n)
+			log.Printf("Disconnection due to other error: %v", err)
 			return
 			return
 		}
 		}
-		log.Printf("got frame: %#v", f)
+		log.Printf("got %v: %#v", f.Header(), f)
 	}
 	}
 }
 }