Browse Source

Add support for varints.

Similar to other integer types except you must demarshal varints
greater than 8 bytes using a *big.Int. It is possible to detect
unset varint fields when fetching using a **big.Int double pointer.
Maybe all types should support this as a way to detect unset cells as
opposed to cells with a value of zero?

Varints are signed integers encoded using two's complement. When marshaling,
gocql trims the encoded bytes down to the minimum required to encode the
integer value (e.g. so if you encode a varint value that can fit in one byte,
gocql will actually only send 1 byte).
Muir Manders 11 years ago
parent
commit
7eda240e1d
4 changed files with 340 additions and 13 deletions
  1. 109 6
      cassandra_test.go
  2. 5 1
      helpers.go
  3. 102 4
      marshal.go
  4. 124 2
      marshal_test.go

+ 109 - 6
cassandra_test.go

@@ -7,14 +7,17 @@ package gocql
 import (
 	"bytes"
 	"flag"
+	"math"
+	"math/big"
 	"reflect"
-	"speter.net/go/exp/math/dec/inf"
 	"strconv"
 	"strings"
 	"sync"
 	"testing"
 	"time"
 	"unicode"
+
+	"speter.net/go/exp/math/dec/inf"
 )
 
 var (
@@ -351,16 +354,23 @@ func TestSliceMap(t *testing.T) {
 			testbigint     bigint,
 			testblob       blob,
 			testbool       boolean,
-			testfloat	   float,
-			testdouble	   double,
+			testfloat      float,
+			testdouble     double,
 			testint        int,
 			testdecimal    decimal,
 			testset        set<int>,
-			testmap        map<varchar, varchar>
+			testmap        map<varchar, varchar>,
+			testvarint     varint
 		)`).Exec(); err != nil {
 		t.Fatal("create table:", err)
 	}
 	m := make(map[string]interface{})
+
+	bigInt := new(big.Int)
+	if _, ok := bigInt.SetString("830169365738487321165427203929228", 10); !ok {
+		t.Fatal("Failed setting bigint by string")
+	}
+
 	m["testuuid"] = TimeUUID()
 	m["testvarchar"] = "Test VarChar"
 	m["testbigint"] = time.Now().Unix()
@@ -373,9 +383,10 @@ func TestSliceMap(t *testing.T) {
 	m["testdecimal"] = inf.NewDec(100, 0)
 	m["testset"] = []int{1, 2, 3, 4, 5, 6, 7, 8, 9}
 	m["testmap"] = map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"}
+	m["testvarint"] = bigInt
 	sliceMap := []map[string]interface{}{m}
-	if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testset, testmap) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
-		m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testset"], m["testmap"]).Exec(); err != nil {
+	if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testset, testmap, testvarint) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+		m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testset"], m["testmap"], m["testvarint"]).Exec(); err != nil {
 		t.Fatal("insert:", err)
 	}
 	if returned, retErr := session.Query(`SELECT * FROM slice_map_table`).Iter().SliceMap(); retErr != nil {
@@ -421,6 +432,12 @@ func TestSliceMap(t *testing.T) {
 		if !reflect.DeepEqual(sliceMap[0]["testmap"], returned[0]["testmap"]) {
 			t.Fatal("returned testmap did not match")
 		}
+
+		expectedBigInt := sliceMap[0]["testvarint"].(*big.Int)
+		returnedBigInt := returned[0]["testvarint"].(*big.Int)
+		if expectedBigInt.Cmp(returnedBigInt) != 0 {
+			t.Fatal("returned testvarint did not match")
+		}
 	}
 
 	// Test for MapScan()
@@ -925,3 +942,89 @@ func TestMarshalFloat64Ptr(t *testing.T) {
 		t.Fatal("insert float64:", err)
 	}
 }
+
+func TestVarint(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := session.Query("CREATE TABLE varint_test (id varchar, test varint, test2 varint, primary key (id))").Exec(); err != nil {
+		t.Fatal("create table:", err)
+	}
+
+	if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", 0).Exec(); err != nil {
+		t.Fatalf("insert varint: %v", err)
+	}
+
+	var result int
+	if err := session.Query("SELECT test FROM varint_test").Scan(&result); err != nil {
+		t.Fatalf("select from varint_test failed: %v", err)
+	}
+
+	if result != 0 {
+		t.Errorf("Expected 0, was %d", result)
+	}
+
+	if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", -1).Exec(); err != nil {
+		t.Fatalf("insert varint: %v", err)
+	}
+
+	if err := session.Query("SELECT test FROM varint_test").Scan(&result); err != nil {
+		t.Fatalf("select from varint_test failed: %v", err)
+	}
+
+	if result != -1 {
+		t.Errorf("Expected -1, was %d", result)
+	}
+
+	if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", int64(math.MaxInt32)+1).Exec(); err != nil {
+		t.Fatalf("insert varint: %v", err)
+	}
+
+	var result64 int64
+	if err := session.Query("SELECT test FROM varint_test").Scan(&result64); err != nil {
+		t.Fatalf("select from varint_test failed: %v", err)
+	}
+
+	if result64 != int64(math.MaxInt32)+1 {
+		t.Errorf("Expected %d, was %d", int64(math.MaxInt32)+1, result64)
+	}
+
+	biggie := new(big.Int)
+	biggie.SetString("36893488147419103232", 10) // > 2**64
+	if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", biggie).Exec(); err != nil {
+		t.Fatalf("insert varint: %v", err)
+	}
+
+	resultBig := new(big.Int)
+	if err := session.Query("SELECT test FROM varint_test").Scan(resultBig); err != nil {
+		t.Fatalf("select from varint_test failed: %v", err)
+	}
+
+	if resultBig.String() != biggie.String() {
+		t.Errorf("Expected %s, was %s", biggie.String(), resultBig.String())
+	}
+
+	err := session.Query("SELECT test FROM varint_test").Scan(&result64)
+	if err == nil || strings.Index(err.Error(), "out of range") == -1 {
+		t.Errorf("expected our of range error since value is too big for int64")
+	}
+
+	// value not set in cassandra, leave bind variable empty
+	resultBig = new(big.Int)
+	if err := session.Query("SELECT test2 FROM varint_test").Scan(resultBig); err != nil {
+		t.Fatalf("select from varint_test failed: %v", err)
+	}
+
+	if resultBig.Int64() != 0 {
+		t.Errorf("Expected %s, was %s", biggie.String(), resultBig.String())
+	}
+
+	// can use double pointer to explicitly detect value is not set in cassandra
+	if err := session.Query("SELECT test2 FROM varint_test").Scan(&resultBig); err != nil {
+		t.Fatalf("select from varint_test failed: %v", err)
+	}
+
+	if resultBig != nil {
+		t.Errorf("Expected %v, was %v", nil, *resultBig)
+	}
+}

+ 5 - 1
helpers.go

@@ -5,10 +5,12 @@
 package gocql
 
 import (
+	"math/big"
 	"reflect"
-	"speter.net/go/exp/math/dec/inf"
 	"strings"
 	"time"
+
+	"speter.net/go/exp/math/dec/inf"
 )
 
 type RowData struct {
@@ -48,6 +50,8 @@ func goType(t *TypeInfo) reflect.Type {
 		return reflect.SliceOf(goType(t.Elem))
 	case TypeMap:
 		return reflect.MapOf(goType(t.Key), goType(t.Elem))
+	case TypeVarint:
+		return reflect.TypeOf(*new(*big.Int))
 	default:
 		return nil
 	}

+ 102 - 4
marshal.go

@@ -6,12 +6,14 @@ package gocql
 
 import (
 	"bytes"
+	"encoding/binary"
 	"fmt"
 	"math"
 	"math/big"
 	"reflect"
-	"speter.net/go/exp/math/dec/inf"
 	"time"
+
+	"speter.net/go/exp/math/dec/inf"
 )
 
 var (
@@ -59,6 +61,8 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 		return marshalMap(info, value)
 	case TypeUUID, TypeTimeUUID:
 		return marshalUUID(info, value)
+	case TypeVarint:
+		return marshalVarint(info, value)
 	}
 	// TODO(tux21b): add the remaining types
 	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
@@ -80,6 +84,8 @@ func Unmarshal(info *TypeInfo, data []byte, value interface{}) error {
 		return unmarshalInt(info, data, value)
 	case TypeBigInt, TypeCounter:
 		return unmarshalBigInt(info, data, value)
+	case TypeVarint:
+		return unmarshalVarint(info, data, value)
 	case TypeFloat:
 		return unmarshalFloat(info, data, value)
 	case TypeDouble:
@@ -266,6 +272,8 @@ func marshalBigInt(info *TypeInfo, value interface{}) ([]byte, error) {
 		return encBigInt(int64(v)), nil
 	case uint8:
 		return encBigInt(int64(v)), nil
+	case *big.Int:
+		return encBigInt2C(v), nil
 	}
 	rv := reflect.ValueOf(value)
 	switch rv.Type().Kind() {
@@ -289,6 +297,13 @@ func encBigInt(x int64) []byte {
 		byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)}
 }
 
+func bytesToInt64(data []byte) (ret int64) {
+	for i := range data {
+		ret |= int64(data[i]) << (8 * uint(len(data)-i-1))
+	}
+	return ret
+}
+
 func unmarshalBigInt(info *TypeInfo, data []byte, value interface{}) error {
 	return unmarshalIntlike(info, decBigInt(data), data, value)
 }
@@ -297,6 +312,72 @@ func unmarshalInt(info *TypeInfo, data []byte, value interface{}) error {
 	return unmarshalIntlike(info, int64(decInt(data)), data, value)
 }
 
+func unmarshalVarint(info *TypeInfo, data []byte, value interface{}) error {
+	switch value.(type) {
+	case *big.Int, **big.Int:
+		return unmarshalIntlike(info, 0, data, value)
+	}
+
+	if len(data) > 8 {
+		return unmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value)
+	}
+
+	int64Val := bytesToInt64(data)
+	if len(data) < 8 && data[0]&0x80 > 0 {
+		int64Val -= (1 << uint(len(data)*8))
+	}
+	return unmarshalIntlike(info, int64Val, data, value)
+}
+
+func marshalVarint(info *TypeInfo, value interface{}) ([]byte, error) {
+	var (
+		retBytes []byte
+		err      error
+	)
+
+	switch v := value.(type) {
+	case uint64:
+		if v > uint64(math.MaxInt64) {
+			retBytes = make([]byte, 9)
+			binary.BigEndian.PutUint64(retBytes[1:], v)
+		} else {
+			retBytes = make([]byte, 8)
+			binary.BigEndian.PutUint64(retBytes, v)
+		}
+	default:
+		retBytes, err = marshalBigInt(info, value)
+	}
+
+	if err == nil {
+		// trim down to most significant byte
+		i := 0
+		for ; i < len(retBytes)-1; i++ {
+			b0 := retBytes[i]
+			if b0 != 0 && b0 != 0xFF {
+				break
+			}
+
+			b1 := retBytes[i+1]
+			if b0 == 0 && b1 != 0 {
+				if b1&0x80 == 0 {
+					i++
+				}
+				break
+			}
+
+			if b0 == 0xFF && b1 != 0xFF {
+				if b1&0x80 > 0 {
+					i++
+				}
+				break
+			}
+		}
+		retBytes = retBytes[i:]
+	}
+
+	return retBytes, err
+}
+
 func unmarshalIntlike(info *TypeInfo, int64Val int64, data []byte, value interface{}) error {
 	switch v := value.(type) {
 	case *int:
@@ -356,12 +437,24 @@ func unmarshalIntlike(info *TypeInfo, int64Val int64, data []byte, value interfa
 		}
 		*v = uint8(int64Val)
 		return nil
+	case *big.Int:
+		decBigInt2C(data, v)
+		return nil
+	case **big.Int:
+		if len(data) == 0 {
+			*v = nil
+		} else {
+			*v = decBigInt2C(data, nil)
+		}
+		return nil
 	}
+
 	rv := reflect.ValueOf(value)
 	if rv.Kind() != reflect.Ptr {
 		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
 	}
 	rv = rv.Elem()
+
 	switch rv.Type().Kind() {
 	case reflect.Int:
 		if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) {
@@ -592,7 +685,7 @@ func unmarshalDecimal(info *TypeInfo, data []byte, value interface{}) error {
 	case **inf.Dec:
 		if len(data) > 4 {
 			scale := decInt(data[0:4])
-			unscaled := decBigInt2C(data[4:])
+			unscaled := decBigInt2C(data[4:], nil)
 			*v = inf.NewDecBig(unscaled, inf.Scale(scale))
 			return nil
 		} else if len(data) == 0 {
@@ -608,8 +701,11 @@ func unmarshalDecimal(info *TypeInfo, data []byte, value interface{}) error {
 // decBigInt2C sets the value of n to the big-endian two's complement
 // value stored in the given data. If data[0]&80 != 0, the number
 // is negative. If data is empty, the result will be 0.
-func decBigInt2C(data []byte) *big.Int {
-	n := new(big.Int).SetBytes(data)
+func decBigInt2C(data []byte, n *big.Int) *big.Int {
+	if n == nil {
+		n = new(big.Int)
+	}
+	n.SetBytes(data)
 	if len(data) > 0 && data[0]&0x80 > 0 {
 		n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8))
 	}
@@ -1044,6 +1140,8 @@ func (t Type) String() string {
 		return "map"
 	case TypeSet:
 		return "set"
+	case TypeVarint:
+		return "varint"
 	default:
 		return "unknown"
 	}

+ 124 - 2
marshal_test.go

@@ -3,11 +3,13 @@ package gocql
 import (
 	"bytes"
 	"math"
+	"math/big"
 	"reflect"
-	"speter.net/go/exp/math/dec/inf"
 	"strings"
 	"testing"
 	"time"
+
+	"speter.net/go/exp/math/dec/inf"
 )
 
 var marshalTests = []struct {
@@ -66,7 +68,7 @@ var marshalTests = []struct {
 	{
 		&TypeInfo{Type: TypeInt},
 		[]byte("\x01\x02\x03\x04"),
-		16909060,
+		int(16909060),
 	},
 	{
 		&TypeInfo{Type: TypeInt},
@@ -240,6 +242,36 @@ var marshalTests = []struct {
 			strings.Repeat("X", 65535): strings.Repeat("Y", 65535),
 		},
 	},
+	{
+		&TypeInfo{Type: TypeVarint},
+		[]byte("\x00"),
+		0,
+	},
+	{
+		&TypeInfo{Type: TypeVarint},
+		[]byte("\x37\xE2\x3C\xEC"),
+		int32(937573612),
+	},
+	{
+		&TypeInfo{Type: TypeVarint},
+		[]byte("\x37\xE2\x3C\xEC"),
+		big.NewInt(937573612),
+	},
+	{
+		&TypeInfo{Type: TypeVarint},
+		[]byte("\x03\x9EV \x15\f\x03\x9DK\x18\xCDI\\$?\a["),
+		bigintize("1231312312331283012830129382342342412123"), // From the iconara/cql-rb test suite
+	},
+	{
+		&TypeInfo{Type: TypeVarint},
+		[]byte("\xC9v\x8D:\x86"),
+		big.NewInt(-234234234234), // From the iconara/cql-rb test suite
+	},
+	{
+		&TypeInfo{Type: TypeVarint},
+		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
+		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
+	},
 }
 
 func decimalize(s string) *inf.Dec {
@@ -247,6 +279,11 @@ func decimalize(s string) *inf.Dec {
 	return i
 }
 
+func bigintize(s string) *big.Int {
+	i, _ := new(big.Int).SetString(s, 10)
+	return i
+}
+
 func TestMarshal(t *testing.T) {
 	for i, test := range marshalTests {
 		data, err := Marshal(test.Info, test.Value)
@@ -274,6 +311,91 @@ func TestUnmarshal(t *testing.T) {
 	}
 }
 
+func TestMarshalVarint(t *testing.T) {
+	varintTests := []struct {
+		Value       interface{}
+		Marshaled   []byte
+		Unmarshaled *big.Int
+	}{
+		{
+			Value:       int8(0),
+			Marshaled:   []byte("\x00"),
+			Unmarshaled: big.NewInt(0),
+		},
+		{
+			Value:       uint8(255),
+			Marshaled:   []byte("\x00\xFF"),
+			Unmarshaled: big.NewInt(255),
+		},
+		{
+			Value:       int8(-1),
+			Marshaled:   []byte("\xFF"),
+			Unmarshaled: big.NewInt(-1),
+		},
+		{
+			Value:       big.NewInt(math.MaxInt32),
+			Marshaled:   []byte("\x7F\xFF\xFF\xFF"),
+			Unmarshaled: big.NewInt(math.MaxInt32),
+		},
+		{
+			Value:       big.NewInt(int64(math.MaxInt32) + 1),
+			Marshaled:   []byte("\x00\x80\x00\x00\x00"),
+			Unmarshaled: big.NewInt(int64(math.MaxInt32) + 1),
+		},
+		{
+			Value:       big.NewInt(math.MinInt32),
+			Marshaled:   []byte("\x80\x00\x00\x00"),
+			Unmarshaled: big.NewInt(math.MinInt32),
+		},
+		{
+			Value:       big.NewInt(int64(math.MinInt32) - 1),
+			Marshaled:   []byte("\xFF\x7F\xFF\xFF\xFF"),
+			Unmarshaled: big.NewInt(int64(math.MinInt32) - 1),
+		},
+		{
+			Value:       math.MinInt64,
+			Marshaled:   []byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
+			Unmarshaled: big.NewInt(math.MinInt64),
+		},
+		{
+			Value:       uint64(math.MaxInt64) + 1,
+			Marshaled:   []byte("\x00\x80\x00\x00\x00\x00\x00\x00\x00"),
+			Unmarshaled: bigintize("9223372036854775808"),
+		},
+		{
+			Value:       bigintize("2361183241434822606848"), // 2**71
+			Marshaled:   []byte("\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00"),
+			Unmarshaled: bigintize("2361183241434822606848"),
+		},
+		{
+			Value:       bigintize("-9223372036854775809"), // -2**63 - 1
+			Marshaled:   []byte("\xFF\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF"),
+			Unmarshaled: bigintize("-9223372036854775809"),
+		},
+	}
+
+	for i, test := range varintTests {
+		data, err := Marshal(&TypeInfo{Type: TypeVarint}, test.Value)
+		if err != nil {
+			t.Errorf("error marshaling varint: %v (test #%d)", err, i)
+		}
+
+		if !bytes.Equal(test.Marshaled, data) {
+			t.Errorf("marshaled varint mismatch: expected %v, got %v (test #%d)", test.Marshaled, data, i)
+		}
+
+		binder := new(big.Int)
+		err = Unmarshal(&TypeInfo{Type: TypeVarint}, test.Marshaled, binder)
+		if err != nil {
+			t.Errorf("error unmarshaling varint: %v (test #%d)", err, i)
+		}
+
+		if test.Unmarshaled.Cmp(binder) != 0 {
+			t.Errorf("unmarshaled varint mismatch: expected %v, got %v (test #%d)", test.Unmarshaled, binder, i)
+		}
+	}
+}
+
 type CustomString string
 
 func (c CustomString) MarshalCQL(info *TypeInfo) ([]byte, error) {