Просмотр исходного кода

Add initial support for UDT's

Add support to read and write UDT's at the framing and scanning level.

To scan or marshal a UDT type the type must implement the Marshal/
UnmarshalUDT type.
Chris Bannister 10 лет назад
Родитель
Сommit
68ab73427f
4 измененных файлов с 235 добавлено и 8 удалено
  1. 16 0
      frame.go
  2. 126 0
      marshal.go
  3. 5 8
      session.go
  4. 88 0
      udt_test.go

+ 16 - 0
frame.go

@@ -612,6 +612,22 @@ func (f *framer) readTypeInfo() TypeInfo {
 
 		return tuple
 
+	case TypeUDT:
+		udt := UDTTypeInfo{
+			NativeType: simple,
+		}
+		udt.KeySpace = f.readString()
+		udt.Name = f.readString()
+
+		n := f.readShort()
+		udt.Elements = make([]UDTField, n)
+		for i := 0; i < int(n); i++ {
+			field := &udt.Elements[i]
+			field.Name = f.readString()
+			field.Type = f.readTypeInfo()
+		}
+
+		return udt
 	case TypeMap, TypeList, TypeSet:
 		collection := CollectionType{
 			NativeType: simple,

+ 126 - 0
marshal.go

@@ -83,6 +83,8 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
 		return marshalVarint(info, value)
 	case TypeInet:
 		return marshalInet(info, value)
+	case TypeUDT:
+		return marshalUDT(info, value)
 	}
 	// TODO(tux21b): add the remaining types
 	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
@@ -130,6 +132,8 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
 		return unmarshalInet(info, data, value)
 	case TypeTuple:
 		return unmarshalTuple(info, data, value)
+	case TypeUDT:
+		return unmarshalUDT(info, data, value)
 	}
 	// TODO(tux21b): add the remaining types
 	return fmt.Errorf("can not unmarshal %s into %T", info, value)
@@ -1198,6 +1202,97 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
 	return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
 }
 
+// UDTMarshaler is an interface which should be implemented by users wishing to
+// handle encoding UDT types to sent to Cassandra.
+type UDTMarshaler interface {
+	EncodeUDTField(name string, info TypeInfo) ([]byte, error)
+}
+
+type UDTUnmarshaler interface {
+	DecodeUDTField(name string, info TypeInfo, data []byte) error
+}
+
+func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
+	udt := info.(UDTTypeInfo)
+
+	switch v := value.(type) {
+	case UDTMarshaler:
+		var buf []byte
+		for _, e := range udt.Elements {
+			data, err := v.EncodeUDTField(e.Name, e.Type)
+			if err != nil {
+				return nil, err
+			}
+
+			n := len(data)
+			buf = append(buf, byte(n<<24),
+				byte(n<<16),
+				byte(n<<8),
+				byte(n))
+
+			buf = append(buf, data...)
+		}
+
+		return buf, nil
+	case map[string]interface{}:
+		var buf []byte
+		for _, e := range udt.Elements {
+			val, ok := v[e.Name]
+			if !ok {
+				return nil, marshalErrorf("missing UDT field in map: %s", e.Name)
+			}
+
+			data, err := Marshal(e.Type, val)
+			if err != nil {
+				return nil, err
+			}
+
+			n := len(data)
+			buf = append(buf, byte(n<<24),
+				byte(n<<16),
+				byte(n<<8),
+				byte(n))
+
+			buf = append(buf, data...)
+		}
+
+		return buf, nil
+	}
+
+	return nil, marshalErrorf("cannot marshal %T into %s", value, info)
+}
+
+func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case UDTUnmarshaler:
+		udt := info.(UDTTypeInfo)
+
+		for _, e := range udt.Elements {
+			size := readInt(data[:4])
+			data = data[4:]
+
+			var err error
+			if size < 0 {
+				err = v.DecodeUDTField(e.Name, e.Type, nil)
+			} else {
+				err = v.DecodeUDTField(e.Name, e.Type, data[:size])
+				data = data[size:]
+
+			}
+
+			if err != nil {
+				return err
+			}
+		}
+
+		return nil
+	}
+
+	return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
+}
+
 // TypeInfo describes a Cassandra specific data type.
 type TypeInfo interface {
 	Type() Type
@@ -1268,6 +1363,37 @@ type TupleTypeInfo struct {
 	Elems []TypeInfo
 }
 
+type UDTField struct {
+	Name string
+	Type TypeInfo
+}
+
+type UDTTypeInfo struct {
+	NativeType
+	KeySpace string
+	Name     string
+	Elements []UDTField
+}
+
+func (u UDTTypeInfo) String() string {
+	buf := &bytes.Buffer{}
+
+	fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name)
+	first := true
+	for _, e := range u.Elements {
+		if !first {
+			fmt.Fprint(buf, ",")
+		} else {
+			first = false
+		}
+
+		fmt.Fprintf(buf, "%s=%v", e.Name, e.Type)
+	}
+	fmt.Fprint(buf, "}")
+
+	return buf.String()
+}
+
 // String returns a human readable name for the Cassandra datatype
 // described by t.
 // Type is the identifier of a Cassandra internal datatype.

+ 5 - 8
session.go

@@ -660,8 +660,8 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 			continue
 		}
 
-		// how can we allow users to pass in a single struct to unmarshal into
-		if col.TypeInfo.Type() == TypeTuple {
+		switch col.TypeInfo.Type() {
+		case TypeTuple:
 			// this will panic, actually a bug, please report
 			tuple := col.TypeInfo.(TupleTypeInfo)
 
@@ -669,18 +669,15 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 			// here we pass in a slice of the struct which has the number number of
 			// values as elements in the tuple
 			iter.err = Unmarshal(col.TypeInfo, iter.rows[iter.pos][c], dest[i:i+count])
-			if iter.err != nil {
-				return false
-			}
 			i += count
-			continue
+		default:
+			iter.err = Unmarshal(col.TypeInfo, iter.rows[iter.pos][c], dest[i])
+			i++
 		}
 
-		iter.err = Unmarshal(col.TypeInfo, iter.rows[iter.pos][c], dest[i])
 		if iter.err != nil {
 			return false
 		}
-		i++
 	}
 
 	iter.pos++

+ 88 - 0
udt_test.go

@@ -0,0 +1,88 @@
+// +build all integration
+
+package gocql
+
+import (
+	"fmt"
+	"testing"
+)
+
+type position struct {
+	lat int
+	lon int
+}
+
+// NOTE: due to current implementation details it is not currently possible to use
+// a pointer receiver type for the UDTMarshaler interface to handle UDT's
+func (p position) EncodeUDTField(name string, info TypeInfo) ([]byte, error) {
+	switch name {
+	case "lat":
+		return Marshal(info, p.lat)
+	case "lon":
+		return Marshal(info, p.lon)
+	default:
+		return nil, fmt.Errorf("unknown column for position: %q", name)
+	}
+}
+
+func (p *position) DecodeUDTField(name string, info TypeInfo, data []byte) error {
+	switch name {
+	case "lat":
+		return Unmarshal(info, data, &p.lat)
+	case "lon":
+		return Unmarshal(info, data, &p.lon)
+	default:
+		return fmt.Errorf("unknown column for position: %q", name)
+	}
+}
+
+func TestUDT(t *testing.T) {
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE position(
+		lat int,
+		lon int);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE houses(
+		id int,
+		name text,
+		loc frozen<position>,
+
+		primary key(id)
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	const (
+		expLat = -1
+		expLon = 2
+	)
+
+	err = session.Query("INSERT INTO houses(id, name, loc) VALUES(?, ?, ?)", 1, "test", &position{expLat, expLon}).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	pos := &position{}
+
+	err = session.Query("SELECT loc FROM houses WHERE id = ?", 1).Scan(pos)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if pos.lat != expLat {
+		t.Errorf("expeceted lat to be be %d got %d", expLat, pos.lat)
+	}
+	if pos.lon != expLon {
+		t.Errorf("expeceted lon to be be %d got %d", expLon, pos.lon)
+	}
+}