Browse Source

http2/hpack: push down max string length checking further, improve docs

Change-Id: I875835875f8f97158f2dc88e508a075929af931e
Reviewed-on: https://go-review.googlesource.com/15827
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Brad Fitzpatrick 10 years ago
parent
commit
21c3935a8f
3 changed files with 62 additions and 21 deletions
  1. 28 16
      http2/hpack/hpack.go
  2. 25 2
      http2/hpack/hpack_test.go
  3. 9 3
      http2/hpack/huffman.go

+ 28 - 16
http2/hpack/hpack.go

@@ -95,8 +95,8 @@ func NewDecoder(maxDynamicTableSize uint32, emitFunc func(f HeaderField)) *Decod
 var ErrStringLength = errors.New("hpack: string too long")
 
 // SetMaxStringLength sets the maximum size of a HeaderField name or
-// value string, after compression. If a string exceeds this length,
-// Write will return ErrStringLength.
+// value string. If a string exceeds this length (even after any
+// decompression), Write will return ErrStringLength.
 // A value of 0 means unlimited and is the default from NewDecoder.
 func (d *Decoder) SetMaxStringLength(n int) {
 	d.maxStrLen = n
@@ -281,16 +281,20 @@ func (d *Decoder) Write(p []byte) (n int, err error) {
 
 	for len(d.buf) > 0 {
 		err = d.parseHeaderFieldRepr()
-		if err != nil {
-			if err == errNeedMore {
-				err = nil
-				const varIntOverhead = 8 // conservative
-				if d.maxStrLen != 0 &&
-					int64(len(d.buf))+int64(d.saveBuf.Len()) > 2*(int64(d.maxStrLen)+varIntOverhead) {
-					return 0, ErrStringLength
-				}
-				d.saveBuf.Write(d.buf)
+		if err == errNeedMore {
+			// Extra paranoia, making sure saveBuf won't
+			// get too large.  All the varint and string
+			// reading code earlier should already catch
+			// overlong things and return ErrStringLength,
+			// but keep this as a last resort.
+			const varIntOverhead = 8 // conservative
+			if d.maxStrLen != 0 && int64(len(d.buf)) > 2*(int64(d.maxStrLen)+varIntOverhead) {
+				return 0, ErrStringLength
 			}
+			d.saveBuf.Write(d.buf)
+			return len(p), nil
+		}
+		if err != nil {
 			break
 		}
 	}
@@ -382,12 +386,12 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
 		}
 		hf.Name = ihf.Name
 	} else {
-		hf.Name, buf, err = readString(buf, wantStr)
+		hf.Name, buf, err = d.readString(buf, wantStr)
 		if err != nil {
 			return err
 		}
 	}
-	hf.Value, buf, err = readString(buf, wantStr)
+	hf.Value, buf, err = d.readString(buf, wantStr)
 	if err != nil {
 		return err
 	}
@@ -477,7 +481,7 @@ func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
 // strings past the MAX_HEADER_LIST_SIZE are ignored, but the server
 // is returning an error anyway, and because they're not indexed, the error
 // won't affect the decoding state.
-func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
+func (d *Decoder) readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
 	if len(p) == 0 {
 		return "", p, errNeedMore
 	}
@@ -486,6 +490,9 @@ func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
 	if err != nil {
 		return "", p, err
 	}
+	if d.maxStrLen != 0 && strLen > uint64(d.maxStrLen) {
+		return "", nil, ErrStringLength
+	}
 	if uint64(len(p)) < strLen {
 		return "", p, errNeedMore
 	}
@@ -497,10 +504,15 @@ func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
 	}
 
 	if wantStr {
-		s, err = HuffmanDecodeToString(p[:strLen])
-		if err != nil {
+		buf := bufPool.Get().(*bytes.Buffer)
+		buf.Reset() // don't trust others
+		defer bufPool.Put(buf)
+		if err := huffmanDecode(buf, d.maxStrLen, p[:strLen]); err != nil {
+			buf.Reset()
 			return "", nil, err
 		}
+		s = buf.String()
+		buf.Reset() // be nice to GC
 	}
 	return s, p[strLen:], nil
 }

+ 25 - 2
http2/hpack/hpack_test.go

@@ -583,6 +583,29 @@ func TestAppendHuffmanString(t *testing.T) {
 	}
 }
 
+func TestHuffmanMaxStrLen(t *testing.T) {
+	const msg = "Some string"
+	huff := AppendHuffmanString(nil, msg)
+
+	testGood := func(max int) {
+		var out bytes.Buffer
+		if err := huffmanDecode(&out, max, huff); err != nil {
+			t.Errorf("For maxLen=%d, unexpected error: %v", max, err)
+		}
+		if out.String() != msg {
+			t.Errorf("For maxLen=%d, out = %q; want %q", max, out.String(), msg)
+		}
+	}
+	testGood(0)
+	testGood(len(msg))
+	testGood(len(msg) + 1)
+
+	var out bytes.Buffer
+	if err := huffmanDecode(&out, len(msg)-1, huff); err != ErrStringLength {
+		t.Errorf("err = %v; want ErrStringLength", err)
+	}
+}
+
 func TestHuffmanRoundtripStress(t *testing.T) {
 	const Len = 50 // of uncompressed string
 	input := make([]byte, Len)
@@ -604,7 +627,7 @@ func TestHuffmanRoundtripStress(t *testing.T) {
 		huff = AppendHuffmanString(huff[:0], string(input))
 		encSize += int64(len(huff))
 		output.Reset()
-		if err := huffmanDecode(&output, huff); err != nil {
+		if err := huffmanDecode(&output, 0, huff); err != nil {
 			t.Errorf("Failed to decode %q -> %q -> error %v", input, huff, err)
 			continue
 		}
@@ -639,7 +662,7 @@ func TestHuffmanDecodeFuzz(t *testing.T) {
 		}
 
 		buf.Reset()
-		if err := huffmanDecode(&buf, zbuf.Bytes()); err != nil {
+		if err := huffmanDecode(&buf, 0, zbuf.Bytes()); err != nil {
 			if err == ErrInvalidHuffman {
 				numFail++
 				continue

+ 9 - 3
http2/hpack/huffman.go

@@ -22,7 +22,7 @@ func HuffmanDecode(w io.Writer, v []byte) (int, error) {
 	buf := bufPool.Get().(*bytes.Buffer)
 	buf.Reset()
 	defer bufPool.Put(buf)
-	if err := huffmanDecode(buf, v); err != nil {
+	if err := huffmanDecode(buf, 0, v); err != nil {
 		return 0, err
 	}
 	return w.Write(buf.Bytes())
@@ -33,7 +33,7 @@ func HuffmanDecodeToString(v []byte) (string, error) {
 	buf := bufPool.Get().(*bytes.Buffer)
 	buf.Reset()
 	defer bufPool.Put(buf)
-	if err := huffmanDecode(buf, v); err != nil {
+	if err := huffmanDecode(buf, 0, v); err != nil {
 		return "", err
 	}
 	return buf.String(), nil
@@ -43,7 +43,10 @@ func HuffmanDecodeToString(v []byte) (string, error) {
 // Huffman-encoded strings.
 var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data")
 
-func huffmanDecode(buf *bytes.Buffer, v []byte) error {
+// huffmanDecode decodes v to buf.
+// If maxLen is greater than 0, attempts to write more to buf than
+// maxLen bytes will return ErrStringLength.
+func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
 	n := rootHuffmanNode
 	cur, nbits := uint(0), uint8(0)
 	for _, b := range v {
@@ -56,6 +59,9 @@ func huffmanDecode(buf *bytes.Buffer, v []byte) error {
 				return ErrInvalidHuffman
 			}
 			if n.children == nil {
+				if maxLen != 0 && buf.Len() == maxLen {
+					return ErrStringLength
+				}
 				buf.WriteByte(n.sym)
 				nbits -= n.codeLen
 				n = rootHuffmanNode