Browse Source

add support to unmarshal UDT into *map[string]interface{}

Chris Bannister 10 years ago
parent
commit
a1bba7f41d
2 changed files with 151 additions and 20 deletions
  1. 62 20
      marshal.go
  2. 89 0
      udt_test.go

+ 62 - 20
marshal.go

@@ -1368,6 +1368,48 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error {
 			}
 		}
 
+		return nil
+	case *map[string]interface{}:
+		udt := info.(UDTTypeInfo)
+
+		rv := reflect.ValueOf(value)
+		if rv.Kind() != reflect.Ptr {
+			return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+		}
+
+		rv = rv.Elem()
+		t := rv.Type()
+		if t.Kind() != reflect.Map {
+			return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+		} else if data == nil {
+			rv.Set(reflect.Zero(t))
+			return nil
+		}
+
+		rv.Set(reflect.MakeMap(t))
+		m := *v
+
+		for _, e := range udt.Elements {
+			size := readInt(data[:4])
+			data = data[4:]
+
+			val := reflect.New(goType(e.Type))
+
+			var err error
+			if size < 0 {
+				err = Unmarshal(e.Type, nil, val.Interface())
+			} else {
+				err = Unmarshal(e.Type, data[:size], val.Interface())
+				data = data[size:]
+			}
+
+			if err != nil {
+				return err
+			}
+
+			m[e.Name] = val.Elem().Interface()
+		}
+
 		return nil
 	}
 
@@ -1533,26 +1575,26 @@ type Type int
 
 const (
 	TypeCustom    Type = 0x0000
-	TypeAscii          = 0x0001
-	TypeBigInt         = 0x0002
-	TypeBlob           = 0x0003
-	TypeBoolean        = 0x0004
-	TypeCounter        = 0x0005
-	TypeDecimal        = 0x0006
-	TypeDouble         = 0x0007
-	TypeFloat          = 0x0008
-	TypeInt            = 0x0009
-	TypeTimestamp      = 0x000B
-	TypeUUID           = 0x000C
-	TypeVarchar        = 0x000D
-	TypeVarint         = 0x000E
-	TypeTimeUUID       = 0x000F
-	TypeInet           = 0x0010
-	TypeList           = 0x0020
-	TypeMap            = 0x0021
-	TypeSet            = 0x0022
-	TypeUDT            = 0x0030
-	TypeTuple          = 0x0031
+	TypeAscii     Type = 0x0001
+	TypeBigInt    Type = 0x0002
+	TypeBlob      Type = 0x0003
+	TypeBoolean   Type = 0x0004
+	TypeCounter   Type = 0x0005
+	TypeDecimal   Type = 0x0006
+	TypeDouble    Type = 0x0007
+	TypeFloat     Type = 0x0008
+	TypeInt       Type = 0x0009
+	TypeTimestamp Type = 0x000B
+	TypeUUID      Type = 0x000C
+	TypeVarchar   Type = 0x000D
+	TypeVarint    Type = 0x000E
+	TypeTimeUUID  Type = 0x000F
+	TypeInet      Type = 0x0010
+	TypeList      Type = 0x0020
+	TypeMap       Type = 0x0021
+	TypeSet       Type = 0x0022
+	TypeUDT       Type = 0x0030
+	TypeTuple     Type = 0x0031
 )
 
 // String returns the name of the identifier.

+ 89 - 0
udt_test.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"strings"
 	"testing"
+	"time"
 )
 
 type position struct {
@@ -252,3 +253,91 @@ func TestUDT_NullObject(t *testing.T) {
 		t.Errorf("expected empty string to be returned for null udt: got %q", readCol.Owner)
 	}
 }
+
+func TestMapScanUDT(t *testing.T) {
+	if *flagProto < protoVersion3 {
+		t.Skip("UDT are only available on protocol >= 3")
+	}
+
+	session := createSession(t)
+	defer session.Close()
+
+	err := createTable(session, `CREATE TYPE log_entry (
+		created_timestamp timestamp,
+		message text
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = createTable(session, `CREATE TABLE requests_by_id (
+		id uuid PRIMARY KEY,
+		type int,
+		log_entries list<frozen <log_entry>>
+	);`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	entry := []struct {
+		CreatedTimestamp time.Time `cql:"created_timestamp"`
+		Message          string    `cql:"message"`
+	}{
+		{
+			CreatedTimestamp: time.Now().Truncate(time.Millisecond),
+			Message:          "test time now",
+		},
+	}
+
+	id, _ := RandomUUID()
+	const typ = 1
+
+	err = session.Query("INSERT INTO requests_by_id(id, type, log_entries) VALUES (?, ?, ?)", id, typ, entry).Exec()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	rawResult := map[string]interface{}{}
+	err = session.Query(`SELECT * FROM requests_by_id WHERE id = ?`, id).MapScan(rawResult)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	logEntries, ok := rawResult["log_entries"].([]map[string]interface{})
+	if !ok {
+		t.Fatal("log_entries not in scanned map")
+	}
+
+	if len(logEntries) != 1 {
+		t.Fatalf("expected to get 1 log_entry got %d", len(logEntries))
+	}
+
+	logEntry := logEntries[0]
+
+	timestamp, ok := logEntry["created_timestamp"]
+	if !ok {
+		t.Error("created_timestamp not unmarshalled into map")
+	} else {
+		if ts, ok := timestamp.(time.Time); ok {
+			if !ts.In(time.UTC).Equal(entry[0].CreatedTimestamp.In(time.UTC)) {
+				t.Errorf("created_timestamp not equal to stored: got %v expected %v", ts.In(time.UTC), entry[0].CreatedTimestamp.In(time.UTC))
+			}
+		} else {
+			t.Errorf("created_timestamp was not time.Time got: %T", timestamp)
+		}
+	}
+
+	message, ok := logEntry["message"]
+	if !ok {
+		t.Error("message not unmarshalled into map")
+	} else {
+		if ts, ok := message.(string); ok {
+			if ts != message {
+				t.Errorf("message not equal to stored: got %v expected %v", ts, entry[0].Message)
+			}
+		} else {
+			t.Errorf("message was not string got: %T", message)
+		}
+	}
+
+}