瀏覽代碼

Fix unmarshaling to unsigned integers (#1360)

There were a few inconsistencies with unmarshaling to unsigned integers:

- negative values were not range checked for primitive uint types
- negative values were not allowed at all for named uint types

This commit restricts the range of accepted values so that we marshal
the same value as we unmarshal. We accept negative values for a type
only if the width of cql type is smaller or equal than the width of Go
type. This ensures that unmarshaling is bijection from CQL to values
to Go values, without the check the following situation can happen:

- we unmarshal CQL smallint 0xffff (-1) to Go uint8 0xff
- we unmarshal CQL smallint 0x00ff (255) to Go uint8 0xff
- we marshal uint8 0xff to CQL smallint 0x00ff (255)

Therefore smallint value -1 would be turned into 255 during round-trip.

We also need to apply the same logic consistently regardless of whether
we are unmarshaling into a native or named type.
Martin Sucha 6 年之前
父節點
當前提交
9faa4c08d9
共有 2 個文件被更改,包括 382 次插入26 次删除
  1. 64 25
      marshal.go
  2. 318 1
      marshal_test.go

+ 64 - 25
marshal.go

@@ -713,9 +713,6 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		return nil
 		return nil
 	case *uint:
 	case *uint:
 		unitVal := uint64(int64Val)
 		unitVal := uint64(int64Val)
-		if ^uint(0) == math.MaxUint32 && unitVal > math.MaxUint32 {
-			return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v)
-		}
 		switch info.Type() {
 		switch info.Type() {
 		case TypeInt:
 		case TypeInt:
 			*v = uint(unitVal) & 0xFFFFFFFF
 			*v = uint(unitVal) & 0xFFFFFFFF
@@ -724,6 +721,9 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		case TypeTinyInt:
 		case TypeTinyInt:
 			*v = uint(unitVal) & 0xFF
 			*v = uint(unitVal) & 0xFF
 		default:
 		default:
+			if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) {
+				return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v)
+			}
 			*v = uint(unitVal)
 			*v = uint(unitVal)
 		}
 		}
 		return nil
 		return nil
@@ -749,15 +749,17 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int32(int64Val)
 		*v = int32(int64Val)
 		return nil
 		return nil
 	case *uint32:
 	case *uint32:
-		if int64Val > math.MaxUint32 {
-			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
-		}
 		switch info.Type() {
 		switch info.Type() {
+		case TypeInt:
+			*v = uint32(int64Val) & 0xFFFFFFFF
 		case TypeSmallInt:
 		case TypeSmallInt:
 			*v = uint32(int64Val) & 0xFFFF
 			*v = uint32(int64Val) & 0xFFFF
 		case TypeTinyInt:
 		case TypeTinyInt:
 			*v = uint32(int64Val) & 0xFF
 			*v = uint32(int64Val) & 0xFF
 		default:
 		default:
+			if int64Val < 0 || int64Val > math.MaxUint32 {
+				return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
+			}
 			*v = uint32(int64Val) & 0xFFFFFFFF
 			*v = uint32(int64Val) & 0xFFFFFFFF
 		}
 		}
 		return nil
 		return nil
@@ -768,13 +770,15 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int16(int64Val)
 		*v = int16(int64Val)
 		return nil
 		return nil
 	case *uint16:
 	case *uint16:
-		if int64Val > math.MaxUint16 {
-			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
-		}
 		switch info.Type() {
 		switch info.Type() {
+		case TypeSmallInt:
+			*v = uint16(int64Val) & 0xFFFF
 		case TypeTinyInt:
 		case TypeTinyInt:
 			*v = uint16(int64Val) & 0xFF
 			*v = uint16(int64Val) & 0xFF
 		default:
 		default:
+			if int64Val < 0 || int64Val > math.MaxUint16 {
+				return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
+			}
 			*v = uint16(int64Val) & 0xFFFF
 			*v = uint16(int64Val) & 0xFFFF
 		}
 		}
 		return nil
 		return nil
@@ -785,7 +789,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int8(int64Val)
 		*v = int8(int64Val)
 		return nil
 		return nil
 	case *uint8:
 	case *uint8:
-		if int64Val > math.MaxUint8 {
+		if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) {
 			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
 			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
 		}
 		}
 		*v = uint8(int64Val) & 0xFF
 		*v = uint8(int64Val) & 0xFF
@@ -833,34 +837,69 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		rv.SetInt(int64Val)
 		rv.SetInt(int64Val)
 		return nil
 		return nil
 	case reflect.Uint:
 	case reflect.Uint:
-		if int64Val < 0 || (^uint(0) == math.MaxUint32 && int64Val > math.MaxUint32) {
-			return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
+		unitVal := uint64(int64Val)
+		switch info.Type() {
+		case TypeInt:
+			rv.SetUint(unitVal & 0xFFFFFFFF)
+		case TypeSmallInt:
+			rv.SetUint(unitVal & 0xFFFF)
+		case TypeTinyInt:
+			rv.SetUint(unitVal & 0xFF)
+		default:
+			if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) {
+				return unmarshalErrorf("unmarshal int: value %d out of range for %s", unitVal, rv.Type())
+			}
+			rv.SetUint(unitVal)
 		}
 		}
-		rv.SetUint(uint64(int64Val))
 		return nil
 		return nil
 	case reflect.Uint64:
 	case reflect.Uint64:
-		if int64Val < 0 {
-			return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
+		unitVal := uint64(int64Val)
+		switch info.Type() {
+		case TypeInt:
+			rv.SetUint(unitVal & 0xFFFFFFFF)
+		case TypeSmallInt:
+			rv.SetUint(unitVal & 0xFFFF)
+		case TypeTinyInt:
+			rv.SetUint(unitVal & 0xFF)
+		default:
+			rv.SetUint(unitVal)
 		}
 		}
-		rv.SetUint(uint64(int64Val))
 		return nil
 		return nil
 	case reflect.Uint32:
 	case reflect.Uint32:
-		if int64Val < 0 || int64Val > math.MaxUint32 {
-			return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
+		unitVal := uint64(int64Val)
+		switch info.Type() {
+		case TypeInt:
+			rv.SetUint(unitVal & 0xFFFFFFFF)
+		case TypeSmallInt:
+			rv.SetUint(unitVal & 0xFFFF)
+		case TypeTinyInt:
+			rv.SetUint(unitVal & 0xFF)
+		default:
+			if int64Val < 0 || int64Val > math.MaxUint32 {
+				return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type())
+			}
+			rv.SetUint(unitVal & 0xFFFFFFFF)
 		}
 		}
-		rv.SetUint(uint64(int64Val))
 		return nil
 		return nil
 	case reflect.Uint16:
 	case reflect.Uint16:
-		if int64Val < 0 || int64Val > math.MaxUint16 {
-			return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
+		unitVal := uint64(int64Val)
+		switch info.Type() {
+		case TypeSmallInt:
+			rv.SetUint(unitVal & 0xFFFF)
+		case TypeTinyInt:
+			rv.SetUint(unitVal & 0xFF)
+		default:
+			if int64Val < 0 || int64Val > math.MaxUint16 {
+				return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type())
+			}
+			rv.SetUint(unitVal & 0xFFFF)
 		}
 		}
-		rv.SetUint(uint64(int64Val))
 		return nil
 		return nil
 	case reflect.Uint8:
 	case reflect.Uint8:
-		if int64Val < 0 || int64Val > math.MaxUint8 {
-			return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
+		if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) {
+			return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type())
 		}
 		}
-		rv.SetUint(uint64(int64Val))
+		rv.SetUint(uint64(int64Val) & 0xff)
 		return nil
 		return nil
 	}
 	}
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)

+ 318 - 1
marshal_test.go

@@ -16,6 +16,11 @@ import (
 )
 )
 
 
 type AliasInt int
 type AliasInt int
+type AliasUint uint
+type AliasUint8 uint8
+type AliasUint16 uint16
+type AliasUint32 uint32
+type AliasUint64 uint64
 
 
 var marshalTests = []struct {
 var marshalTests = []struct {
 	Info           TypeInfo
 	Info           TypeInfo
@@ -802,6 +807,13 @@ var marshalTests = []struct {
 		nil,
 		nil,
 		nil,
 		nil,
 	},
 	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x00\xff"),
+		uint8(255),
+		nil,
+		nil,
+	},
 	{
 	{
 		NativeType{proto: 2, typ: TypeSmallInt},
 		NativeType{proto: 2, typ: TypeSmallInt},
 		[]byte("\xff\xff"),
 		[]byte("\xff\xff"),
@@ -809,6 +821,55 @@ var marshalTests = []struct {
 		nil,
 		nil,
 		nil,
 		nil,
 	},
 	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		uint32(65535),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		uint64(65535),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x00\xff"),
+		AliasUint8(255),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		AliasUint16(65535),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		AliasUint32(65535),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		AliasUint64(65535),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		AliasUint(65535),
+		nil,
+		nil,
+	},
 	{
 	{
 		NativeType{proto: 2, typ: TypeTinyInt},
 		NativeType{proto: 2, typ: TypeTinyInt},
 		[]byte("\x7f"),
 		[]byte("\x7f"),
@@ -872,6 +933,62 @@ var marshalTests = []struct {
 		nil,
 		nil,
 		nil,
 		nil,
 	},
 	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		AliasUint8(255),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		AliasUint64(255),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		AliasUint32(255),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		AliasUint16(255),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		AliasUint(255),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\xff"),
+		uint8(math.MaxUint8),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x00\xff\xff"),
+		uint64(math.MaxUint16),
+		nil,
+		nil,
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\xff\xff\xff\xff"),
+		uint64(math.MaxUint32),
+		nil,
+		nil,
+	},
 	{
 	{
 		NativeType{proto: 2, typ: TypeBigInt},
 		NativeType{proto: 2, typ: TypeBigInt},
 		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
 		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
@@ -886,6 +1003,13 @@ var marshalTests = []struct {
 		nil,
 		nil,
 		nil,
 		nil,
 	},
 	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\xff\xff\xff\xff"),
+		uint64(math.MaxUint32),
+		nil,
+		nil,
+	},
 	{
 	{
 		NativeType{proto: 2, typ: TypeBlob},
 		NativeType{proto: 2, typ: TypeBlob},
 		[]byte(nil),
 		[]byte(nil),
@@ -912,6 +1036,182 @@ var marshalTests = []struct {
 	},
 	},
 }
 }
 
 
+var unmarshalTests = []struct {
+	Info           TypeInfo
+	Data           []byte
+	Value          interface{}
+	UnmarshalError error
+}{
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		uint8(0),
+		UnmarshalError("unmarshal int: value -1 out of range for uint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x01\x00"),
+		uint8(0),
+		UnmarshalError("unmarshal int: value 256 out of range for uint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\xff\xff\xff\xff"),
+		uint8(0),
+		UnmarshalError("unmarshal int: value -1 out of range for uint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\x00\x00\x01\x00"),
+		uint8(0),
+		UnmarshalError("unmarshal int: value 256 out of range for uint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\xff\xff\xff\xff"),
+		uint16(0),
+		UnmarshalError("unmarshal int: value -1 out of range for uint16"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\x00\x01\x00\x00"),
+		uint16(0),
+		UnmarshalError("unmarshal int: value 65536 out of range for uint16"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		uint8(0),
+		UnmarshalError("unmarshal int: value -1 out of range for uint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x00\x01\x00"),
+		uint8(0),
+		UnmarshalError("unmarshal int: value 256 out of range for uint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		uint8(0),
+		UnmarshalError("unmarshal int: value -1 out of range for uint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x00\x01\x00"),
+		uint8(0),
+		UnmarshalError("unmarshal int: value 256 out of range for uint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		uint16(0),
+		UnmarshalError("unmarshal int: value -1 out of range for uint16"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x01\x00\x00"),
+		uint16(0),
+		UnmarshalError("unmarshal int: value 65536 out of range for uint16"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		uint32(0),
+		UnmarshalError("unmarshal int: value -1 out of range for uint32"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x01\x00\x00\x00\x00"),
+		uint32(0),
+		UnmarshalError("unmarshal int: value 4294967296 out of range for uint32"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		AliasUint8(0),
+		UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x01\x00"),
+		AliasUint8(0),
+		UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\xff\xff\xff\xff"),
+		AliasUint8(0),
+		UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\x00\x00\x01\x00"),
+		AliasUint8(0),
+		UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\xff\xff\xff\xff"),
+		AliasUint16(0),
+		UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint16"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\x00\x01\x00\x00"),
+		AliasUint16(0),
+		UnmarshalError("unmarshal int: value 65536 out of range for gocql.AliasUint16"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		AliasUint8(0),
+		UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x00\x01\x00"),
+		AliasUint8(0),
+		UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		AliasUint8(0),
+		UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x00\x01\x00"),
+		AliasUint8(0),
+		UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		AliasUint16(0),
+		UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint16"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x01\x00\x00"),
+		AliasUint16(0),
+		UnmarshalError("unmarshal int: value 65536 out of range for gocql.AliasUint16"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		AliasUint32(0),
+		UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint32"),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\x00\x00\x00\x01\x00\x00\x00\x00"),
+		AliasUint32(0),
+		UnmarshalError("unmarshal int: value 4294967296 out of range for gocql.AliasUint32"),
+	},
+}
+
 func decimalize(s string) *inf.Dec {
 func decimalize(s string) *inf.Dec {
 	i, _ := new(inf.Dec).SetString(s)
 	i, _ := new(inf.Dec).SetString(s)
 	return i
 	return i
@@ -955,7 +1255,24 @@ func TestMarshal_Decode(t *testing.T) {
 			}
 			}
 		} else {
 		} else {
 			if err := Unmarshal(test.Info, test.Data, test.Value); err != test.UnmarshalError {
 			if err := Unmarshal(test.Info, test.Data, test.Value); err != test.UnmarshalError {
-				t.Errorf("unmarshalTest[%d] (%v=>%t): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError)
+				t.Errorf("unmarshalTest[%d] (%v=>%T): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError)
+			}
+		}
+	}
+	for i, test := range unmarshalTests {
+		v := reflect.New(reflect.TypeOf(test.Value))
+		if test.UnmarshalError == nil {
+			err := Unmarshal(test.Info, test.Data, v.Interface())
+			if err != nil {
+				t.Errorf("unmarshalTest[%d] (%v=>%T): %v", i, test.Info, test.Value, err)
+				continue
+			}
+			if !reflect.DeepEqual(v.Elem().Interface(), test.Value) {
+				t.Errorf("unmarshalTest[%d] (%v=>%T): expected %#v, got %#v.", i, test.Info, test.Value, test.Value, v.Elem().Interface())
+			}
+		} else {
+			if err := Unmarshal(test.Info, test.Data, v.Interface()); err != test.UnmarshalError {
+				t.Errorf("unmarshalTest[%d] (%v=>%T): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError)
 			}
 			}
 		}
 		}
 	}
 	}