Browse Source

Merge pull request #1271 from martin-sucha/fix-range-check-gocql

Fix range checks in readCollectionSize.
Alex Lourie 6 years ago
parent
commit
59996b52a6
3 changed files with 109 additions and 14 deletions
  1. 1 0
      AUTHORS
  2. 23 14
      marshal.go
  3. 85 0
      marshal_test.go

+ 1 - 0
AUTHORS

@@ -110,3 +110,4 @@ Alex Lourie <alex@instaclustr.com>; <djay.il@gmail.com>
 Marco Cadetg <cadetg@gmail.com>
 Marco Cadetg <cadetg@gmail.com>
 Karl Matthias <karl@matthias.org>
 Karl Matthias <karl@matthias.org>
 Thomas Meson <zllak@hycik.org>
 Thomas Meson <zllak@hycik.org>
+Martin Sucha <martin.sucha@kiwi.com>; <git@mm.ms47.eu>

+ 23 - 14
marshal.go

@@ -1400,11 +1400,17 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) {
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 }
 
 
-func readCollectionSize(info CollectionType, data []byte) (size, read int) {
+func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) {
 	if info.proto > protoVersion2 {
 	if info.proto > protoVersion2 {
+		if len(data) < 4 {
+			return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof")
+		}
 		size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
 		size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
 		read = 4
 		read = 4
 	} else {
 	} else {
+		if len(data) < 2 {
+			return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof")
+		}
 		size = int(data[0])<<8 | int(data[1])
 		size = int(data[0])<<8 | int(data[1])
 		read = 2
 		read = 2
 	}
 	}
@@ -1437,10 +1443,10 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
 			rv.Set(reflect.Zero(t))
 			rv.Set(reflect.Zero(t))
 			return nil
 			return nil
 		}
 		}
-		if len(data) < 2 {
-			return unmarshalErrorf("unmarshal list: unexpected eof")
+		n, p, err := readCollectionSize(listInfo, data)
+		if err != nil {
+			return err
 		}
 		}
-		n, p := readCollectionSize(listInfo, data)
 		data = data[p:]
 		data = data[p:]
 		if k == reflect.Array {
 		if k == reflect.Array {
 			if rv.Len() != n {
 			if rv.Len() != n {
@@ -1450,10 +1456,10 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
 			rv.Set(reflect.MakeSlice(t, n, n))
 			rv.Set(reflect.MakeSlice(t, n, n))
 		}
 		}
 		for i := 0; i < n; i++ {
 		for i := 0; i < n; i++ {
-			if len(data) < 2 {
-				return unmarshalErrorf("unmarshal list: unexpected eof")
+			m, p, err := readCollectionSize(listInfo, data)
+			if err != nil {
+				return err
 			}
 			}
-			m, p := readCollectionSize(listInfo, data)
 			data = data[p:]
 			data = data[p:]
 			if err := Unmarshal(listInfo.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 			if err := Unmarshal(listInfo.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 				return err
 				return err
@@ -1538,16 +1544,16 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error {
 		return nil
 		return nil
 	}
 	}
 	rv.Set(reflect.MakeMap(t))
 	rv.Set(reflect.MakeMap(t))
-	if len(data) < 2 {
-		return unmarshalErrorf("unmarshal map: unexpected eof")
+	n, p, err := readCollectionSize(mapInfo, data)
+	if err != nil {
+		return err
 	}
 	}
-	n, p := readCollectionSize(mapInfo, data)
 	data = data[p:]
 	data = data[p:]
 	for i := 0; i < n; i++ {
 	for i := 0; i < n; i++ {
-		if len(data) < 2 {
-			return unmarshalErrorf("unmarshal list: unexpected eof")
+		m, p, err := readCollectionSize(mapInfo, data)
+		if err != nil {
+			return err
 		}
 		}
-		m, p := readCollectionSize(mapInfo, data)
 		data = data[p:]
 		data = data[p:]
 		key := reflect.New(t.Key())
 		key := reflect.New(t.Key())
 		if err := Unmarshal(mapInfo.Key, data[:m], key.Interface()); err != nil {
 		if err := Unmarshal(mapInfo.Key, data[:m], key.Interface()); err != nil {
@@ -1555,7 +1561,10 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error {
 		}
 		}
 		data = data[m:]
 		data = data[m:]
 
 
-		m, p = readCollectionSize(mapInfo, data)
+		m, p, err = readCollectionSize(mapInfo, data)
+		if err != nil {
+			return err
+		}
 		data = data[p:]
 		data = data[p:]
 		val := reflect.New(t.Elem())
 		val := reflect.New(t.Elem())
 		if err := Unmarshal(mapInfo.Elem, data[:m], val.Interface()); err != nil {
 		if err := Unmarshal(mapInfo.Elem, data[:m], val.Interface()); err != nil {

+ 85 - 0
marshal_test.go

@@ -1480,3 +1480,88 @@ func TestMarshalDuration(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestReadCollectionSize(t *testing.T) {
+	listV2 := CollectionType{
+		NativeType: NativeType{proto: 2, typ: TypeList},
+		Elem:       NativeType{proto: 2, typ: TypeVarchar},
+	}
+	listV3 := CollectionType{
+		NativeType: NativeType{proto: 3, typ: TypeList},
+		Elem:       NativeType{proto: 3, typ: TypeVarchar},
+	}
+
+	tests := []struct {
+		name string
+		info CollectionType
+		data []byte
+		isError bool
+		expectedSize int
+	}{
+		{
+			name: "short read 0 proto 2",
+			info: listV2,
+			data: []byte{},
+			isError: true,
+		},
+		{
+			name: "short read 1 proto 2",
+			info: listV2,
+			data: []byte{0x01},
+			isError: true,
+		},
+		{
+			name: "good read proto 2",
+			info: listV2,
+			data: []byte{0x01, 0x38},
+			expectedSize: 0x0138,
+		},
+		{
+			name: "short read 0 proto 3",
+			info: listV3,
+			data: []byte{},
+			isError: true,
+		},
+		{
+			name: "short read 1 proto 3",
+			info: listV3,
+			data: []byte{0x01},
+			isError: true,
+		},
+		{
+			name: "short read 2 proto 3",
+			info: listV3,
+			data: []byte{0x01, 0x38},
+			isError: true,
+		},
+		{
+			name: "short read 3 proto 3",
+			info: listV3,
+			data: []byte{0x01, 0x38, 0x42},
+			isError: true,
+		},
+		{
+			name: "good read proto 3",
+			info: listV3,
+			data: []byte{0x01, 0x38, 0x42, 0x22},
+			expectedSize: 0x01384222,
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			size, _, err := readCollectionSize(test.info, test.data)
+			if test.isError {
+				if err == nil {
+					t.Fatal("Expected error, but it was nil")
+				}
+			} else {
+				if err != nil {
+					t.Fatalf("Expected no error, got %v", err)
+				}
+				if size != test.expectedSize {
+					t.Fatalf("Expected size of %d, but got %d", test.expectedSize, size)
+				}
+			}
+		})
+	}
+}