Browse Source

decode all the representations.

basic tests from the spec, but not enough yet.
Brad Fitzpatrick 11 years ago
parent
commit
271491bd59
2 changed files with 253 additions and 64 deletions
  1. 218 52
      hpack/hpack.go
  2. 35 12
      hpack/hpack_test.go

+ 218 - 52
hpack/hpack.go

@@ -10,9 +10,9 @@
 package hpack
 package hpack
 
 
 import (
 import (
+	"bytes"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"io"
 )
 )
 
 
 // A DecodingError is something the spec defines as a decoding error.
 // A DecodingError is something the spec defines as a decoding error.
@@ -59,12 +59,20 @@ func (hf *HeaderField) size() uint32 {
 type Decoder struct {
 type Decoder struct {
 	dynTab dynamicTable
 	dynTab dynamicTable
 	emit   func(f HeaderField, sensitive bool)
 	emit   func(f HeaderField, sensitive bool)
+
+	// buf is the unparsed buffer. It's only written to
+	// saveBuf if it was truncated in the middle of a header
+	// block. Because it's usually not owned, we can only
+	// process it under Write.
+	buf     []byte // usually not owned
+	saveBuf bytes.Buffer
 }
 }
 
 
 func NewDecoder(maxSize uint32, emitFunc func(f HeaderField, sensitive bool)) *Decoder {
 func NewDecoder(maxSize uint32, emitFunc func(f HeaderField, sensitive bool)) *Decoder {
 	d := &Decoder{
 	d := &Decoder{
 		emit: emitFunc,
 		emit: emitFunc,
 	}
 	}
+	d.dynTab.allowedMaxSize = maxSize
 	d.dynTab.setMaxSize(maxSize)
 	d.dynTab.setMaxSize(maxSize)
 	return d
 	return d
 }
 }
@@ -76,14 +84,22 @@ func (d *Decoder) SetMaxDynamicTableSize(v uint32) {
 	d.dynTab.setMaxSize(v)
 	d.dynTab.setMaxSize(v)
 }
 }
 
 
+// SetAllowedMaxDynamicTableSize sets the upper bound that the encoded
+// stream (via dynamic table size updates) may set the maximum size
+// to.
+func (d *Decoder) SetAllowedMaxDynamicTableSize(v uint32) {
+	d.dynTab.allowedMaxSize = v
+}
+
 type dynamicTable struct {
 type dynamicTable struct {
 	// s is the FIFO described at
 	// s is the FIFO described at
 	// http://http2.github.io/http2-spec/compression.html#rfc.section.2.3.2
 	// http://http2.github.io/http2-spec/compression.html#rfc.section.2.3.2
 	// The newest (low index) is append at the end, and items are
 	// The newest (low index) is append at the end, and items are
 	// evicted from the front.
 	// evicted from the front.
-	ents    []HeaderField
-	size    uint32
-	maxSize uint32
+	ents           []HeaderField
+	size           uint32
+	maxSize        uint32 // current maxSize
+	allowedMaxSize uint32 // maxSize may go up to this, inclusive
 }
 }
 
 
 func (dt *dynamicTable) setMaxSize(v uint32) {
 func (dt *dynamicTable) setMaxSize(v uint32) {
@@ -125,64 +141,194 @@ func (d *Decoder) maxTableIndex() int {
 	return len(d.dynTab.ents) + len(staticTable)
 	return len(d.dynTab.ents) + len(staticTable)
 }
 }
 
 
-func (d *Decoder) at(i int) (hf HeaderField, ok bool) {
+func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) {
 	if i < 1 {
 	if i < 1 {
 		return
 		return
 	}
 	}
-	if i > d.maxTableIndex() {
+	if i > uint64(d.maxTableIndex()) {
 		return
 		return
 	}
 	}
-	if i <= len(staticTable) {
+	if i <= uint64(len(staticTable)) {
 		return staticTable[i-1], true
 		return staticTable[i-1], true
 	}
 	}
 	dents := d.dynTab.ents
 	dents := d.dynTab.ents
-	return dents[len(dents)-(i-len(staticTable))], true
+	return dents[len(dents)-(int(i)-len(staticTable))], true
 }
 }
 
 
 // Decode decodes an entire block.
 // Decode decodes an entire block.
 //
 //
 // TODO: remove this method and make it incremental later? This is
 // TODO: remove this method and make it incremental later? This is
 // easier for debugging now.
 // easier for debugging now.
-func (d *Decoder) Decode(p []byte) ([]HeaderField, error) {
+func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
 	var hf []HeaderField
 	var hf []HeaderField
-	// TODO: This is trashy. temporary development aid.
 	saveFunc := d.emit
 	saveFunc := d.emit
 	defer func() { d.emit = saveFunc }()
 	defer func() { d.emit = saveFunc }()
 	d.emit = func(f HeaderField, sensitive bool) {
 	d.emit = func(f HeaderField, sensitive bool) {
 		hf = append(hf, f)
 		hf = append(hf, f)
 	}
 	}
+	if _, err := d.Write(p); err != nil {
+		return nil, err
+	}
+	if err := d.Close(); err != nil {
+		return nil, err
+	}
+	return hf, nil
+}
 
 
-	for len(p) > 0 {
-		// Look at first byte to see what we're dealing with.
-		switch {
-		case p[0]&(1<<7) != 0:
-			// Indexed representation.
-			// http://http2.github.io/http2-spec/compression.html#rfc.section.6.1
-			idx, size, err := readVarInt(7, p)
-			if err != nil {
-				return nil, err
-			}
-			if size == 0 {
-				// TODO: will later stop processing
-				// here and wait for more (buffering
-				// what we've got), but this is the
-				// all-at-once Decode debug version.
-				return nil, io.ErrUnexpectedEOF
-			}
-			if idx > uint64(d.maxTableIndex()) {
-				return nil, DecodingError{InvalidIndexError(idx)}
-			}
-			hf, ok := d.at(int(idx))
-			if !ok {
-				return nil, DecodingError{InvalidIndexError(idx)}
+func (d *Decoder) Close() error {
+	if d.saveBuf.Len() > 0 {
+		d.saveBuf.Reset()
+		return DecodingError{errors.New("truncated headers")}
+	}
+	return nil
+}
+
+func (d *Decoder) Write(p []byte) (n int, err error) {
+	if len(p) == 0 {
+		// Prevent state machine CPU attacks (making us redo
+		// work up to the point of finding out we don't have
+		// enough data)
+		return
+	}
+	// Only copy the data if we have to. Optimistically assume
+	// that p will contain a complete header block.
+	if d.saveBuf.Len() == 0 {
+		d.buf = p
+	} else {
+		d.saveBuf.Write(p)
+		d.buf = d.saveBuf.Bytes()
+		d.saveBuf.Reset()
+	}
+
+	for len(d.buf) > 0 {
+		err = d.parseHeaderFieldRepr()
+		if err != nil {
+			if err == errNeedMore {
+				err = nil
+				d.saveBuf.Write(d.buf)
 			}
 			}
-			d.emit(hf, false /* TODO: sensitive ? */)
-			p = p[size:]
-		default:
-			panic("TODO")
+			break
 		}
 		}
 	}
 	}
-	return hf, nil
+
+	return len(p), err
+}
+
+// errNeedMore is an internal sentinel error value that means the
+// buffer is truncated and we need to read more data before we can
+// continue parsing.
+var errNeedMore = errors.New("need more data")
+
+type indexType int
+
+const (
+	indexedTrue indexType = iota
+	indexedFalse
+	indexedNever
+)
+
+func (v indexType) indexed() bool   { return v == indexedTrue }
+func (v indexType) sensitive() bool { return v == indexedNever }
+
+// returns errNeedMore if there isn't enough data available.
+// any other error is atal.
+// consumes d.buf iff it returns nil.
+// precondition: must be called with len(d.buf) > 0
+func (d *Decoder) parseHeaderFieldRepr() error {
+	b := d.buf[0]
+	switch {
+	case b&128 != 0:
+		// Indexed representation.
+		// High bit set?
+		// http://http2.github.io/http2-spec/compression.html#rfc.section.6.1
+		return d.parseFieldIndexed()
+	case b&192 == 64:
+		// 6.2.1 Literal Header Field with Incremental Indexing
+		// 0b10xxxxxx: top two bits are 10
+		// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.1
+		return d.parseFieldLiteral(6, indexedTrue)
+	case b&240 == 0:
+		// 6.2.2 Literal Header Field without Indexing
+		// 0b0000xxxx: top four bits are 0000
+		// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.2
+		return d.parseFieldLiteral(4, indexedFalse)
+	case b&240 == 16:
+		// 6.2.3 Literal Header Field never Indexed
+		// 0b0001xxxx: top four bits are 0001
+		// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.3
+		return d.parseFieldLiteral(4, indexedNever)
+	case b&224 == 32:
+		// 6.3 Dynamic Table Size Update
+		// Top three bits are '001'.
+		// http://http2.github.io/http2-spec/compression.html#rfc.section.6.3
+		return d.parseDynamicTableSizeUpdate()
+	}
+
+	return DecodingError{errors.New("invalid encoding")}
+}
+
+// (same invariants and behavior as parseHeaderFieldRepr)
+func (d *Decoder) parseFieldIndexed() error {
+	buf := d.buf
+	idx, buf, err := readVarInt(7, buf)
+	if err != nil {
+		return err
+	}
+	hf, ok := d.at(idx)
+	if !ok {
+		return DecodingError{InvalidIndexError(idx)}
+	}
+	d.emit(hf, false)
+	d.buf = buf
+	return nil
+}
+
+// (same invariants and behavior as parseHeaderFieldRepr)
+func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
+	buf := d.buf
+	nameIdx, buf, err := readVarInt(n, buf)
+	if err != nil {
+		return err
+	}
+
+	var hf HeaderField
+	if nameIdx > 0 {
+		ihf, ok := d.at(nameIdx)
+		if !ok {
+			return DecodingError{InvalidIndexError(nameIdx)}
+		}
+		hf.Name = ihf.Name
+	} else {
+		hf.Name, buf, err = readString(buf)
+		if err != nil {
+			return err
+		}
+	}
+	hf.Value, buf, err = readString(buf)
+	if err != nil {
+		return err
+	}
+	d.buf = buf
+	if it.indexed() {
+		d.dynTab.add(hf)
+	}
+	d.emit(hf, it.sensitive())
+	return nil
+}
+
+// (same invariants and behavior as parseHeaderFieldRepr)
+func (d *Decoder) parseDynamicTableSizeUpdate() error {
+	buf := d.buf
+	size, buf, err := readVarInt(5, buf)
+	if err != nil {
+		return err
+	}
+	if size > uint64(d.dynTab.allowedMaxSize) {
+		return DecodingError{errors.New("dynamic table size update too large")}
+	}
+	d.dynTab.setMaxSize(uint32(size))
+	d.buf = buf
+	return nil
 }
 }
 
 
 var errVarintOverflow = DecodingError{errors.New("varint integer overflow")}
 var errVarintOverflow = DecodingError{errors.New("varint integer overflow")}
@@ -193,41 +339,61 @@ var errVarintOverflow = DecodingError{errors.New("varint integer overflow")}
 //
 //
 // n must always be between 1 and 8.
 // n must always be between 1 and 8.
 //
 //
-// The returned consumed parameter is the number of bytes that were
-// consumed from the beginning of p. It is zero if the end of the
-// integer's representation wasn't included in p. (In this case,
-// callers should wait for more data to arrive and try again with a
-// larger p buffer).
-func readVarInt(n byte, p []byte) (i uint64, consumed int, err error) {
+// The returned remain buffer is either a smaller suffix of p, or err != nil.
+// The error is errNeedMore if p doesn't contain a complete integer.
+func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
 	if n < 1 || n > 8 {
 	if n < 1 || n > 8 {
 		panic("bad n")
 		panic("bad n")
 	}
 	}
 	if len(p) == 0 {
 	if len(p) == 0 {
-		return
+		return 0, p, errNeedMore
 	}
 	}
 	i = uint64(p[0])
 	i = uint64(p[0])
 	if n < 8 {
 	if n < 8 {
 		i &= (1 << uint64(n)) - 1
 		i &= (1 << uint64(n)) - 1
 	}
 	}
 	if i < (1<<uint64(n))-1 {
 	if i < (1<<uint64(n))-1 {
-		return i, 1, nil
+		return i, p[1:], nil
 	}
 	}
 
 
+	origP := p
 	p = p[1:]
 	p = p[1:]
-	consumed++
 	var m uint64
 	var m uint64
 	for len(p) > 0 {
 	for len(p) > 0 {
 		b := p[0]
 		b := p[0]
-		consumed++
+		p = p[1:]
 		i += uint64(b&127) << m
 		i += uint64(b&127) << m
 		if b&128 == 0 {
 		if b&128 == 0 {
-			return
+			return i, p, nil
 		}
 		}
-		p = p[1:]
 		m += 7
 		m += 7
 		if m >= 63 { // TODO: proper overflow check. making this up.
 		if m >= 63 { // TODO: proper overflow check. making this up.
-			return 0, 0, errVarintOverflow
+			return 0, origP, errVarintOverflow
 		}
 		}
 	}
 	}
-	return 0, 0, nil
+	return 0, origP, errNeedMore
+}
+
+func readString(p []byte) (s string, remain []byte, err error) {
+	if len(p) == 0 {
+		return "", p, errNeedMore
+	}
+	isHuff := p[0]&128 != 0
+	strLen, p, err := readVarInt(7, p)
+	if err != nil {
+		return "", p, err
+	}
+	if uint64(len(p)) < strLen {
+		return "", p, errNeedMore
+	}
+	if !isHuff {
+		return string(p[:strLen]), p[strLen:], nil
+	}
+
+	// TODO: optimize this garbage:
+	var buf bytes.Buffer
+	if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil {
+		return "", nil, err
+	}
+	return buf.String(), p[strLen:], nil
 }
 }

+ 35 - 12
hpack/hpack_test.go

@@ -116,7 +116,7 @@ func TestStaticTable(t *testing.T) {
 }
 }
 
 
 func (d *Decoder) mustAt(idx int) HeaderField {
 func (d *Decoder) mustAt(idx int) HeaderField {
-	if hf, ok := d.at(idx); !ok {
+	if hf, ok := d.at(uint64(idx)); !ok {
 		panic(fmt.Sprintf("bogus index %d", idx))
 		panic(fmt.Sprintf("bogus index %d", idx))
 	} else {
 	} else {
 		return hf
 		return hf
@@ -170,18 +170,31 @@ func TestDynamicTableSizeEvict(t *testing.T) {
 }
 }
 
 
 func TestDecoderDecode(t *testing.T) {
 func TestDecoderDecode(t *testing.T) {
+	// TODO: also test state of dynamic table after all these.
 	tests := []struct {
 	tests := []struct {
 		name string
 		name string
 		in   []byte
 		in   []byte
 		want []HeaderField
 		want []HeaderField
 	}{
 	}{
+		// C.2.1 Literal Header Field with Indexing
+		// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.1
+		{"C.2.1", dehex("400a 6375 7374 6f6d 2d6b 6579 0d63 7573 746f 6d2d 6865 6164 6572"),
+			[]HeaderField{{"custom-key", "custom-header"}}},
+
+		{"C.2.2", dehex("040c 2f73 616d 706c 652f 7061 7468"),
+			[]HeaderField{{":path", "/sample/path"}}},
+
+		// TODO: test callback happens with sensitive
+		{"C.2.3", dehex("1008 7061 7373 776f 7264 0673 6563 7265 74"),
+			[]HeaderField{{"password", "secret"}}},
+
 		// Indexed Header Field
 		// Indexed Header Field
 		// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.4
 		// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.4
 		{"C.2.4", []byte("\x82"), []HeaderField{{":method", "GET"}}},
 		{"C.2.4", []byte("\x82"), []HeaderField{{":method", "GET"}}},
 	}
 	}
 	for _, tt := range tests {
 	for _, tt := range tests {
 		d := NewDecoder(4096, nil)
 		d := NewDecoder(4096, nil)
-		hf, err := d.Decode(tt.in)
+		hf, err := d.DecodeFull(tt.in)
 		if err != nil {
 		if err != nil {
 			t.Errorf("%s: %v", tt.name, err)
 			t.Errorf("%s: %v", tt.name, err)
 			continue
 			continue
@@ -247,14 +260,14 @@ func TestReadVarInt(t *testing.T) {
 		{8, []byte{254}, res{254, 1, nil}},
 		{8, []byte{254}, res{254, 1, nil}},
 
 
 		// Doesn't fit in a byte:
 		// Doesn't fit in a byte:
-		{1, []byte{1}, res{0, 0, nil}},
-		{2, []byte{3}, res{0, 0, nil}},
-		{3, []byte{7}, res{0, 0, nil}},
-		{4, []byte{15}, res{0, 0, nil}},
-		{5, []byte{31}, res{0, 0, nil}},
-		{6, []byte{63}, res{0, 0, nil}},
-		{7, []byte{127}, res{0, 0, nil}},
-		{8, []byte{255}, res{0, 0, nil}},
+		{1, []byte{1}, res{0, 0, errNeedMore}},
+		{2, []byte{3}, res{0, 0, errNeedMore}},
+		{3, []byte{7}, res{0, 0, errNeedMore}},
+		{4, []byte{15}, res{0, 0, errNeedMore}},
+		{5, []byte{31}, res{0, 0, errNeedMore}},
+		{6, []byte{63}, res{0, 0, errNeedMore}},
+		{7, []byte{127}, res{0, 0, errNeedMore}},
+		{8, []byte{255}, res{0, 0, errNeedMore}},
 
 
 		// Ignoring top bits:
 		// Ignoring top bits:
 		{5, []byte{255, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 111
 		{5, []byte{255, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 111
@@ -265,16 +278,26 @@ func TestReadVarInt(t *testing.T) {
 		{5, []byte{191, 154, 10, 2}, res{1337, 3, nil}}, // extra byte
 		{5, []byte{191, 154, 10, 2}, res{1337, 3, nil}}, // extra byte
 
 
 		// Short a byte:
 		// Short a byte:
-		{5, []byte{191, 154}, res{0, 0, nil}},
+		{5, []byte{191, 154}, res{0, 0, errNeedMore}},
 
 
 		// integer overflow:
 		// integer overflow:
 		{1, []byte{255, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128}, res{0, 0, errVarintOverflow}},
 		{1, []byte{255, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128}, res{0, 0, errVarintOverflow}},
 	}
 	}
 	for _, tt := range tests {
 	for _, tt := range tests {
-		i, consumed, err := readVarInt(tt.n, tt.p)
+		i, remain, err := readVarInt(tt.n, tt.p)
+		consumed := len(tt.p) - len(remain)
 		got := res{i, consumed, err}
 		got := res{i, consumed, err}
 		if got != tt.want {
 		if got != tt.want {
 			t.Errorf("readVarInt(%d, %v ~ %x) = %+v; want %+v", tt.n, tt.p, tt.p, got, tt.want)
 			t.Errorf("readVarInt(%d, %v ~ %x) = %+v; want %+v", tt.n, tt.p, tt.p, got, tt.want)
 		}
 		}
 	}
 	}
 }
 }
+
+func dehex(s string) []byte {
+	s = strings.Replace(s, " ", "", -1)
+	b, err := hex.DecodeString(s)
+	if err != nil {
+		panic(err)
+	}
+	return b
+}