Browse Source

Fix embarassing typo bug in ReadFrameHeader.

Thanks to Qi Zhao for the find.

Added new tests as penance.
Brad Fitzpatrick 11 năm trước cách đây
mục cha
commit
e354193b05
2 tập tin đã thay đổi với 55 bổ sung1 xóa
  1. 7 1
      frame.go
  2. 48 0
      frame_test.go

+ 7 - 1
frame.go

@@ -239,7 +239,7 @@ func readFrameHeader(buf []byte, r io.Reader) (FrameHeader, error) {
 		return FrameHeader{}, err
 	}
 	return FrameHeader{
-		Length:   (uint32(buf[0])<<16 | uint32(buf[1])<<7 | uint32(buf[2])),
+		Length:   (uint32(buf[0])<<16 | uint32(buf[1])<<8 | uint32(buf[2])),
 		Type:     FrameType(buf[3]),
 		Flags:    Flags(buf[4]),
 		StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
@@ -303,10 +303,15 @@ func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) {
 		byte(streamID))
 }
 
+var errFrameTooLarge = errors.New("http2: frame too large")
+
 func (f *Framer) endWrite() error {
 	// Now that we know the final size, fill in the FrameHeader in
 	// the space previously reserved for it. Abuse append.
 	length := len(f.wbuf) - frameHeaderLen
+	if length >= (1 << 24) {
+		return errFrameTooLarge
+	}
 	_ = append(f.wbuf[:0],
 		byte(length>>16),
 		byte(length>>8),
@@ -319,6 +324,7 @@ func (f *Framer) endWrite() error {
 }
 
 func (f *Framer) writeByte(v byte)     { f.wbuf = append(f.wbuf, v) }
+func (f *Framer) writeBytes(v []byte)  { f.wbuf = append(f.wbuf, v...) }
 func (f *Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) }
 func (f *Framer) writeUint32(v uint32) {
 	f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))

+ 48 - 0
frame_test.go

@@ -382,3 +382,51 @@ func TestWriteWindowUpdate(t *testing.T) {
 		t.Errorf("parsed back %#v; want %#v", f, want)
 	}
 }
+
+func TestReadFrameHeader(t *testing.T) {
+	tests := []struct {
+		len      uint32
+		typ      FrameType
+		flags    Flags
+		streamID uint32
+	}{
+		{len: 0, typ: 255, flags: 1, streamID: 0},
+		{len: 0, typ: 255, flags: 1, streamID: 1},
+		{len: 0, typ: 255, flags: 1, streamID: 255},
+		{len: 0, typ: 255, flags: 1, streamID: 256},
+		{len: 0, typ: 255, flags: 1, streamID: 65535},
+		{len: 0, typ: 255, flags: 1, streamID: 65536},
+
+		{len: 0, typ: 1, flags: 255, streamID: 1},
+		{len: 255, typ: 1, flags: 255, streamID: 1},
+		{len: 256, typ: 1, flags: 255, streamID: 1},
+		{len: 65535, typ: 1, flags: 255, streamID: 1},
+		{len: 65536, typ: 1, flags: 255, streamID: 1},
+		{len: 16777215, typ: 1, flags: 255, streamID: 1},
+	}
+	for _, tt := range tests {
+		fr, buf := testFramer()
+		fr.startWrite(tt.typ, tt.flags, tt.streamID)
+		fr.writeBytes(make([]byte, tt.len))
+		fr.endWrite()
+		fh, err := ReadFrameHeader(buf)
+		if err != nil {
+			t.Errorf("ReadFrameHeader(%+v) = %v", tt, err)
+			continue
+		}
+		if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID {
+			t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh)
+		}
+	}
+
+}
+
+func TestWriteTooLargeFrame(t *testing.T) {
+	fr, _ := testFramer()
+	fr.startWrite(0, 1, 1)
+	fr.writeBytes(make([]byte, 1<<24))
+	err := fr.endWrite()
+	if err != errFrameTooLarge {
+		t.Errorf("endWrite = %v; want errFrameTooLarge", err)
+	}
+}