瀏覽代碼

Merge pull request #154 from ChannelMeter/ApacheCassandraTypes

Support for org.apache.cassandra.db.marshal-namespaced custom types (See #151)
Chris Bannister 11 年之前
父節點
當前提交
4256984d6d
共有 4 個文件被更改,包括 95 次插入1 次删除
  1. 12 0
      frame.go
  2. 45 0
      helpers.go
  3. 1 1
      marshal.go
  4. 37 0
      marshal_test.go

+ 12 - 0
frame.go

@@ -59,6 +59,8 @@ const (
 	errUnprepared    = 0x2500
 
 	headerSize = 8
+
+	apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal."
 )
 
 type frame []byte
@@ -243,6 +245,16 @@ func (f *frame) readTypeInfo() *TypeInfo {
 	switch typ.Type {
 	case TypeCustom:
 		typ.Custom = f.readString()
+		if cassType := getApacheCassandraType(typ.Custom); cassType != TypeCustom {
+			typ = &TypeInfo{Type: cassType}
+			switch typ.Type {
+			case TypeMap:
+				typ.Key = f.readTypeInfo()
+				fallthrough
+			case TypeList, TypeSet:
+				typ.Elem = f.readTypeInfo()
+			}
+		}
 	case TypeMap:
 		typ.Key = f.readTypeInfo()
 		fallthrough

+ 45 - 0
helpers.go

@@ -7,6 +7,7 @@ package gocql
 import (
 	"reflect"
 	"speter.net/go/exp/math/dec/inf"
+	"strings"
 	"time"
 )
 
@@ -56,6 +57,50 @@ func dereference(i interface{}) interface{} {
 	return reflect.Indirect(reflect.ValueOf(i)).Interface()
 }
 
+func getApacheCassandraType(class string) Type {
+	if strings.HasPrefix(class, apacheCassandraTypePrefix) {
+		switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
+		case "AsciiType":
+			return TypeAscii
+		case "LongType":
+			return TypeBigInt
+		case "BytesType":
+			return TypeBlob
+		case "BooleanType":
+			return TypeBoolean
+		case "CounterColumnType":
+			return TypeCounter
+		case "DecimalType":
+			return TypeDecimal
+		case "DoubleType":
+			return TypeDouble
+		case "FloatType":
+			return TypeFloat
+		case "Int32Type":
+			return TypeInt
+		case "DateType":
+			return TypeTimestamp
+		case "UUIDType":
+			return TypeUUID
+		case "UTF8Type":
+			return TypeVarchar
+		case "IntegerType":
+			return TypeVarint
+		case "TimeUUIDType":
+			return TypeTimeUUID
+		case "InetAddressType":
+			return TypeInet
+		case "MapType":
+			return TypeMap
+		case "ListType":
+			return TypeInet
+		case "SetType":
+			return TypeInet
+		}
+	}
+	return TypeCustom
+}
+
 func (r *RowData) rowMap(m map[string]interface{}) {
 	for i, column := range r.Columns {
 		m[column] = dereference(r.Values[i])

+ 1 - 1
marshal.go

@@ -1098,7 +1098,7 @@ func (t TypeInfo) String() string {
 	case TypeList, TypeSet:
 		return fmt.Sprintf("%s(%s)", t.Type, t.Elem)
 	case TypeCustom:
-		return fmt.Sprintf("%s(%s)", t.Type, t.Elem)
+		return fmt.Sprintf("%s(%s)", t.Type, t.Custom)
 	}
 	return t.Type.String()
 }

+ 37 - 0
marshal_test.go

@@ -287,3 +287,40 @@ func (c *CustomString) UnmarshalCQL(info *TypeInfo, data []byte) error {
 type MyString string
 
 type MyInt int
+
+var typeLookupTest = []struct {
+	TypeName     string
+	ExpectedType Type
+}{
+	{"AsciiType", TypeAscii},
+	{"LongType", TypeBigInt},
+	{"BytesType", TypeBlob},
+	{"BooleanType", TypeBoolean},
+	{"CounterColumnType", TypeCounter},
+	{"DecimalType", TypeDecimal},
+	{"DoubleType", TypeDouble},
+	{"FloatType", TypeFloat},
+	{"Int32Type", TypeInt},
+	{"DateType", TypeTimestamp},
+	{"UUIDType", TypeUUID},
+	{"UTF8Type", TypeVarchar},
+	{"IntegerType", TypeVarint},
+	{"TimeUUIDType", TypeTimeUUID},
+	{"InetAddressType", TypeInet},
+	{"MapType", TypeMap},
+	{"ListType", TypeInet},
+	{"SetType", TypeInet},
+	{"unknown", TypeCustom},
+}
+
+func testType(t *testing.T, cassType string, expectedType Type) {
+	if computedType := getApacheCassandraType(apacheCassandraTypePrefix + cassType); computedType != expectedType {
+		t.Errorf("Cassandra custom type lookup for %s failed. Expected %s, got %s.", cassType, expectedType.String(), computedType.String())
+	}
+}
+
+func TestLookupCassType(t *testing.T) {
+	for _, lookupTest := range typeLookupTest {
+		testType(t, lookupTest.TypeName, lookupTest.ExpectedType)
+	}
+}