Browse Source

Ensure that frames are not bigger than the max

When writing frames ensure that we dont write a frame which is
invalid as it is too big, ensure that we dont read a frame which
is too large.
Chris Bannister 10 years ago
parent
commit
619641c580
2 changed files with 63 additions and 2 deletions
  1. 21 0
      frame.go
  2. 42 2
      frame_test.go

+ 21 - 0
frame.go

@@ -5,8 +5,10 @@
 package gocql
 
 import (
+	"errors"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"net"
 	"runtime"
 	"sync"
@@ -19,6 +21,8 @@ const (
 	protoVersion1      = 0x01
 	protoVersion2      = 0x02
 	protoVersion3      = 0x03
+
+	maxFrameSize = 256 * 1024 * 1024
 )
 
 type protoVersion byte
@@ -193,6 +197,10 @@ const (
 	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 )
 
+var (
+	ErrFrameTooBig = errors.New("frame length is bigger than the maximum alowed")
+)
+
 func writeInt(p []byte, n int32) {
 	p[0] = byte(n >> 24)
 	p[1] = byte(n >> 16)
@@ -345,6 +353,13 @@ func (f *framer) trace() {
 func (f *framer) readFrame(head *frameHeader) error {
 	if head.length < 0 {
 		return fmt.Errorf("frame body length can not be less than 0: %d", head.length)
+	} else if head.length > maxFrameSize {
+		// need to free up the connection to be used again
+		_, err := io.CopyN(ioutil.Discard, f.r, int64(head.length))
+		if err != nil {
+			return fmt.Errorf("error whilst trying to discard frame with invalid length: %v", err)
+		}
+		return ErrFrameTooBig
 	}
 
 	if cap(f.readBuffer) >= head.length {
@@ -521,6 +536,12 @@ func (f *framer) setLength(length int) {
 }
 
 func (f *framer) finishWrite() error {
+	if len(f.wbuf) > maxFrameSize {
+		// huge app frame, lets remove it so it doesnt bloat the heap
+		f.wbuf = make([]byte, defaultBufSize)
+		return ErrFrameTooBig
+	}
+
 	if f.wbuf[1]&flagCompress == flagCompress {
 		if f.compres == nil {
 			panic("compress flag set with no compressor")

+ 42 - 2
frame_test.go

@@ -29,12 +29,12 @@ func TestFuzzBugs(t *testing.T) {
 
 		r := bytes.NewReader(test)
 
-		head, err := readHeader(r, make([]byte, 9))
+		head, err := readHeader(r, make([]byte, 8))
 		if err != nil {
 			continue
 		}
 
-		framer := newFramer(r, &bw, nil, 3)
+		framer := newFramer(r, &bw, nil, 2)
 		err = framer.readFrame(&head)
 		if err != nil {
 			continue
@@ -48,3 +48,43 @@ func TestFuzzBugs(t *testing.T) {
 		t.Errorf("(%d) expected to fail for input %q", i, test)
 	}
 }
+
+func TestFrameWriteTooLong(t *testing.T) {
+	w := &bytes.Buffer{}
+	framer := newFramer(nil, w, nil, 2)
+
+	framer.writeHeader(0, opStartup, 1)
+	framer.writeBytes(make([]byte, maxFrameSize+1))
+	err := framer.finishWrite()
+	if err != ErrFrameTooBig {
+		t.Fatalf("expected to get %v got %v", ErrFrameTooBig, err)
+	}
+}
+
+func TestFrameReadTooLong(t *testing.T) {
+	r := &bytes.Buffer{}
+	r.Write(make([]byte, maxFrameSize+1))
+	// write a new header right after this frame to verify that we can read it
+	r.Write([]byte{0x02, 0x00, 0x00, opReady, 0x00, 0x00, 0x00, 0x00})
+
+	framer := newFramer(r, nil, nil, 2)
+
+	head := frameHeader{
+		version: 2,
+		op:      opReady,
+		length:  r.Len() - 8,
+	}
+
+	err := framer.readFrame(&head)
+	if err != ErrFrameTooBig {
+		t.Fatalf("expected to get %v got %v", ErrFrameTooBig, err)
+	}
+
+	head, err = readHeader(r, make([]byte, 8))
+	if err != nil {
+		t.Fatal(err)
+	}
+	if head.op != opReady {
+		t.Fatalf("expected to get header %v got %v", opReady, head.op)
+	}
+}