فهرست منبع

http2/hpack: forbid excess and invalid padding in hpack decoder

This change fixes a few bugs in the HPACK decoder:
 * Excess trailing padding is treated as an error per the HPACK Spec
     section 5.2
 * Non EOS prefix padding is treated as an error
 * Max length is now enforced for all decoded symbols

The idea here is to keep track of the decoded symbol length, rather
than the number of unconsumed bits in cur.  To this end, nbits has
been renamed cbits (cur bits), and sbits (sym bits) has been
introduced.  The main problem with using nbits is that it can easily
be zero, such as when decoding {0xff, 0xff}.  Using a clear moniker
makes it easier to see why checking cbits > 0 at the end of the
function is incorrect.

Fixes golang/go#15614

Change-Id: I1ae868caa9c207fcf9c9dec7f10ee9f400211f99
Reviewed-on: https://go-review.googlesource.com/23067
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Carl Mastrangelo 9 سال پیش
والد
کامیت
6050c11192
2فایلهای تغییر یافته به همراه73 افزوده شده و 10 حذف شده
  1. 41 0
      http2/hpack/hpack_test.go
  2. 32 10
      http2/hpack/huffman.go

+ 41 - 0
http2/hpack/hpack_test.go

@@ -524,6 +524,47 @@ func testDecodeSeries(t *testing.T, size uint32, steps []encAndWant) {
 	}
 	}
 }
 }
 
 
+func TestHuffmanDecodeExcessPadding(t *testing.T) {
+	tests := [][]byte{
+		{0xff},                                   // Padding Exceeds 7 bits
+		{0x1f, 0xff},                             // {"a", 1 byte excess padding}
+		{0x1f, 0xff, 0xff},                       // {"a", 2 byte excess padding}
+		{0x1f, 0xff, 0xff, 0xff},                 // {"a", 3 byte excess padding}
+		{0xff, 0x9f, 0xff, 0xff, 0xff},           // {"a", 29 bit excess padding}
+		{'R', 0xbc, '0', 0xff, 0xff, 0xff, 0xff}, // Padding ends on partial symbol.
+	}
+	for i, in := range tests {
+		var buf bytes.Buffer
+		if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+			t.Errorf("test-%d: decode(%q) = %v; want ErrInvalidHuffman", i, in, err)
+		}
+	}
+}
+
+func TestHuffmanDecodeEOS(t *testing.T) {
+	in := []byte{0xff, 0xff, 0xff, 0xff, 0xfc} // {EOS, "?"}
+	var buf bytes.Buffer
+	if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+		t.Errorf("error = %v; want ErrInvalidHuffman", err)
+	}
+}
+
+func TestHuffmanDecodeMaxLengthOnTrailingByte(t *testing.T) {
+	in := []byte{0x00, 0x01} // {"0", "0", "0"}
+	var buf bytes.Buffer
+	if err := huffmanDecode(&buf, 2, in); err != ErrStringLength {
+		t.Errorf("error = %v; want ErrStringLength", err)
+	}
+}
+
+func TestHuffmanDecodeCorruptPadding(t *testing.T) {
+	in := []byte{0x00}
+	var buf bytes.Buffer
+	if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+		t.Errorf("error = %v; want ErrInvalidHuffman", err)
+	}
+}
+
 func TestHuffmanDecode(t *testing.T) {
 func TestHuffmanDecode(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		inHex, want string
 		inHex, want string

+ 32 - 10
http2/hpack/huffman.go

@@ -48,12 +48,16 @@ var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data")
 // maxLen bytes will return ErrStringLength.
 // maxLen bytes will return ErrStringLength.
 func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
 func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
 	n := rootHuffmanNode
 	n := rootHuffmanNode
-	cur, nbits := uint(0), uint8(0)
+	// cur is the bit buffer that has not been fed into n.
+	// cbits is the number of low order bits in cur that are valid.
+	// sbits is the number of bits of the symbol prefix being decoded.
+	cur, cbits, sbits := uint(0), uint8(0), uint8(0)
 	for _, b := range v {
 	for _, b := range v {
 		cur = cur<<8 | uint(b)
 		cur = cur<<8 | uint(b)
-		nbits += 8
-		for nbits >= 8 {
-			idx := byte(cur >> (nbits - 8))
+		cbits += 8
+		sbits += 8
+		for cbits >= 8 {
+			idx := byte(cur >> (cbits - 8))
 			n = n.children[idx]
 			n = n.children[idx]
 			if n == nil {
 			if n == nil {
 				return ErrInvalidHuffman
 				return ErrInvalidHuffman
@@ -63,22 +67,40 @@ func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
 					return ErrStringLength
 					return ErrStringLength
 				}
 				}
 				buf.WriteByte(n.sym)
 				buf.WriteByte(n.sym)
-				nbits -= n.codeLen
+				cbits -= n.codeLen
 				n = rootHuffmanNode
 				n = rootHuffmanNode
+				sbits = cbits
 			} else {
 			} else {
-				nbits -= 8
+				cbits -= 8
 			}
 			}
 		}
 		}
 	}
 	}
-	for nbits > 0 {
-		n = n.children[byte(cur<<(8-nbits))]
-		if n.children != nil || n.codeLen > nbits {
+	for cbits > 0 {
+		n = n.children[byte(cur<<(8-cbits))]
+		if n == nil {
+			return ErrInvalidHuffman
+		}
+		if n.children != nil || n.codeLen > cbits {
 			break
 			break
 		}
 		}
+		if maxLen != 0 && buf.Len() == maxLen {
+			return ErrStringLength
+		}
 		buf.WriteByte(n.sym)
 		buf.WriteByte(n.sym)
-		nbits -= n.codeLen
+		cbits -= n.codeLen
 		n = rootHuffmanNode
 		n = rootHuffmanNode
+		sbits = cbits
+	}
+	if sbits > 7 {
+		// Either there was an incomplete symbol, or overlong padding.
+		// Both are decoding errors per RFC 7541 section 5.2.
+		return ErrInvalidHuffman
 	}
 	}
+	if mask := uint(1<<cbits - 1); cur&mask != mask {
+		// Trailing bits must be a prefix of EOS per RFC 7541 section 5.2.
+		return ErrInvalidHuffman
+	}
+
 	return nil
 	return nil
 }
 }