Browse Source

support for marshalling inet values using strings and net.IP with IPv6 tests
added integration test and updated authors
fixed formatting error returned by vet
added SliceMap test as per comments

dankennedy 11 years ago
parent
commit
051c1957ac
5 changed files with 136 additions and 13 deletions
  1. 1 0
      AUTHORS
  2. 67 6
      cassandra_test.go
  3. 1 1
      helpers.go
  4. 26 6
      marshal.go
  5. 41 0
      marshal_test.go

+ 1 - 0
AUTHORS

@@ -30,3 +30,4 @@ Dan Simmons <dan@simmons.io>
 Muir Manders <muir@retailnext.net>
 Muir Manders <muir@retailnext.net>
 Sankar P <sankar.curiosity@gmail.com>
 Sankar P <sankar.curiosity@gmail.com>
 Julien Da Silva <julien.dasilva@gmail.com>
 Julien Da Silva <julien.dasilva@gmail.com>
+Dan Kennedy <daniel@firstcs.co.uk>

+ 67 - 6
cassandra_test.go

@@ -11,6 +11,7 @@ import (
 	"log"
 	"log"
 	"math"
 	"math"
 	"math/big"
 	"math/big"
+	"net"
 	"reflect"
 	"reflect"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -425,7 +426,8 @@ func TestSliceMap(t *testing.T) {
 			testlist       list<text>,
 			testlist       list<text>,
 			testset        set<int>,
 			testset        set<int>,
 			testmap        map<varchar, varchar>,
 			testmap        map<varchar, varchar>,
-			testvarint     varint
+			testvarint     varint,
+			testinet			 inet
 		)`); err != nil {
 		)`); err != nil {
 		t.Fatal("create table:", err)
 		t.Fatal("create table:", err)
 	}
 	}
@@ -450,9 +452,10 @@ func TestSliceMap(t *testing.T) {
 	m["testset"] = []int{1, 2, 3, 4, 5, 6, 7, 8, 9}
 	m["testset"] = []int{1, 2, 3, 4, 5, 6, 7, 8, 9}
 	m["testmap"] = map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"}
 	m["testmap"] = map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"}
 	m["testvarint"] = bigInt
 	m["testvarint"] = bigInt
+	m["testinet"] = "213.212.2.19"
 	sliceMap := []map[string]interface{}{m}
 	sliceMap := []map[string]interface{}{m}
-	if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testlist, testset, testmap, testvarint) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
-		m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testlist"], m["testset"], m["testmap"], m["testvarint"]).Exec(); err != nil {
+	if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+		m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testlist"], m["testset"], m["testmap"], m["testvarint"], m["testinet"]).Exec(); err != nil {
 		t.Fatal("insert:", err)
 		t.Fatal("insert:", err)
 	}
 	}
 	if returned, retErr := session.Query(`SELECT * FROM slice_map_table`).Iter().SliceMap(); retErr != nil {
 	if returned, retErr := session.Query(`SELECT * FROM slice_map_table`).Iter().SliceMap(); retErr != nil {
@@ -507,6 +510,10 @@ func TestSliceMap(t *testing.T) {
 		if expectedBigInt.Cmp(returnedBigInt) != 0 {
 		if expectedBigInt.Cmp(returnedBigInt) != 0 {
 			t.Fatal("returned testvarint did not match")
 			t.Fatal("returned testvarint did not match")
 		}
 		}
+
+		if sliceMap[0]["testinet"] != returned[0]["testinet"] {
+			t.Fatal("returned testinet did not match")
+		}
 	}
 	}
 
 
 	// Test for MapScan()
 	// Test for MapScan()
@@ -538,8 +545,8 @@ func TestSliceMap(t *testing.T) {
 	if sliceMap[0]["testdouble"] != testMap["testdouble"] {
 	if sliceMap[0]["testdouble"] != testMap["testdouble"] {
 		t.Fatal("returned testdouble did not match")
 		t.Fatal("returned testdouble did not match")
 	}
 	}
-	if sliceMap[0]["testint"] != testMap["testint"] {
-		t.Fatal("returned testint did not match")
+	if sliceMap[0]["testinet"] != testMap["testinet"] {
+		t.Fatal("returned testinet did not match")
 	}
 	}
 
 
 	expectedDecimal := sliceMap[0]["testdecimal"].(*inf.Dec)
 	expectedDecimal := sliceMap[0]["testdecimal"].(*inf.Dec)
@@ -558,7 +565,9 @@ func TestSliceMap(t *testing.T) {
 	if !reflect.DeepEqual(sliceMap[0]["testmap"], testMap["testmap"]) {
 	if !reflect.DeepEqual(sliceMap[0]["testmap"], testMap["testmap"]) {
 		t.Fatal("returned testmap did not match")
 		t.Fatal("returned testmap did not match")
 	}
 	}
-
+	if sliceMap[0]["testint"] != testMap["testint"] {
+		t.Fatal("returned testint did not match")
+	}
 }
 }
 
 
 func TestScanWithNilArguments(t *testing.T) {
 func TestScanWithNilArguments(t *testing.T) {
@@ -1084,6 +1093,58 @@ func TestMarshalFloat64Ptr(t *testing.T) {
 	}
 	}
 }
 }
 
 
+//TestMarshalInet tests to see that a pointer to a float64 is marshalled correctly.
+func TestMarshalInet(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE inet_test (ip inet, name text, primary key (ip))"); err != nil {
+		t.Fatal("create table:", err)
+	}
+	stringIp := "123.34.45.56"
+	if err := session.Query(`INSERT INTO inet_test (ip,name) VALUES (?,?)`, stringIp, "Test IP 1").Exec(); err != nil {
+		t.Fatal("insert string inet:", err)
+	}
+	var stringResult string
+	if err := session.Query("SELECT ip FROM inet_test").Scan(&stringResult); err != nil {
+		t.Fatalf("select for string from inet_test 1 failed: %v", err)
+	}
+	if stringResult != stringIp {
+		t.Errorf("Expected %s, was %s", stringIp, stringResult)
+	}
+
+	var ipResult net.IP
+	if err := session.Query("SELECT ip FROM inet_test").Scan(&ipResult); err != nil {
+		t.Fatalf("select for net.IP from inet_test 1 failed: %v", err)
+	}
+	if ipResult.String() != stringIp {
+		t.Errorf("Expected %s, was %s", stringIp, ipResult.String())
+	}
+
+	if err := session.Query(`DELETE FROM inet_test WHERE ip = ?`, stringIp).Exec(); err != nil {
+		t.Fatal("delete inet table:", err)
+	}
+
+	netIp := net.ParseIP("222.43.54.65")
+	if err := session.Query(`INSERT INTO inet_test (ip,name) VALUES (?,?)`, netIp, "Test IP 2").Exec(); err != nil {
+		t.Fatal("insert netIp inet:", err)
+	}
+
+	if err := session.Query("SELECT ip FROM inet_test").Scan(&stringResult); err != nil {
+		t.Fatalf("select for string from inet_test 2 failed: %v", err)
+	}
+	if stringResult != netIp.String() {
+		t.Errorf("Expected %s, was %s", netIp.String(), stringResult)
+	}
+	if err := session.Query("SELECT ip FROM inet_test").Scan(&ipResult); err != nil {
+		t.Fatalf("select for net.IP from inet_test 2 failed: %v", err)
+	}
+	if ipResult.String() != netIp.String() {
+		t.Errorf("Expected %s, was %s", netIp.String(), ipResult.String())
+	}
+
+}
+
 func TestVarint(t *testing.T) {
 func TestVarint(t *testing.T) {
 	session := createSession(t)
 	session := createSession(t)
 	defer session.Close()
 	defer session.Close()

+ 1 - 1
helpers.go

@@ -26,7 +26,7 @@ func (t *TypeInfo) New() interface{} {
 
 
 func goType(t *TypeInfo) reflect.Type {
 func goType(t *TypeInfo) reflect.Type {
 	switch t.Type {
 	switch t.Type {
-	case TypeVarchar, TypeAscii:
+	case TypeVarchar, TypeAscii, TypeInet:
 		return reflect.TypeOf(*new(string))
 		return reflect.TypeOf(*new(string))
 	case TypeBigInt, TypeCounter:
 	case TypeBigInt, TypeCounter:
 		return reflect.TypeOf(*new(int64))
 		return reflect.TypeOf(*new(int64))

+ 26 - 6
marshal.go

@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"fmt"
 	"math"
 	"math"
 	"math/big"
 	"math/big"
+	"net"
 	"reflect"
 	"reflect"
 	"time"
 	"time"
 
 
@@ -67,6 +68,8 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 		return marshalUUID(info, value)
 		return marshalUUID(info, value)
 	case TypeVarint:
 	case TypeVarint:
 		return marshalVarint(info, value)
 		return marshalVarint(info, value)
+	case TypeInet:
+		return marshalInet(info, value)
 	}
 	}
 	// TODO(tux21b): add the remaining types
 	// TODO(tux21b): add the remaining types
 	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
 	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
@@ -1040,22 +1043,39 @@ func unmarshalTimeUUID(info *TypeInfo, data []byte, value interface{}) error {
 	}
 	}
 }
 }
 
 
+func marshalInet(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch val := value.(type) {
+	case net.IP:
+		return val, nil
+	case []byte:
+		return val, nil
+	case string:
+		b := net.ParseIP(val)
+		if b != nil {
+			return b[:], nil
+		}
+		return nil, marshalErrorf("cannot marshal. invalid ip string %s", val)
+	}
+	return nil, marshalErrorf("cannot marshal %T into %s", value, info)
+}
+
 func unmarshalInet(info *TypeInfo, data []byte, value interface{}) error {
 func unmarshalInet(info *TypeInfo, data []byte, value interface{}) error {
 	switch v := value.(type) {
 	switch v := value.(type) {
 	case Unmarshaler:
 	case Unmarshaler:
 		return v.UnmarshalCQL(info, data)
 		return v.UnmarshalCQL(info, data)
+	case *net.IP:
+		*v = net.IP(data)
+		return nil
 	case *string:
 	case *string:
 		if len(data) == 0 {
 		if len(data) == 0 {
 			*v = ""
 			*v = ""
 			return nil
 			return nil
 		}
 		}
-		if len(data) == 4 {
-			*v = fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3])
-			return nil
-		}
-		// TODO: support IPv6
+		ip := net.IP(data)
+		*v = ip.String()
+		return nil
 	}
 	}
-	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+	return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
 }
 }
 
 
 // TypeInfo describes a Cassandra specific data type.
 // TypeInfo describes a Cassandra specific data type.

+ 41 - 0
marshal_test.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"bytes"
 	"math"
 	"math"
 	"math/big"
 	"math/big"
+	"net"
 	"reflect"
 	"reflect"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
@@ -272,6 +273,46 @@ var marshalTests = []struct {
 		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
 		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
 		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
 		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
 	},
 	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\x7F\x00\x00\x01"),
+		net.ParseIP("127.0.0.1"),
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
+		net.ParseIP("255.255.255.255"),
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\x7F\x00\x00\x01"),
+		"127.0.0.1",
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
+		"255.255.255.255",
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
+		"21da:d3:0:2f3b:2aa:ff:fe28:9c5a",
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
+		"fe80::202:b3ff:fe1e:8329",
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
+		net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"),
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
+		net.ParseIP("fe80::202:b3ff:fe1e:8329"),
+	},
 }
 }
 
 
 func decimalize(s string) *inf.Dec {
 func decimalize(s string) *inf.Dec {