Browse Source

Start of parsing window update and headers frames.

Brad Fitzpatrick 11 years ago
parent
commit
7338cb2841
3 changed files with 175 additions and 4 deletions
  1. 8 0
      errors.go
  2. 156 2
      frame.go
  3. 11 2
      http2.go

+ 8 - 0
errors.go

@@ -56,3 +56,11 @@ var _ Error = ConnectionError(0)
 func (e ConnectionError) IsStreamError() bool     { return false }
 func (e ConnectionError) IsConnectionError() bool { return true }
 func (e ConnectionError) Error() string           { return fmt.Sprintf("connection error: %s", ErrCode(e)) }
+
+type StreamError uint32
+
+var _ Error = StreamError(0)
+
+func (e StreamError) IsStreamError() bool     { return true }
+func (e StreamError) IsConnectionError() bool { return false }
+func (e StreamError) Error() string           { return fmt.Sprintf("stream error: stream ID = %d", uint32(e)) }

+ 156 - 2
frame.go

@@ -2,6 +2,7 @@ package http2
 
 import (
 	"encoding/binary"
+	"fmt"
 	"io"
 	"io/ioutil"
 	"log"
@@ -24,8 +25,39 @@ const (
 	FrameGoAway       FrameType = 0x7
 	FrameWindowUpdate FrameType = 0x8
 	FrameContinuation FrameType = 0x9
+)
+
+var frameName = map[FrameType]string{
+	FrameData:         "DATA",
+	FrameHeaders:      "HEADERS",
+	FramePriority:     "PRIORITY",
+	FrameRSTStream:    "RST_STREAM",
+	FrameSettings:     "SETTINGS",
+	FramePushPromise:  "PUSH_PROMISE",
+	FramePing:         "PING",
+	FrameGoAway:       "GOAWAY",
+	FrameWindowUpdate: "WINDOW_UPDATE",
+	FrameContinuation: "CONTINUATION",
+}
 
+func (t FrameType) String() string {
+	if s, ok := frameName[t]; ok {
+		return s
+	}
+	return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t))
+}
+
+// Frame-specific FrameHeader flag bits.
+const (
+	// Settings Frame
 	FlagSettingsAck Flags = 0x1
+
+	// Headers Frame
+	FlagHeadersEndStream  Flags = 0x1
+	FlagHeadersEndSegment Flags = 0x2
+	FlagHeadersEndHeaders Flags = 0x4
+	FlagHeadersPadded     Flags = 0x8
+	FlagHeadersPriority   Flags = 0x20
 )
 
 type SettingID uint16
@@ -43,10 +75,15 @@ func knownSetting(id SettingID) bool {
 	return id >= 1 && id <= 4
 }
 
+// a frameParser parses a frame. The parser can assume that the Reader will
+// not read past the length of a frame (e.g. it acts like an io.LimitReader
+// bounded by the FrameHeader.Length)
 type frameParser func(FrameHeader, io.Reader) (Frame, error)
 
 var FrameParsers = map[FrameType]frameParser{
-	FrameSettings: parseFrameSettings,
+	FrameSettings:     parseSettingsFrame,
+	FrameWindowUpdate: parseWindowUpdateFrame,
+	FrameHeaders:      parseHeadersFrame,
 }
 
 func typeFrameParser(t FrameType) frameParser {
@@ -107,7 +144,7 @@ type SettingsFrame struct {
 	Settings map[SettingID]uint32
 }
 
-func parseFrameSettings(fh FrameHeader, r io.Reader) (Frame, error) {
+func parseSettingsFrame(fh FrameHeader, r io.Reader) (Frame, error) {
 	if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 {
 		// When this (ACK 0x1) bit is set, the payload of the
 		// SETTINGS frame MUST be empty.  Receipt of a
@@ -144,6 +181,12 @@ func parseFrameSettings(fh FrameHeader, r io.Reader) (Frame, error) {
 			return nil, err
 		}
 		value := binary.BigEndian.Uint32(buf[:4])
+		if settingID == SettingInitialWindowSize && value > (1<<31)-1 {
+			// Values above the maximum flow control window size of 2^31 - 1 MUST
+			// be treated as a connection error (Section 5.4.1) of type
+			// FLOW_CONTROL_ERROR.
+			return nil, ConnectionError(ErrCodeFlowControl)
+		}
 		if knownSetting(settingID) {
 			s[settingID] = value
 		}
@@ -163,3 +206,114 @@ func parseUnknownFrame(fh FrameHeader, r io.Reader) (Frame, error) {
 	_, err := io.CopyN(ioutil.Discard, r, int64(fh.Length))
 	return UnknownFrame{fh}, err
 }
+
+type WindowUpdateFrame struct {
+	FrameHeader
+	Increment uint32
+}
+
+func parseWindowUpdateFrame(fh FrameHeader, r io.Reader) (Frame, error) {
+	if fh.Length < 4 {
+		// Too short.
+		return nil, ConnectionError(ErrCodeProtocol)
+	}
+	f := WindowUpdateFrame{
+		FrameHeader: fh,
+	}
+	var err error
+	f.Increment, err = readUint32(r)
+	if err != nil {
+		return nil, err
+	}
+	f.Increment &= 0x7fffffff // mask off high reserved bit
+
+	// Future-proof: ignore any extra length in the frame. The spec doesn't
+	// say what to do if Length is too large.
+	if fh.Length > 4 {
+		if _, err := io.CopyN(ioutil.Discard, r, int64(fh.Length-4)); err != nil {
+			return nil, err
+		}
+	}
+	return f, nil
+}
+
+type HeaderFrame struct {
+	FrameHeader
+
+	// If FlagHeadersPriority:
+	ExclusiveDep bool
+	StreamDep    uint32
+
+	// Weight is [0,255]. Only valid if FrameHeader.Flags has the
+	// FlagHeadersPriority bit set, in which case the caller must
+	// also add 1 to get to spec-defined [1,256] range.
+	Weight uint8
+
+	HeaderFragBuf []byte
+}
+
+func parseHeadersFrame(fh FrameHeader, r io.Reader) (_ Frame, err error) {
+	hf := HeaderFrame{
+		FrameHeader: fh,
+	}
+	if fh.StreamID == 0 {
+		// HEADERS frames MUST be associated with a stream.  If a HEADERS frame
+		// is received whose stream identifier field is 0x0, the recipient MUST
+		// respond with a connection error (Section 5.4.1) of type
+		// PROTOCOL_ERROR.
+		return nil, ConnectionError(ErrCodeProtocol)
+	}
+	var padLength uint8
+	var notHeaders int // Header Block Fragment length = fh.Length - notHeaders
+	if fh.Flags.Has(FlagHeadersPadded) {
+		notHeaders += 1
+		if padLength, err = readByte(r); err != nil {
+			return
+		}
+	}
+	if fh.Flags.Has(FlagHeadersPriority) {
+		notHeaders += 4
+		v, err := readUint32(r)
+		if err != nil {
+			return nil, err
+		}
+		hf.StreamDep = v & 0x7fffffff
+		hf.ExclusiveDep = (v != hf.StreamDep) // high bit was set
+	}
+	if fh.Flags.Has(FlagHeadersPriority) {
+		notHeaders += 1
+		hf.Weight, err = readByte(r)
+		if err != nil {
+			return
+		}
+	}
+	headerFragLen := int(fh.Length) - notHeaders
+	if headerFragLen <= 0 {
+		return nil, StreamError(fh.StreamID)
+	}
+	buf := make([]byte, headerFragLen)
+	if _, err := io.ReadFull(r, buf); err != nil {
+		return nil, err
+	}
+	if _, err := io.CopyN(ioutil.Discard, r, int64(padLength)); err != nil {
+		return nil, err
+	}
+	hf.HeaderFragBuf = buf
+	return hf, nil
+}
+
+func readByte(r io.Reader) (uint8, error) {
+	// TODO: optimize, reuse buffers
+	var buf [1]byte
+	_, err := io.ReadFull(r, buf[:1])
+	return buf[0], err
+}
+
+func readUint32(r io.Reader) (uint32, error) {
+	// TODO: optimize, reuse buffers
+	var buf [4]byte
+	if _, err := io.ReadFull(r, buf[:4]); err != nil {
+		return 0, err
+	}
+	return binary.BigEndian.Uint32(buf[:4]), nil
+}

+ 11 - 2
http2.go

@@ -7,6 +7,7 @@ import (
 	"bytes"
 	"crypto/tls"
 	"io"
+	"io/ioutil"
 	"log"
 	"net/http"
 	"sync"
@@ -66,6 +67,9 @@ func (cc *clientConn) serve() {
 		return
 	}
 	log.Printf("client %v said hello", cc.c.RemoteAddr())
+	var frameReader = io.LimitedReader{
+		R: cc.c,
+	}
 	for {
 		fh, err := ReadFrameHeader(cc.c)
 		if err != nil {
@@ -74,7 +78,8 @@ func (cc *clientConn) serve() {
 			}
 			return
 		}
-		f, err := typeFrameParser(fh.Type)(fh, cc.c)
+		frameReader.N = int64(fh.Length)
+		f, err := typeFrameParser(fh.Type)(fh, &frameReader)
 		if h2e, ok := err.(Error); ok {
 			if h2e.IsConnectionError() {
 				log.Printf("Disconnection; connection error: %v", err)
@@ -86,7 +91,11 @@ func (cc *clientConn) serve() {
 			log.Printf("Disconnection to other error: %v", err)
 			return
 		}
-		log.Printf("read frame: %#v", f)
+		if n, _ := io.Copy(ioutil.Discard, &frameReader); n > 0 {
+			log.Printf("Frame reader for %s failed to read %d bytes", fh.Type, n)
+			return
+		}
+		log.Printf("got frame: %#v", f)
 	}
 }