Browse Source

Fix range checks in readCollectionSize

When using version greater than protoVersion2 the range check in
readCollectionSize only checked for two bytes instead of four bytes.
This commit moves the range check inside the function so that it is
near the logic that accesses the buffer.

This is the panic that the missing range check caused:

panic: runtime error: index out of range

goroutine 78088437 [running]:
github.com/gocql/gocql.readCollectionSize(...)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/marshal.go:1405
github.com/gocql/gocql.unmarshalList(0xfb2ac0, 0xc003a01900, 0xc008a012d5, 0x2, 0x2f, 0xcf7860, 0xc00632a748, 0x918641, 0x4)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/marshal.go:1443 +0xd39
github.com/gocql/gocql.Unmarshal(0xfb2ac0, 0xc003a01900, 0xc008a012d5, 0x2, 0x2f, 0xcf7860, 0xc00632a748, 0xc0073f7130, 0x2)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/marshal.go:152 +0xa3b
github.com/gocql/gocql.scanColumn(0xc008a012d5, 0x2, 0x2f, 0xc001845617, 0x7, 0xc001845630, 0x7, 0xc001845640, 0xc, 0xfb2ac0, ...)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/session.go:1295 +0x291
github.com/gocql/gocql.(*Iter).Scan(0xc00250a360, 0xc00aa406c0, 0x23, 0x23, 0x23)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/session.go:1395 +0x2e6
github.com/scylladb/gocqlx.(*Iterx).StructScan(0xc00183f8c0, 0xdf7280, 0xc00632a600, 0xe78060)
    /go/pkg/mod/github.com/scylladb/gocqlx@v1.2.1/iterx.go:242 +0x1bb
github.com/scylladb/gocqlx.(*Iterx).scanAny(0xc00183f8c0, 0xdf7280, 0xc00632a600, 0xc00183f800, 0xc00250a360)
    /go/pkg/mod/github.com/scylladb/gocqlx@v1.2.1/iterx.go:105 +0x232
github.com/scylladb/gocqlx.(*Iterx).Get(0xc00183f8c0, 0xdf7280, 0xc00632a600, 0xc001714c74, 0xc001714c74)
    /go/pkg/mod/github.com/scylladb/gocqlx@v1.2.1/iterx.go:66 +0x48
github.com/scylladb/gocqlx.(*Queryx).Get(0xc00081b5c0, 0xdf7280, 0xc00632a600, 0xc00081b5c0, 0x40750b)
    /go/pkg/mod/github.com/scylladb/gocqlx@v1.2.1/queryx.go:212 +0x75
Martin Sucha 6 years ago
parent
commit
ecab4857ac
1 changed files with 23 additions and 14 deletions
  1. 23 14
      marshal.go

+ 23 - 14
marshal.go

@@ -1400,11 +1400,17 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) {
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 
-func readCollectionSize(info CollectionType, data []byte) (size, read int) {
+func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) {
 	if info.proto > protoVersion2 {
+		if len(data) < 4 {
+			return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof")
+		}
 		size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
 		read = 4
 	} else {
+		if len(data) < 2 {
+			return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof")
+		}
 		size = int(data[0])<<8 | int(data[1])
 		read = 2
 	}
@@ -1437,10 +1443,10 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
 			rv.Set(reflect.Zero(t))
 			return nil
 		}
-		if len(data) < 2 {
-			return unmarshalErrorf("unmarshal list: unexpected eof")
+		n, p, err := readCollectionSize(listInfo, data)
+		if err != nil {
+			return err
 		}
-		n, p := readCollectionSize(listInfo, data)
 		data = data[p:]
 		if k == reflect.Array {
 			if rv.Len() != n {
@@ -1450,10 +1456,10 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
 			rv.Set(reflect.MakeSlice(t, n, n))
 		}
 		for i := 0; i < n; i++ {
-			if len(data) < 2 {
-				return unmarshalErrorf("unmarshal list: unexpected eof")
+			m, p, err := readCollectionSize(listInfo, data)
+			if err != nil {
+				return err
 			}
-			m, p := readCollectionSize(listInfo, data)
 			data = data[p:]
 			if err := Unmarshal(listInfo.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 				return err
@@ -1538,16 +1544,16 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error {
 		return nil
 	}
 	rv.Set(reflect.MakeMap(t))
-	if len(data) < 2 {
-		return unmarshalErrorf("unmarshal map: unexpected eof")
+	n, p, err := readCollectionSize(mapInfo, data)
+	if err != nil {
+		return err
 	}
-	n, p := readCollectionSize(mapInfo, data)
 	data = data[p:]
 	for i := 0; i < n; i++ {
-		if len(data) < 2 {
-			return unmarshalErrorf("unmarshal list: unexpected eof")
+		m, p, err := readCollectionSize(mapInfo, data)
+		if err != nil {
+			return err
 		}
-		m, p := readCollectionSize(mapInfo, data)
 		data = data[p:]
 		key := reflect.New(t.Key())
 		if err := Unmarshal(mapInfo.Key, data[:m], key.Interface()); err != nil {
@@ -1555,7 +1561,10 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error {
 		}
 		data = data[m:]
 
-		m, p = readCollectionSize(mapInfo, data)
+		m, p, err = readCollectionSize(mapInfo, data)
+		if err != nil {
+			return err
+		}
 		data = data[p:]
 		val := reflect.New(t.Elem())
 		if err := Unmarshal(mapInfo.Elem, data[:m], val.Interface()); err != nil {