فهرست منبع

marshall: support small and tiny ints, fix uint

Handle cql small int and tiny int, these are 16 bit and 8 bit signed
ints.

Fix unmarshalling uints to correctly set the max value of the int when
the sign bit is extended to a larger type. This fixes it so that it
correctly unmarshalls a tiny int of value 0xFF into any uint 0xFF
instead of the max value of the unit type.
Chris Bannister 10 سال پیش
والد
کامیت
e1ed1fb456
3فایلهای تغییر یافته به همراه341 افزوده شده و 26 حذف شده
  1. 4 0
      helpers.go
  2. 250 21
      marshal.go
  3. 87 5
      marshal_test.go

+ 4 - 0
helpers.go

@@ -136,6 +136,10 @@ func getApacheCassandraType(class string) Type {
 		return TypeFloat
 	case "Int32Type":
 		return TypeInt
+	case "ShortType":
+		return TypeSmallInt
+	case "ByteType":
+		return TypeTinyInt
 	case "DateType", "TimestampType":
 		return TypeTimestamp
 	case "UUIDType", "LexicalUUIDType":

+ 250 - 21
marshal.go

@@ -66,6 +66,10 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
 		return marshalVarchar(info, value)
 	case TypeBoolean:
 		return marshalBool(info, value)
+	case TypeTinyInt:
+		return marshalTinyInt(info, value)
+	case TypeSmallInt:
+		return marshalSmallInt(info, value)
 	case TypeInt:
 		return marshalInt(info, value)
 	case TypeBigInt, TypeCounter:
@@ -125,6 +129,10 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
 		return unmarshalBigInt(info, data, value)
 	case TypeVarint:
 		return unmarshalVarint(info, data, value)
+	case TypeSmallInt:
+		return unmarshalSmallInt(info, data, value)
+	case TypeTinyInt:
+		return unmarshalTinyInt(info, data, value)
 	case TypeFloat:
 		return unmarshalFloat(info, data, value)
 	case TypeDouble:
@@ -246,6 +254,164 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error {
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }
 
+func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case int16:
+		return encShort(v), nil
+	case uint16:
+		return encShort(int16(v)), nil
+	case int8:
+		return encShort(int16(v)), nil
+	case uint8:
+		return encShort(int16(v)), nil
+	case int:
+		if v > math.MaxInt16 || v < math.MinInt16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case int32:
+		if v > math.MaxInt16 || v < math.MinInt16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case int64:
+		if v > math.MaxInt16 || v < math.MinInt16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case uint:
+		if v > math.MaxUint16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case uint32:
+		if v > math.MaxUint16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case uint64:
+		if v > math.MaxUint16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case string:
+		n, err := strconv.ParseInt(v, 10, 16)
+		if err != nil {
+			return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err)
+		}
+		return encShort(int16(n)), nil
+	}
+
+	if value == nil {
+		return nil, nil
+	}
+
+	switch rv := reflect.ValueOf(value); rv.Type().Kind() {
+	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
+		v := rv.Int()
+		if v > math.MaxInt16 || v < math.MinInt16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
+		v := rv.Uint()
+		if v > math.MaxUint16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	default:
+		if rv.IsNil() {
+			return nil, nil
+		}
+	}
+
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case int8:
+		return []byte{byte(v)}, nil
+	case uint8:
+		return []byte{byte(v)}, nil
+	case int16:
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case uint16:
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case int:
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case int32:
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case int64:
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case uint:
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case uint32:
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case uint64:
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case string:
+		n, err := strconv.ParseInt(v, 10, 8)
+		if err != nil {
+			return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err)
+		}
+		return []byte{byte(n)}, nil
+	}
+
+	if value == nil {
+		return nil, nil
+	}
+
+	switch rv := reflect.ValueOf(value); rv.Type().Kind() {
+	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
+		v := rv.Int()
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
+		v := rv.Uint()
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	default:
+		if rv.IsNil() {
+			return nil, nil
+		}
+	}
+
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
 func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	case Marshaler:
@@ -256,7 +422,7 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 		return encInt(int32(v)), nil
 	case uint:
-		if v > math.MaxInt32 {
+		if v > math.MaxUint32 {
 			return nil, marshalErrorf("marshal int: value %d out of range", v)
 		}
 		return encInt(int32(v)), nil
@@ -266,16 +432,13 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 		return encInt(int32(v)), nil
 	case uint64:
-		if v > math.MaxInt32 {
+		if v > math.MaxUint32 {
 			return nil, marshalErrorf("marshal int: value %d out of range", v)
 		}
 		return encInt(int32(v)), nil
 	case int32:
 		return encInt(v), nil
 	case uint32:
-		if v > math.MaxInt32 {
-			return nil, marshalErrorf("marshal int: value %d out of range", v)
-		}
 		return encInt(int32(v)), nil
 	case int16:
 		return encInt(int32(v)), nil
@@ -286,7 +449,7 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 	case uint8:
 		return encInt(int32(v)), nil
 	case string:
-		i, err := strconv.ParseInt(value.(string), 10, 32)
+		i, err := strconv.ParseInt(v, 10, 32)
 		if err != nil {
 			return nil, marshalErrorf("can not marshal string to int: %s", err)
 		}
@@ -330,6 +493,27 @@ func decInt(x []byte) int32 {
 	return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3])
 }
 
+func encShort(x int16) []byte {
+	p := make([]byte, 2)
+	p[0] = byte(x >> 8)
+	p[1] = byte(x)
+	return p
+}
+
+func decShort(p []byte) int16 {
+	if len(p) != 2 {
+		return 0
+	}
+	return int16(p[0])<<8 | int16(p[1])
+}
+
+func decTiny(p []byte) int8 {
+	if len(p) != 1 {
+		return 0
+	}
+	return int8(p[0])
+}
+
 func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	case Marshaler:
@@ -344,9 +528,6 @@ func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) {
 	case int64:
 		return encBigInt(v), nil
 	case uint64:
-		if v > math.MaxInt64 {
-			return nil, marshalErrorf("marshal bigint: value %d out of range", v)
-		}
 		return encBigInt(int64(v)), nil
 	case int32:
 		return encBigInt(int64(v)), nil
@@ -416,6 +597,14 @@ func unmarshalInt(info TypeInfo, data []byte, value interface{}) error {
 	return unmarshalIntlike(info, int64(decInt(data)), data, value)
 }
 
+func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error {
+	return unmarshalIntlike(info, int64(decShort(data)), data, value)
+}
+
+func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error {
+	return unmarshalIntlike(info, int64(decTiny(data)), data, value)
+}
+
 func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error {
 	switch v := value.(type) {
 	case *big.Int:
@@ -496,19 +685,35 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int(int64Val)
 		return nil
 	case *uint:
-		if int64Val < 0 || (^uint(0) == math.MaxUint32 && int64Val > math.MaxUint32) {
-			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
+		unitVal := uint64(int64Val)
+		if ^uint(0) == math.MaxUint32 && unitVal > math.MaxUint32 {
+			return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v)
+		}
+		switch info.Type() {
+		case TypeInt:
+			*v = uint(unitVal) & 0xFFFFFFFF
+		case TypeSmallInt:
+			*v = uint(unitVal) & 0xFFFF
+		case TypeTinyInt:
+			*v = uint(unitVal) & 0xFF
+		default:
+			*v = uint(unitVal)
 		}
-		*v = uint(int64Val)
 		return nil
 	case *int64:
 		*v = int64Val
 		return nil
 	case *uint64:
-		if int64Val < 0 {
-			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
+		switch info.Type() {
+		case TypeInt:
+			*v = uint64(int64Val) & 0xFFFFFFFF
+		case TypeSmallInt:
+			*v = uint64(int64Val) & 0xFFFF
+		case TypeTinyInt:
+			*v = uint64(int64Val) & 0xFF
+		default:
+			*v = uint64(int64Val)
 		}
-		*v = uint64(int64Val)
 		return nil
 	case *int32:
 		if int64Val < math.MinInt32 || int64Val > math.MaxInt32 {
@@ -517,10 +722,17 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int32(int64Val)
 		return nil
 	case *uint32:
-		if int64Val < 0 || int64Val > math.MaxUint32 {
+		if int64Val > math.MaxUint32 {
 			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
 		}
-		*v = uint32(int64Val)
+		switch info.Type() {
+		case TypeSmallInt:
+			*v = uint32(int64Val) & 0xFFFF
+		case TypeTinyInt:
+			*v = uint32(int64Val) & 0xFF
+		default:
+			*v = uint32(int64Val) & 0xFFFFFFFF
+		}
 		return nil
 	case *int16:
 		if int64Val < math.MinInt16 || int64Val > math.MaxInt16 {
@@ -529,10 +741,15 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int16(int64Val)
 		return nil
 	case *uint16:
-		if int64Val < 0 || int64Val > math.MaxUint16 {
+		if int64Val > math.MaxUint16 {
 			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
 		}
-		*v = uint16(int64Val)
+		switch info.Type() {
+		case TypeTinyInt:
+			*v = uint16(int64Val) & 0xFF
+		default:
+			*v = uint16(int64Val) & 0xFFFF
+		}
 		return nil
 	case *int8:
 		if int64Val < math.MinInt8 || int64Val > math.MaxInt8 {
@@ -541,10 +758,10 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int8(int64Val)
 		return nil
 	case *uint8:
-		if int64Val < 0 || int64Val > math.MaxUint8 {
+		if int64Val > math.MaxUint8 {
 			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
 		}
-		*v = uint8(int64Val)
+		*v = uint8(int64Val) & 0xFF
 		return nil
 	case *big.Int:
 		decBigInt2C(data, v)
@@ -1682,6 +1899,10 @@ const (
 	TypeVarint    Type = 0x000E
 	TypeTimeUUID  Type = 0x000F
 	TypeInet      Type = 0x0010
+	TypeDate      Type = 0x0011
+	TypeTime      Type = 0x0012
+	TypeSmallInt  Type = 0x0013
+	TypeTinyInt   Type = 0x0014
 	TypeList      Type = 0x0020
 	TypeMap       Type = 0x0021
 	TypeSet       Type = 0x0022
@@ -1724,6 +1945,14 @@ func (t Type) String() string {
 		return "timeuuid"
 	case TypeInet:
 		return "inet"
+	case TypeDate:
+		return "date"
+	case TypeTime:
+		return "time"
+	case TypeSmallInt:
+		return "smallint"
+	case TypeTinyInt:
+		return "tinyint"
 	case TypeList:
 		return "list"
 	case TypeMap:

+ 87 - 5
marshal_test.go

@@ -552,6 +552,86 @@ var marshalTests = []struct {
 		[]byte(nil),
 		(*CustomString)(nil),
 	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x7f\xff"),
+		32767, // math.MaxInt16
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x7f\xff"),
+		"32767", // math.MaxInt16
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x00\x01"),
+		int16(1),
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		int16(-1),
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		uint16(65535),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\x7f"),
+		127, // math.MaxInt8
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\x7f"),
+		"127", // math.MaxInt8
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\x01"),
+		int16(1),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		int16(-1),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint8(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint64(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint32(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint16(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		uint64(math.MaxUint64),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\xff\xff\xff\xff"),
+		uint32(math.MaxUint32),
+	},
 }
 
 func decimalize(s string) *inf.Dec {
@@ -564,7 +644,7 @@ func bigintize(s string) *big.Int {
 	return i
 }
 
-func TestMarshal(t *testing.T) {
+func TestMarshal_Encode(t *testing.T) {
 	for i, test := range marshalTests {
 		data, err := Marshal(test.Info, test.Value)
 		if err != nil {
@@ -577,21 +657,21 @@ func TestMarshal(t *testing.T) {
 	}
 }
 
-func TestUnmarshal(t *testing.T) {
+func TestMarshal_Decode(t *testing.T) {
 	for i, test := range marshalTests {
 		if test.Value != nil {
 			v := reflect.New(reflect.TypeOf(test.Value))
 			err := Unmarshal(test.Info, test.Data, v.Interface())
 			if err != nil {
-				t.Errorf("unmarshalTest[%d]: %v", i, err)
+				t.Errorf("unmarshalTest[%d] (%v=>%T): %v", i, test.Info, test.Value, err)
 				continue
 			}
 			if !reflect.DeepEqual(v.Elem().Interface(), test.Value) {
-				t.Errorf("unmarshalTest[%d]: expected %#v, got %#v.", i, test.Value, v.Elem().Interface())
+				t.Errorf("unmarshalTest[%d] (%v=>%T): expected %#v, got %#v.", i, test.Info, test.Value, test.Value, v.Elem().Interface())
 			}
 		} else {
 			if err := Unmarshal(test.Info, test.Data, test.Value); nil == err {
-				t.Errorf("unmarshalTest[%d]: %#v not return error.", i, test.Value)
+				t.Errorf("unmarshalTest[%d] (%v=>%t): %#v not return error.", i, test.Info, test.Value, test.Value)
 			}
 		}
 	}
@@ -831,6 +911,8 @@ var typeLookupTest = []struct {
 	{"ListType", TypeList},
 	{"SetType", TypeSet},
 	{"unknown", TypeCustom},
+	{"ShortType", TypeSmallInt},
+	{"ByteType", TypeTinyInt},
 }
 
 func testType(t *testing.T, cassType string, expectedType Type) {