Browse Source

Merge pull request #242 from dankennedy/master

Support for marshalling inet values using strings and net.IP
Ben Hood 11 years ago
parent
commit
00af47871a
5 changed files with 136 additions and 13 deletions
  1. 1 0
      AUTHORS
  2. 67 6
      cassandra_test.go
  3. 1 1
      helpers.go
  4. 26 6
      marshal.go
  5. 41 0
      marshal_test.go

+ 1 - 0
AUTHORS

@@ -30,3 +30,4 @@ Dan Simmons <dan@simmons.io>
 Muir Manders <muir@retailnext.net>
 Sankar P <sankar.curiosity@gmail.com>
 Julien Da Silva <julien.dasilva@gmail.com>
+Dan Kennedy <daniel@firstcs.co.uk>

+ 67 - 6
cassandra_test.go

@@ -11,6 +11,7 @@ import (
 	"log"
 	"math"
 	"math/big"
+	"net"
 	"reflect"
 	"strconv"
 	"strings"
@@ -425,7 +426,8 @@ func TestSliceMap(t *testing.T) {
 			testlist       list<text>,
 			testset        set<int>,
 			testmap        map<varchar, varchar>,
-			testvarint     varint
+			testvarint     varint,
+			testinet			 inet
 		)`); err != nil {
 		t.Fatal("create table:", err)
 	}
@@ -450,9 +452,10 @@ func TestSliceMap(t *testing.T) {
 	m["testset"] = []int{1, 2, 3, 4, 5, 6, 7, 8, 9}
 	m["testmap"] = map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"}
 	m["testvarint"] = bigInt
+	m["testinet"] = "213.212.2.19"
 	sliceMap := []map[string]interface{}{m}
-	if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testlist, testset, testmap, testvarint) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
-		m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testlist"], m["testset"], m["testmap"], m["testvarint"]).Exec(); err != nil {
+	if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+		m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testlist"], m["testset"], m["testmap"], m["testvarint"], m["testinet"]).Exec(); err != nil {
 		t.Fatal("insert:", err)
 	}
 	if returned, retErr := session.Query(`SELECT * FROM slice_map_table`).Iter().SliceMap(); retErr != nil {
@@ -507,6 +510,10 @@ func TestSliceMap(t *testing.T) {
 		if expectedBigInt.Cmp(returnedBigInt) != 0 {
 			t.Fatal("returned testvarint did not match")
 		}
+
+		if sliceMap[0]["testinet"] != returned[0]["testinet"] {
+			t.Fatal("returned testinet did not match")
+		}
 	}
 
 	// Test for MapScan()
@@ -538,8 +545,8 @@ func TestSliceMap(t *testing.T) {
 	if sliceMap[0]["testdouble"] != testMap["testdouble"] {
 		t.Fatal("returned testdouble did not match")
 	}
-	if sliceMap[0]["testint"] != testMap["testint"] {
-		t.Fatal("returned testint did not match")
+	if sliceMap[0]["testinet"] != testMap["testinet"] {
+		t.Fatal("returned testinet did not match")
 	}
 
 	expectedDecimal := sliceMap[0]["testdecimal"].(*inf.Dec)
@@ -558,7 +565,9 @@ func TestSliceMap(t *testing.T) {
 	if !reflect.DeepEqual(sliceMap[0]["testmap"], testMap["testmap"]) {
 		t.Fatal("returned testmap did not match")
 	}
-
+	if sliceMap[0]["testint"] != testMap["testint"] {
+		t.Fatal("returned testint did not match")
+	}
 }
 
 func TestScanWithNilArguments(t *testing.T) {
@@ -1084,6 +1093,58 @@ func TestMarshalFloat64Ptr(t *testing.T) {
 	}
 }
 
+//TestMarshalInet tests to see that a pointer to a float64 is marshalled correctly.
+func TestMarshalInet(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+
+	if err := createTable(session, "CREATE TABLE inet_test (ip inet, name text, primary key (ip))"); err != nil {
+		t.Fatal("create table:", err)
+	}
+	stringIp := "123.34.45.56"
+	if err := session.Query(`INSERT INTO inet_test (ip,name) VALUES (?,?)`, stringIp, "Test IP 1").Exec(); err != nil {
+		t.Fatal("insert string inet:", err)
+	}
+	var stringResult string
+	if err := session.Query("SELECT ip FROM inet_test").Scan(&stringResult); err != nil {
+		t.Fatalf("select for string from inet_test 1 failed: %v", err)
+	}
+	if stringResult != stringIp {
+		t.Errorf("Expected %s, was %s", stringIp, stringResult)
+	}
+
+	var ipResult net.IP
+	if err := session.Query("SELECT ip FROM inet_test").Scan(&ipResult); err != nil {
+		t.Fatalf("select for net.IP from inet_test 1 failed: %v", err)
+	}
+	if ipResult.String() != stringIp {
+		t.Errorf("Expected %s, was %s", stringIp, ipResult.String())
+	}
+
+	if err := session.Query(`DELETE FROM inet_test WHERE ip = ?`, stringIp).Exec(); err != nil {
+		t.Fatal("delete inet table:", err)
+	}
+
+	netIp := net.ParseIP("222.43.54.65")
+	if err := session.Query(`INSERT INTO inet_test (ip,name) VALUES (?,?)`, netIp, "Test IP 2").Exec(); err != nil {
+		t.Fatal("insert netIp inet:", err)
+	}
+
+	if err := session.Query("SELECT ip FROM inet_test").Scan(&stringResult); err != nil {
+		t.Fatalf("select for string from inet_test 2 failed: %v", err)
+	}
+	if stringResult != netIp.String() {
+		t.Errorf("Expected %s, was %s", netIp.String(), stringResult)
+	}
+	if err := session.Query("SELECT ip FROM inet_test").Scan(&ipResult); err != nil {
+		t.Fatalf("select for net.IP from inet_test 2 failed: %v", err)
+	}
+	if ipResult.String() != netIp.String() {
+		t.Errorf("Expected %s, was %s", netIp.String(), ipResult.String())
+	}
+
+}
+
 func TestVarint(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()

+ 1 - 1
helpers.go

@@ -26,7 +26,7 @@ func (t *TypeInfo) New() interface{} {
 
 func goType(t *TypeInfo) reflect.Type {
 	switch t.Type {
-	case TypeVarchar, TypeAscii:
+	case TypeVarchar, TypeAscii, TypeInet:
 		return reflect.TypeOf(*new(string))
 	case TypeBigInt, TypeCounter:
 		return reflect.TypeOf(*new(int64))

+ 26 - 6
marshal.go

@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"math"
 	"math/big"
+	"net"
 	"reflect"
 	"time"
 
@@ -67,6 +68,8 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 		return marshalUUID(info, value)
 	case TypeVarint:
 		return marshalVarint(info, value)
+	case TypeInet:
+		return marshalInet(info, value)
 	}
 	// TODO(tux21b): add the remaining types
 	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
@@ -1040,22 +1043,39 @@ func unmarshalTimeUUID(info *TypeInfo, data []byte, value interface{}) error {
 	}
 }
 
+func marshalInet(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch val := value.(type) {
+	case net.IP:
+		return val, nil
+	case []byte:
+		return val, nil
+	case string:
+		b := net.ParseIP(val)
+		if b != nil {
+			return b[:], nil
+		}
+		return nil, marshalErrorf("cannot marshal. invalid ip string %s", val)
+	}
+	return nil, marshalErrorf("cannot marshal %T into %s", value, info)
+}
+
 func unmarshalInet(info *TypeInfo, data []byte, value interface{}) error {
 	switch v := value.(type) {
 	case Unmarshaler:
 		return v.UnmarshalCQL(info, data)
+	case *net.IP:
+		*v = net.IP(data)
+		return nil
 	case *string:
 		if len(data) == 0 {
 			*v = ""
 			return nil
 		}
-		if len(data) == 4 {
-			*v = fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3])
-			return nil
-		}
-		// TODO: support IPv6
+		ip := net.IP(data)
+		*v = ip.String()
+		return nil
 	}
-	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+	return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
 }
 
 // TypeInfo describes a Cassandra specific data type.

+ 41 - 0
marshal_test.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"math"
 	"math/big"
+	"net"
 	"reflect"
 	"strings"
 	"testing"
@@ -272,6 +273,46 @@ var marshalTests = []struct {
 		[]byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"),
 		bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite
 	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\x7F\x00\x00\x01"),
+		net.ParseIP("127.0.0.1"),
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
+		net.ParseIP("255.255.255.255"),
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\x7F\x00\x00\x01"),
+		"127.0.0.1",
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
+		"255.255.255.255",
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
+		"21da:d3:0:2f3b:2aa:ff:fe28:9c5a",
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
+		"fe80::202:b3ff:fe1e:8329",
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"),
+		net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"),
+	},
+	{
+		&TypeInfo{Type: TypeInet},
+		[]byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"),
+		net.ParseIP("fe80::202:b3ff:fe1e:8329"),
+	},
 }
 
 func decimalize(s string) *inf.Dec {