瀏覽代碼

metadata: fix all types being NativeType (#1052)

Parse the correct TypeInfo from the cassandra string in the db. Fixing
representation to recursivly parse the nested types.
Chris Bannister 7 年之前
父節點
當前提交
dd47639f78
共有 4 個文件被更改,包括 145 次插入23 次删除
  1. 47 11
      helpers.go
  2. 74 0
      helpers_test.go
  3. 11 0
      marshal.go
  4. 13 12
      metadata.go

+ 47 - 11
helpers.go

@@ -68,7 +68,7 @@ func dereference(i interface{}) interface{} {
 	return reflect.Indirect(reflect.ValueOf(i)).Interface()
 }
 
-func getCassandraType(name string) Type {
+func getCassandraBaseType(name string) Type {
 	switch name {
 	case "ascii":
 		return TypeAscii
@@ -92,8 +92,10 @@ func getCassandraType(name string) Type {
 		return TypeTimestamp
 	case "uuid":
 		return TypeUUID
-	case "varchar", "text":
+	case "varchar":
 		return TypeVarchar
+	case "text":
+		return TypeText
 	case "varint":
 		return TypeVarint
 	case "timeuuid":
@@ -109,19 +111,53 @@ func getCassandraType(name string) Type {
 	case "TupleType":
 		return TypeTuple
 	default:
-		if strings.HasPrefix(name, "set") {
-			return TypeSet
-		} else if strings.HasPrefix(name, "list") {
-			return TypeList
-		} else if strings.HasPrefix(name, "map") {
-			return TypeMap
-		} else if strings.HasPrefix(name, "tuple") {
-			return TypeTuple
-		}
 		return TypeCustom
 	}
 }
 
+func getCassandraType(name string) TypeInfo {
+	if strings.HasPrefix(name, "frozen<") {
+		return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"))
+	} else if strings.HasPrefix(name, "set<") {
+		return CollectionType{
+			NativeType: NativeType{typ: TypeSet},
+			Elem:       getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<")),
+		}
+	} else if strings.HasPrefix(name, "list<") {
+		return CollectionType{
+			NativeType: NativeType{typ: TypeList},
+			Elem:       getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<")),
+		}
+	} else if strings.HasPrefix(name, "map<") {
+		names := strings.Split(strings.TrimPrefix(name[:len(name)-1], "map<"), ", ")
+		if len(names) != 2 {
+			panic(fmt.Sprintf("invalid map type: %v", name))
+		}
+
+		return CollectionType{
+			NativeType: NativeType{typ: TypeMap},
+			Key:        getCassandraType(names[0]),
+			Elem:       getCassandraType(names[1]),
+		}
+	} else if strings.HasPrefix(name, "tuple<") {
+		names := strings.Split(strings.TrimPrefix(name[:len(name)-1], "tuple<"), ", ")
+		types := make([]TypeInfo, len(names))
+
+		for i, name := range names {
+			types[i] = getCassandraType(name)
+		}
+
+		return TupleTypeInfo{
+			NativeType: NativeType{typ: TypeTuple},
+			Elems:      types,
+		}
+	} else {
+		return NativeType{
+			typ: getCassandraBaseType(name),
+		}
+	}
+}
+
 func getApacheCassandraType(class string) Type {
 	switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
 	case "AsciiType":

+ 74 - 0
helpers_test.go

@@ -0,0 +1,74 @@
+package gocql
+
+import (
+	"reflect"
+	"testing"
+)
+
+func TestGetCassandraType_Set(t *testing.T) {
+	typ := getCassandraType("set<text>")
+	set, ok := typ.(CollectionType)
+	if !ok {
+		t.Fatalf("expected CollectionType got %T", typ)
+	} else if set.typ != TypeSet {
+		t.Fatalf("expected type %v got %v", TypeSet, set.typ)
+	}
+
+	inner, ok := set.Elem.(NativeType)
+	if !ok {
+		t.Fatalf("expected to get NativeType got %T", set.Elem)
+	} else if inner.typ != TypeText {
+		t.Fatalf("expected to get %v got %v for set value", TypeText, set.typ)
+	}
+}
+
+func TestGetCassandraType(t *testing.T) {
+	tests := []struct {
+		input string
+		exp   TypeInfo
+	}{
+		{
+			"set<text>", CollectionType{
+				NativeType: NativeType{typ: TypeSet},
+
+				Elem: NativeType{typ: TypeText},
+			},
+		},
+		{
+			"map<text, varchar>", CollectionType{
+				NativeType: NativeType{typ: TypeMap},
+
+				Key:  NativeType{typ: TypeText},
+				Elem: NativeType{typ: TypeVarchar},
+			},
+		},
+		{
+			"list<int>", CollectionType{
+				NativeType: NativeType{typ: TypeList},
+				Elem:       NativeType{typ: TypeInt},
+			},
+		},
+		{
+			"tuple<int, int, text>", TupleTypeInfo{
+				NativeType: NativeType{typ: TypeTuple},
+
+				Elems: []TypeInfo{
+					NativeType{typ: TypeInt},
+					NativeType{typ: TypeInt},
+					NativeType{typ: TypeText},
+				},
+			},
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.input, func(t *testing.T) {
+			got := getCassandraType(test.input)
+
+			// TODO(zariel): define an equal method on the types?
+			if !reflect.DeepEqual(got, test.exp) {
+				t.Fatalf("expected %v got %v", test.exp, got)
+			}
+		})
+	}
+}

+ 11 - 0
marshal.go

@@ -2053,6 +2053,17 @@ type TupleTypeInfo struct {
 	Elems []TypeInfo
 }
 
+func (t TupleTypeInfo) String() string {
+	var buf bytes.Buffer
+	buf.WriteString(fmt.Sprintf("%s(", t.typ))
+	for _, elem := range t.Elems {
+		buf.WriteString(fmt.Sprintf("%s, ", elem))
+	}
+	buf.Truncate(buf.Len() - 2)
+	buf.WriteByte(')')
+	return buf.String()
+}
+
 func (t TupleTypeInfo) New() interface{} {
 	return reflect.New(goType(t)).Interface()
 }

+ 13 - 12
metadata.go

@@ -226,25 +226,26 @@ func compileMetadata(
 
 	// add columns from the schema data
 	for i := range columns {
+		col := &columns[i]
 		// decode the validator for TypeInfo and order
-		if columns[i].ClusteringOrder != "" { // Cassandra 3.x+
-			columns[i].Type = NativeType{typ: getCassandraType(columns[i].Validator)}
-			columns[i].Order = ASC
-			if columns[i].ClusteringOrder == "desc" {
-				columns[i].Order = DESC
+		if col.ClusteringOrder != "" { // Cassandra 3.x+
+			col.Type = getCassandraType(col.Validator)
+			col.Order = ASC
+			if col.ClusteringOrder == "desc" {
+				col.Order = DESC
 			}
 		} else {
-			validatorParsed := parseType(columns[i].Validator)
-			columns[i].Type = validatorParsed.types[0]
-			columns[i].Order = ASC
+			validatorParsed := parseType(col.Validator)
+			col.Type = validatorParsed.types[0]
+			col.Order = ASC
 			if validatorParsed.reversed[0] {
-				columns[i].Order = DESC
+				col.Order = DESC
 			}
 		}
 
-		table := keyspace.Tables[columns[i].Table]
-		table.Columns[columns[i].Name] = &columns[i]
-		table.OrderedColumns = append(table.OrderedColumns, columns[i].Name)
+		table := keyspace.Tables[col.Table]
+		table.Columns[col.Name] = col
+		table.OrderedColumns = append(table.OrderedColumns, col.Name)
 	}
 
 	if protoVersion == protoVersion1 {