Ver Fonte

Merge pull request #89 from Zariel/uuid-from-string

Allow uuid's to be unmarsalled into strings and bytes
Christoph Hack há 12 anos atrás
pai
commit
f240f998ba
3 ficheiros alterados com 59 adições e 16 exclusões
  1. 49 10
      marshal.go
  2. 4 1
      marshal_test.go
  3. 6 5
      uuid.go

+ 49 - 10
marshal.go

@@ -903,31 +903,70 @@ func unmarshalMap(info *TypeInfo, data []byte, value interface{}) error {
 }
 
 func marshalUUID(info *TypeInfo, value interface{}) ([]byte, error) {
-	if val, ok := value.([]byte); ok && len(val) == 16 {
-		return val, nil
-	}
-	if val, ok := value.(UUID); ok {
+	switch val := value.(type) {
+	case UUID:
 		return val.Bytes(), nil
+	case []byte:
+		if len(val) == 16 {
+			return val, nil
+		}
+	case string:
+		b, err := ParseUUID(val)
+		if err != nil {
+			return nil, err
+		}
+		return b[:], nil
 	}
 	return nil, marshalErrorf("can not marshal %T into %s", value, info)
 }
 
 func unmarshalUUID(info *TypeInfo, data []byte, value interface{}) error {
+	if data == nil {
+		return nil
+	}
+
 	switch v := value.(type) {
-	case Unmarshaler:
-		return v.UnmarshalCQL(info, data)
+	case *string:
+		u, err := UUIDFromBytes(data)
+		if err != nil {
+			return unmarshalErrorf("Unable to parse UUID: %s", err)
+		}
+
+		*v = u.String()
+		return nil
+	case *[]byte:
+		u, err := UUIDFromBytes(data)
+		if err != nil {
+			return unmarshalErrorf("Unable to parse UUID: %s", err)
+		}
+
+		b := [16]byte(u)
+
+		*v = b[:]
+		return nil
 	case *UUID:
-		*v = UUIDFromBytes(data)
+		u, err := UUIDFromBytes(data)
+		if err != nil {
+			return unmarshalErrorf("Unable to parse UUID: %s", err)
+		}
+
+		*v = u
+
 		return nil
 	}
-	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+
+	return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
 }
 
 func unmarshalTimeUUID(info *TypeInfo, data []byte, value interface{}) error {
 	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
 	case *time.Time:
-		id := UUIDFromBytes(data)
-		if id.Version() != 1 {
+		id, err := UUIDFromBytes(data)
+		if err != nil {
+			return err
+		} else if id.Version() != 1 {
 			return unmarshalErrorf("invalid timeuuid")
 		}
 		*v = id.Time()

+ 4 - 1
marshal_test.go

@@ -52,7 +52,10 @@ var marshalTests = []struct {
 	{
 		&TypeInfo{Type: TypeTimeUUID},
 		[]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0},
-		UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}),
+		func() UUID {
+			x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0})
+			return x
+		}(),
 	},
 	{
 		&TypeInfo{Type: TypeInt},

+ 6 - 5
uuid.go

@@ -10,6 +10,7 @@ package gocql
 
 import (
 	"crypto/rand"
+	"errors"
 	"fmt"
 	"io"
 	"net"
@@ -82,15 +83,15 @@ func ParseUUID(input string) (UUID, error) {
 	return u, nil
 }
 
-// UUIDFromBytes converts a raw byte slice to an UUID. It will panic if the
-// slice isn't exactly 16 bytes long.
-func UUIDFromBytes(input []byte) UUID {
+// UUIDFromBytes converts a raw byte slice to an UUID.
+func UUIDFromBytes(input []byte) (UUID, error) {
 	var u UUID
 	if len(input) != 16 {
-		panic("UUIDs must be exactly 16 bytes long")
+		return u, errors.New("UUIDs must be exactly 16 bytes long")
 	}
+
 	copy(u[:], input)
-	return u
+	return u, nil
 }
 
 // RandomUUID generates a totally random UUID (version 4) as described in