Bläddra i källkod

Merge pull request #369 from Zariel/udt

Add support for UDT's
Ben Hood 10 år sedan
förälder
incheckning
1a30be4244
4 ändrade filer med 390 tillägg och 8 borttagningar
  1. 16 0
      frame.go
  2. 230 0
      marshal.go
  3. 5 8
      session.go
  4. 139 0
      udt_test.go

+ 16 - 0
frame.go

@@ -612,6 +612,22 @@ func (f *framer) readTypeInfo() TypeInfo {
 
 		return tuple
 
+	case TypeUDT:
+		udt := UDTTypeInfo{
+			NativeType: simple,
+		}
+		udt.KeySpace = f.readString()
+		udt.Name = f.readString()
+
+		n := f.readShort()
+		udt.Elements = make([]UDTField, n)
+		for i := 0; i < int(n); i++ {
+			field := &udt.Elements[i]
+			field.Name = f.readString()
+			field.Type = f.readTypeInfo()
+		}
+
+		return udt
 	case TypeMap, TypeList, TypeSet:
 		collection := CollectionType{
 			NativeType: simple,

+ 230 - 0
marshal.go

@@ -83,6 +83,8 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
 		return marshalVarint(info, value)
 	case TypeInet:
 		return marshalInet(info, value)
+	case TypeUDT:
+		return marshalUDT(info, value)
 	}
 	// TODO(tux21b): add the remaining types
 	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
@@ -130,6 +132,8 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
 		return unmarshalInet(info, data, value)
 	case TypeTuple:
 		return unmarshalTuple(info, data, value)
+	case TypeUDT:
+		return unmarshalUDT(info, data, value)
 	}
 	// TODO(tux21b): add the remaining types
 	return fmt.Errorf("can not unmarshal %s into %T", info, value)
@@ -1198,6 +1202,201 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
 	return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
 }
 
+// UDTMarshaler is an interface which should be implemented by users wishing to
+// handle encoding UDT types to sent to Cassandra. Note: due to current implentations
+// methods defined for this interface must be value receivers not pointer receivers.
+type UDTMarshaler interface {
+	// MarshalUDT will be called for each field in the the UDT returned by Cassandra,
+	// the implementor should marshal the type to return by for example calling
+	// Marshal.
+	MarshalUDT(name string, info TypeInfo) ([]byte, error)
+}
+
+// UDTUnmarshaler should be implemented by users wanting to implement custom
+// UDT unmarshaling.
+type UDTUnmarshaler interface {
+	// UnmarshalUDT will be called for each field in the UDT return by Cassandra,
+	// the implementor should unmarshal the data into the value of their chosing,
+	// for example by calling Unmarshal.
+	UnmarshalUDT(name string, info TypeInfo, data []byte) error
+}
+
+func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
+	udt := info.(UDTTypeInfo)
+
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case UDTMarshaler:
+		var buf []byte
+		for _, e := range udt.Elements {
+			data, err := v.MarshalUDT(e.Name, e.Type)
+			if err != nil {
+				return nil, err
+			}
+
+			n := len(data)
+			buf = append(buf, byte(n<<24),
+				byte(n<<16),
+				byte(n<<8),
+				byte(n))
+
+			buf = append(buf, data...)
+		}
+
+		return buf, nil
+	case map[string]interface{}:
+		var buf []byte
+		for _, e := range udt.Elements {
+			val, ok := v[e.Name]
+			if !ok {
+				return nil, marshalErrorf("missing UDT field in map: %s", e.Name)
+			}
+
+			data, err := Marshal(e.Type, val)
+			if err != nil {
+				return nil, err
+			}
+
+			n := len(data)
+			buf = append(buf, byte(n<<24),
+				byte(n<<16),
+				byte(n<<8),
+				byte(n))
+
+			buf = append(buf, data...)
+		}
+
+		return buf, nil
+	}
+
+	k := reflect.ValueOf(value)
+	if k.Kind() == reflect.Ptr {
+		if k.IsNil() {
+			return nil, marshalErrorf("cannot marshal %T into %s", value, info)
+		}
+		k = k.Elem()
+	}
+
+	if k.Kind() != reflect.Struct || !k.IsValid() {
+		return nil, marshalErrorf("cannot marshal %T into %s", value, info)
+	}
+
+	fields := make(map[string]reflect.Value)
+	t := reflect.TypeOf(value)
+	for i := 0; i < t.NumField(); i++ {
+		sf := t.Field(i)
+
+		if tag := sf.Tag.Get("cql"); tag != "" {
+			fields[tag] = k.Field(i)
+		}
+	}
+
+	var buf []byte
+	for _, e := range udt.Elements {
+		f, ok := fields[e.Name]
+		if !ok {
+			f = k.FieldByName(e.Name)
+		}
+
+		if !f.IsValid() {
+			return nil, marshalErrorf("cannot marshal %T into %s", value, info)
+		} else if f.Kind() == reflect.Ptr {
+			f = f.Elem()
+		}
+
+		data, err := Marshal(e.Type, f.Interface())
+		if err != nil {
+			return nil, err
+		}
+
+		n := len(data)
+		buf = append(buf, byte(n<<24),
+			byte(n<<16),
+			byte(n<<8),
+			byte(n))
+
+		buf = append(buf, data...)
+	}
+
+	return buf, nil
+
+}
+
+func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case UDTUnmarshaler:
+		udt := info.(UDTTypeInfo)
+
+		for _, e := range udt.Elements {
+			size := readInt(data[:4])
+			data = data[4:]
+
+			var err error
+			if size < 0 {
+				err = v.UnmarshalUDT(e.Name, e.Type, nil)
+			} else {
+				err = v.UnmarshalUDT(e.Name, e.Type, data[:size])
+				data = data[size:]
+			}
+
+			if err != nil {
+				return err
+			}
+		}
+
+		return nil
+	}
+
+	k := reflect.ValueOf(value).Elem()
+	if k.Kind() != reflect.Struct || !k.IsValid() {
+		return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
+	}
+
+	fields := make(map[string]reflect.Value)
+	t := k.Type()
+	for i := 0; i < t.NumField(); i++ {
+		sf := t.Field(i)
+
+		if tag := sf.Tag.Get("cql"); tag != "" {
+			fields[tag] = k.Field(i)
+		}
+	}
+
+	udt := info.(UDTTypeInfo)
+
+	for _, e := range udt.Elements {
+		size := readInt(data[:4])
+		data = data[4:]
+
+		var err error
+		if size >= 0 {
+			f, ok := fields[e.Name]
+			if !ok {
+				f = k.FieldByName(e.Name)
+			}
+
+			if !f.IsValid() || !f.CanAddr() {
+				return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
+			}
+
+			fk := f.Addr().Interface()
+			if err := Unmarshal(e.Type, data[:size], fk); err != nil {
+				return err
+			}
+			data = data[size:]
+		}
+
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
 // TypeInfo describes a Cassandra specific data type.
 type TypeInfo interface {
 	Type() Type
@@ -1268,6 +1467,37 @@ type TupleTypeInfo struct {
 	Elems []TypeInfo
 }
 
+type UDTField struct {
+	Name string
+	Type TypeInfo
+}
+
+type UDTTypeInfo struct {
+	NativeType
+	KeySpace string
+	Name     string
+	Elements []UDTField
+}
+
+func (u UDTTypeInfo) String() string {
+	buf := &bytes.Buffer{}
+
+	fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name)
+	first := true
+	for _, e := range u.Elements {
+		if !first {
+			fmt.Fprint(buf, ",")
+		} else {
+			first = false
+		}
+
+		fmt.Fprintf(buf, "%s=%v", e.Name, e.Type)
+	}
+	fmt.Fprint(buf, "}")
+
+	return buf.String()
+}
+
 // String returns a human readable name for the Cassandra datatype
 // described by t.
 // Type is the identifier of a Cassandra internal datatype.

+ 5 - 8
session.go

@@ -674,8 +674,8 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 			continue
 		}
 
-		// how can we allow users to pass in a single struct to unmarshal into
-		if col.TypeInfo.Type() == TypeTuple {
+		switch col.TypeInfo.Type() {
+		case TypeTuple:
 			// this will panic, actually a bug, please report
 			tuple := col.TypeInfo.(TupleTypeInfo)
 
@@ -683,18 +683,15 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
 			// here we pass in a slice of the struct which has the number number of
 			// values as elements in the tuple
 			iter.err = Unmarshal(col.TypeInfo, iter.rows[iter.pos][c], dest[i:i+count])
-			if iter.err != nil {
-				return false
-			}
 			i += count
-			continue
+		default:
+			iter.err = Unmarshal(col.TypeInfo, iter.rows[iter.pos][c], dest[i])
+			i++
 		}
 
-		iter.err = Unmarshal(col.TypeInfo, iter.rows[iter.pos][c], dest[i])
 		if iter.err != nil {
 			return false
 		}
-		i++
 	}
 
 	iter.pos++

+ 139 - 0
udt_test.go

@@ -0,0 +1,139 @@
+// +build all integration
+
+package gocql
+
+import (
+	"fmt"
+	"testing"
+)
+
+type position struct {
+	Lat int
+	Lon int
+}
+
+// NOTE: due to current implementation details it is not currently possible to use
+// a pointer receiver type for the UDTMarshaler interface to handle UDT's
+func (p position) MarshalUDT(name string, info TypeInfo) ([]byte, error) {
+	switch name {
+	case "lat":
+		return Marshal(info, p.Lat)
+	case "lon":
+		return Marshal(info, p.Lon)
+	default:
+		return nil, fmt.Errorf("unknown column for position: %q", name)
+	}
+}
+
+func (p *position) UnmarshalUDT(name string, info TypeInfo, data []byte) error {
+	switch name {
+	case "lat":
+		return Unmarshal(info, data, &p.Lat)
+	case "lon":
+		return Unmarshal(info, data, &p.Lon)
+	default:
+		return fmt.Errorf("unknown column for position: %q", name)
+	}
+}
+
+func TestUDT_Marshaler(t *testing.T) {
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE position(
+		lat int,
+		lon int);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE houses(
+		id int,
+		name text,
+		loc frozen<position>,
+
+		primary key(id)
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	const (
+		expLat = -1
+		expLon = 2
+	)
+
+	err = session.Query("INSERT INTO houses(id, name, loc) VALUES(?, ?, ?)", 1, "test", &position{expLat, expLon}).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	pos := &position{}
+
+	err = session.Query("SELECT loc FROM houses WHERE id = ?", 1).Scan(pos)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if pos.Lat != expLat {
+		t.Errorf("expeceted lat to be be %d got %d", expLat, pos.Lat)
+	}
+	if pos.Lon != expLon {
+		t.Errorf("expeceted lon to be be %d got %d", expLon, pos.Lon)
+	}
+}
+func TestUDT_Reflect(t *testing.T) {
+	// Uses reflection instead of implementing the marshaling type
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE horse(
+		name text,
+		owner text);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE horse_race(
+		position int,
+		horse frozen<horse>,
+
+		primary key(position)
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	type horse struct {
+		Name  string `cql:"name"`
+		Owner string `cql:"owner"`
+	}
+
+	insertedHorse := &horse{
+		Name:  "pony",
+		Owner: "jim",
+	}
+
+	err = session.Query("INSERT INTO horse_race(position, horse) VALUES(?, ?)", 1, insertedHorse).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	retrievedHorse := &horse{}
+	err = session.Query("SELECT horse FROM horse_race WHERE position = ?", 1).Scan(retrievedHorse)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if *retrievedHorse != *insertedHorse {
+		t.Fatal("exepcted to get %+v got %+v", insertedHorse, retrievedHorse)
+	}
+}