소스 검색

Fix range checks in readCollectionSize

When using version greater than protoVersion2 the range check in
readCollectionSize only checked for two bytes instead of four bytes.
This commit moves the range check inside the function so that it is
near the logic that accesses the buffer.

This is the panic that the missing range check caused:

panic: runtime error: index out of range

goroutine 78088437 [running]:
github.com/gocql/gocql.readCollectionSize(...)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/marshal.go:1405
github.com/gocql/gocql.unmarshalList(0xfb2ac0, 0xc003a01900, 0xc008a012d5, 0x2, 0x2f, 0xcf7860, 0xc00632a748, 0x918641, 0x4)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/marshal.go:1443 +0xd39
github.com/gocql/gocql.Unmarshal(0xfb2ac0, 0xc003a01900, 0xc008a012d5, 0x2, 0x2f, 0xcf7860, 0xc00632a748, 0xc0073f7130, 0x2)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/marshal.go:152 +0xa3b
github.com/gocql/gocql.scanColumn(0xc008a012d5, 0x2, 0x2f, 0xc001845617, 0x7, 0xc001845630, 0x7, 0xc001845640, 0xc, 0xfb2ac0, ...)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/session.go:1295 +0x291
github.com/gocql/gocql.(*Iter).Scan(0xc00250a360, 0xc00aa406c0, 0x23, 0x23, 0x23)
    /go/pkg/mod/github.com/scylladb/gocql@v1.0.0/session.go:1395 +0x2e6
github.com/scylladb/gocqlx.(*Iterx).StructScan(0xc00183f8c0, 0xdf7280, 0xc00632a600, 0xe78060)
    /go/pkg/mod/github.com/scylladb/gocqlx@v1.2.1/iterx.go:242 +0x1bb
github.com/scylladb/gocqlx.(*Iterx).scanAny(0xc00183f8c0, 0xdf7280, 0xc00632a600, 0xc00183f800, 0xc00250a360)
    /go/pkg/mod/github.com/scylladb/gocqlx@v1.2.1/iterx.go:105 +0x232
github.com/scylladb/gocqlx.(*Iterx).Get(0xc00183f8c0, 0xdf7280, 0xc00632a600, 0xc001714c74, 0xc001714c74)
    /go/pkg/mod/github.com/scylladb/gocqlx@v1.2.1/iterx.go:66 +0x48
github.com/scylladb/gocqlx.(*Queryx).Get(0xc00081b5c0, 0xdf7280, 0xc00632a600, 0xc00081b5c0, 0x40750b)
    /go/pkg/mod/github.com/scylladb/gocqlx@v1.2.1/queryx.go:212 +0x75
Martin Sucha 6 년 전
부모
커밋
ecab4857ac
1개의 변경된 파일23개의 추가작업 그리고 14개의 파일을 삭제
  1. 23 14
      marshal.go

+ 23 - 14
marshal.go

@@ -1400,11 +1400,17 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) {
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 
-func readCollectionSize(info CollectionType, data []byte) (size, read int) {
+func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) {
 	if info.proto > protoVersion2 {
+		if len(data) < 4 {
+			return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof")
+		}
 		size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
 		read = 4
 	} else {
+		if len(data) < 2 {
+			return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof")
+		}
 		size = int(data[0])<<8 | int(data[1])
 		read = 2
 	}
@@ -1437,10 +1443,10 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
 			rv.Set(reflect.Zero(t))
 			return nil
 		}
-		if len(data) < 2 {
-			return unmarshalErrorf("unmarshal list: unexpected eof")
+		n, p, err := readCollectionSize(listInfo, data)
+		if err != nil {
+			return err
 		}
-		n, p := readCollectionSize(listInfo, data)
 		data = data[p:]
 		if k == reflect.Array {
 			if rv.Len() != n {
@@ -1450,10 +1456,10 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
 			rv.Set(reflect.MakeSlice(t, n, n))
 		}
 		for i := 0; i < n; i++ {
-			if len(data) < 2 {
-				return unmarshalErrorf("unmarshal list: unexpected eof")
+			m, p, err := readCollectionSize(listInfo, data)
+			if err != nil {
+				return err
 			}
-			m, p := readCollectionSize(listInfo, data)
 			data = data[p:]
 			if err := Unmarshal(listInfo.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
 				return err
@@ -1538,16 +1544,16 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error {
 		return nil
 	}
 	rv.Set(reflect.MakeMap(t))
-	if len(data) < 2 {
-		return unmarshalErrorf("unmarshal map: unexpected eof")
+	n, p, err := readCollectionSize(mapInfo, data)
+	if err != nil {
+		return err
 	}
-	n, p := readCollectionSize(mapInfo, data)
 	data = data[p:]
 	for i := 0; i < n; i++ {
-		if len(data) < 2 {
-			return unmarshalErrorf("unmarshal list: unexpected eof")
+		m, p, err := readCollectionSize(mapInfo, data)
+		if err != nil {
+			return err
 		}
-		m, p := readCollectionSize(mapInfo, data)
 		data = data[p:]
 		key := reflect.New(t.Key())
 		if err := Unmarshal(mapInfo.Key, data[:m], key.Interface()); err != nil {
@@ -1555,7 +1561,10 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error {
 		}
 		data = data[m:]
 
-		m, p = readCollectionSize(mapInfo, data)
+		m, p, err = readCollectionSize(mapInfo, data)
+		if err != nil {
+			return err
+		}
 		data = data[p:]
 		val := reflect.New(t.Elem())
 		if err := Unmarshal(mapInfo.Elem, data[:m], val.Interface()); err != nil {