Browse Source

Merge pull request #734 from Zariel/tiny-ints

marshall: support small and tiny ints, fix uint
Chris Bannister 9 years ago
parent
commit
405d8024aa
3 changed files with 341 additions and 26 deletions
  1. 4 0
      helpers.go
  2. 250 21
      marshal.go
  3. 87 5
      marshal_test.go

+ 4 - 0
helpers.go

@@ -136,6 +136,10 @@ func getApacheCassandraType(class string) Type {
 		return TypeFloat
 	case "Int32Type":
 		return TypeInt
+	case "ShortType":
+		return TypeSmallInt
+	case "ByteType":
+		return TypeTinyInt
 	case "DateType", "TimestampType":
 		return TypeTimestamp
 	case "UUIDType", "LexicalUUIDType":

+ 250 - 21
marshal.go

@@ -66,6 +66,10 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
 		return marshalVarchar(info, value)
 	case TypeBoolean:
 		return marshalBool(info, value)
+	case TypeTinyInt:
+		return marshalTinyInt(info, value)
+	case TypeSmallInt:
+		return marshalSmallInt(info, value)
 	case TypeInt:
 		return marshalInt(info, value)
 	case TypeBigInt, TypeCounter:
@@ -125,6 +129,10 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
 		return unmarshalBigInt(info, data, value)
 	case TypeVarint:
 		return unmarshalVarint(info, data, value)
+	case TypeSmallInt:
+		return unmarshalSmallInt(info, data, value)
+	case TypeTinyInt:
+		return unmarshalTinyInt(info, data, value)
 	case TypeFloat:
 		return unmarshalFloat(info, data, value)
 	case TypeDouble:
@@ -246,6 +254,164 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error {
 	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
 }
 
+func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case int16:
+		return encShort(v), nil
+	case uint16:
+		return encShort(int16(v)), nil
+	case int8:
+		return encShort(int16(v)), nil
+	case uint8:
+		return encShort(int16(v)), nil
+	case int:
+		if v > math.MaxInt16 || v < math.MinInt16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case int32:
+		if v > math.MaxInt16 || v < math.MinInt16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case int64:
+		if v > math.MaxInt16 || v < math.MinInt16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case uint:
+		if v > math.MaxUint16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case uint32:
+		if v > math.MaxUint16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case uint64:
+		if v > math.MaxUint16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case string:
+		n, err := strconv.ParseInt(v, 10, 16)
+		if err != nil {
+			return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err)
+		}
+		return encShort(int16(n)), nil
+	}
+
+	if value == nil {
+		return nil, nil
+	}
+
+	switch rv := reflect.ValueOf(value); rv.Type().Kind() {
+	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
+		v := rv.Int()
+		if v > math.MaxInt16 || v < math.MinInt16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
+		v := rv.Uint()
+		if v > math.MaxUint16 {
+			return nil, marshalErrorf("marshal smallint: value %d out of range", v)
+		}
+		return encShort(int16(v)), nil
+	default:
+		if rv.IsNil() {
+			return nil, nil
+		}
+	}
+
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case int8:
+		return []byte{byte(v)}, nil
+	case uint8:
+		return []byte{byte(v)}, nil
+	case int16:
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case uint16:
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case int:
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case int32:
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case int64:
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case uint:
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case uint32:
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case uint64:
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case string:
+		n, err := strconv.ParseInt(v, 10, 8)
+		if err != nil {
+			return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err)
+		}
+		return []byte{byte(n)}, nil
+	}
+
+	if value == nil {
+		return nil, nil
+	}
+
+	switch rv := reflect.ValueOf(value); rv.Type().Kind() {
+	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
+		v := rv.Int()
+		if v > math.MaxInt8 || v < math.MinInt8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
+		v := rv.Uint()
+		if v > math.MaxUint8 {
+			return nil, marshalErrorf("marshal tinyint: value %d out of range", v)
+		}
+		return []byte{byte(v)}, nil
+	default:
+		if rv.IsNil() {
+			return nil, nil
+		}
+	}
+
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
 func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	case Marshaler:
@@ -256,7 +422,7 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 		return encInt(int32(v)), nil
 	case uint:
-		if v > math.MaxInt32 {
+		if v > math.MaxUint32 {
 			return nil, marshalErrorf("marshal int: value %d out of range", v)
 		}
 		return encInt(int32(v)), nil
@@ -266,16 +432,13 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 		}
 		return encInt(int32(v)), nil
 	case uint64:
-		if v > math.MaxInt32 {
+		if v > math.MaxUint32 {
 			return nil, marshalErrorf("marshal int: value %d out of range", v)
 		}
 		return encInt(int32(v)), nil
 	case int32:
 		return encInt(v), nil
 	case uint32:
-		if v > math.MaxInt32 {
-			return nil, marshalErrorf("marshal int: value %d out of range", v)
-		}
 		return encInt(int32(v)), nil
 	case int16:
 		return encInt(int32(v)), nil
@@ -286,7 +449,7 @@ func marshalInt(info TypeInfo, value interface{}) ([]byte, error) {
 	case uint8:
 		return encInt(int32(v)), nil
 	case string:
-		i, err := strconv.ParseInt(value.(string), 10, 32)
+		i, err := strconv.ParseInt(v, 10, 32)
 		if err != nil {
 			return nil, marshalErrorf("can not marshal string to int: %s", err)
 		}
@@ -330,6 +493,27 @@ func decInt(x []byte) int32 {
 	return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3])
 }
 
+func encShort(x int16) []byte {
+	p := make([]byte, 2)
+	p[0] = byte(x >> 8)
+	p[1] = byte(x)
+	return p
+}
+
+func decShort(p []byte) int16 {
+	if len(p) != 2 {
+		return 0
+	}
+	return int16(p[0])<<8 | int16(p[1])
+}
+
+func decTiny(p []byte) int8 {
+	if len(p) != 1 {
+		return 0
+	}
+	return int8(p[0])
+}
+
 func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) {
 	switch v := value.(type) {
 	case Marshaler:
@@ -344,9 +528,6 @@ func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) {
 	case int64:
 		return encBigInt(v), nil
 	case uint64:
-		if v > math.MaxInt64 {
-			return nil, marshalErrorf("marshal bigint: value %d out of range", v)
-		}
 		return encBigInt(int64(v)), nil
 	case int32:
 		return encBigInt(int64(v)), nil
@@ -416,6 +597,14 @@ func unmarshalInt(info TypeInfo, data []byte, value interface{}) error {
 	return unmarshalIntlike(info, int64(decInt(data)), data, value)
 }
 
+func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error {
+	return unmarshalIntlike(info, int64(decShort(data)), data, value)
+}
+
+func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error {
+	return unmarshalIntlike(info, int64(decTiny(data)), data, value)
+}
+
 func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error {
 	switch v := value.(type) {
 	case *big.Int:
@@ -496,19 +685,35 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int(int64Val)
 		return nil
 	case *uint:
-		if int64Val < 0 || (^uint(0) == math.MaxUint32 && int64Val > math.MaxUint32) {
-			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
+		unitVal := uint64(int64Val)
+		if ^uint(0) == math.MaxUint32 && unitVal > math.MaxUint32 {
+			return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v)
+		}
+		switch info.Type() {
+		case TypeInt:
+			*v = uint(unitVal) & 0xFFFFFFFF
+		case TypeSmallInt:
+			*v = uint(unitVal) & 0xFFFF
+		case TypeTinyInt:
+			*v = uint(unitVal) & 0xFF
+		default:
+			*v = uint(unitVal)
 		}
-		*v = uint(int64Val)
 		return nil
 	case *int64:
 		*v = int64Val
 		return nil
 	case *uint64:
-		if int64Val < 0 {
-			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
+		switch info.Type() {
+		case TypeInt:
+			*v = uint64(int64Val) & 0xFFFFFFFF
+		case TypeSmallInt:
+			*v = uint64(int64Val) & 0xFFFF
+		case TypeTinyInt:
+			*v = uint64(int64Val) & 0xFF
+		default:
+			*v = uint64(int64Val)
 		}
-		*v = uint64(int64Val)
 		return nil
 	case *int32:
 		if int64Val < math.MinInt32 || int64Val > math.MaxInt32 {
@@ -517,10 +722,17 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int32(int64Val)
 		return nil
 	case *uint32:
-		if int64Val < 0 || int64Val > math.MaxUint32 {
+		if int64Val > math.MaxUint32 {
 			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
 		}
-		*v = uint32(int64Val)
+		switch info.Type() {
+		case TypeSmallInt:
+			*v = uint32(int64Val) & 0xFFFF
+		case TypeTinyInt:
+			*v = uint32(int64Val) & 0xFF
+		default:
+			*v = uint32(int64Val) & 0xFFFFFFFF
+		}
 		return nil
 	case *int16:
 		if int64Val < math.MinInt16 || int64Val > math.MaxInt16 {
@@ -529,10 +741,15 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int16(int64Val)
 		return nil
 	case *uint16:
-		if int64Val < 0 || int64Val > math.MaxUint16 {
+		if int64Val > math.MaxUint16 {
 			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
 		}
-		*v = uint16(int64Val)
+		switch info.Type() {
+		case TypeTinyInt:
+			*v = uint16(int64Val) & 0xFF
+		default:
+			*v = uint16(int64Val) & 0xFFFF
+		}
 		return nil
 	case *int8:
 		if int64Val < math.MinInt8 || int64Val > math.MaxInt8 {
@@ -541,10 +758,10 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
 		*v = int8(int64Val)
 		return nil
 	case *uint8:
-		if int64Val < 0 || int64Val > math.MaxUint8 {
+		if int64Val > math.MaxUint8 {
 			return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
 		}
-		*v = uint8(int64Val)
+		*v = uint8(int64Val) & 0xFF
 		return nil
 	case *big.Int:
 		decBigInt2C(data, v)
@@ -1682,6 +1899,10 @@ const (
 	TypeVarint    Type = 0x000E
 	TypeTimeUUID  Type = 0x000F
 	TypeInet      Type = 0x0010
+	TypeDate      Type = 0x0011
+	TypeTime      Type = 0x0012
+	TypeSmallInt  Type = 0x0013
+	TypeTinyInt   Type = 0x0014
 	TypeList      Type = 0x0020
 	TypeMap       Type = 0x0021
 	TypeSet       Type = 0x0022
@@ -1724,6 +1945,14 @@ func (t Type) String() string {
 		return "timeuuid"
 	case TypeInet:
 		return "inet"
+	case TypeDate:
+		return "date"
+	case TypeTime:
+		return "time"
+	case TypeSmallInt:
+		return "smallint"
+	case TypeTinyInt:
+		return "tinyint"
 	case TypeList:
 		return "list"
 	case TypeMap:

+ 87 - 5
marshal_test.go

@@ -552,6 +552,86 @@ var marshalTests = []struct {
 		[]byte(nil),
 		(*CustomString)(nil),
 	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x7f\xff"),
+		32767, // math.MaxInt16
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x7f\xff"),
+		"32767", // math.MaxInt16
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\x00\x01"),
+		int16(1),
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		int16(-1),
+	},
+	{
+		NativeType{proto: 2, typ: TypeSmallInt},
+		[]byte("\xff\xff"),
+		uint16(65535),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\x7f"),
+		127, // math.MaxInt8
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\x7f"),
+		"127", // math.MaxInt8
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\x01"),
+		int16(1),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		int16(-1),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint8(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint64(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint32(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint16(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeTinyInt},
+		[]byte("\xff"),
+		uint(255),
+	},
+	{
+		NativeType{proto: 2, typ: TypeBigInt},
+		[]byte("\xff\xff\xff\xff\xff\xff\xff\xff"),
+		uint64(math.MaxUint64),
+	},
+	{
+		NativeType{proto: 2, typ: TypeInt},
+		[]byte("\xff\xff\xff\xff"),
+		uint32(math.MaxUint32),
+	},
 }
 
 func decimalize(s string) *inf.Dec {
@@ -564,7 +644,7 @@ func bigintize(s string) *big.Int {
 	return i
 }
 
-func TestMarshal(t *testing.T) {
+func TestMarshal_Encode(t *testing.T) {
 	for i, test := range marshalTests {
 		data, err := Marshal(test.Info, test.Value)
 		if err != nil {
@@ -577,21 +657,21 @@ func TestMarshal(t *testing.T) {
 	}
 }
 
-func TestUnmarshal(t *testing.T) {
+func TestMarshal_Decode(t *testing.T) {
 	for i, test := range marshalTests {
 		if test.Value != nil {
 			v := reflect.New(reflect.TypeOf(test.Value))
 			err := Unmarshal(test.Info, test.Data, v.Interface())
 			if err != nil {
-				t.Errorf("unmarshalTest[%d]: %v", i, err)
+				t.Errorf("unmarshalTest[%d] (%v=>%T): %v", i, test.Info, test.Value, err)
 				continue
 			}
 			if !reflect.DeepEqual(v.Elem().Interface(), test.Value) {
-				t.Errorf("unmarshalTest[%d]: expected %#v, got %#v.", i, test.Value, v.Elem().Interface())
+				t.Errorf("unmarshalTest[%d] (%v=>%T): expected %#v, got %#v.", i, test.Info, test.Value, test.Value, v.Elem().Interface())
 			}
 		} else {
 			if err := Unmarshal(test.Info, test.Data, test.Value); nil == err {
-				t.Errorf("unmarshalTest[%d]: %#v not return error.", i, test.Value)
+				t.Errorf("unmarshalTest[%d] (%v=>%t): %#v not return error.", i, test.Info, test.Value, test.Value)
 			}
 		}
 	}
@@ -831,6 +911,8 @@ var typeLookupTest = []struct {
 	{"ListType", TypeList},
 	{"SetType", TypeSet},
 	{"unknown", TypeCustom},
+	{"ShortType", TypeSmallInt},
+	{"ByteType", TypeTinyInt},
 }
 
 func testType(t *testing.T, cassType string, expectedType Type) {