Ver código fonte

added marshaling code for all basic types

Christoph Hack 12 anos atrás
pai
commit
a33d49c4ca
2 arquivos alterados com 843 adições e 67 exclusões
  1. 680 67
      marshal.go
  2. 163 0
      marshal_test.go

+ 680 - 67
marshal.go

@@ -6,58 +6,46 @@ package gocql
 
 import (
 	"fmt"
+	"math"
+	"reflect"
 	"time"
 )
 
 // Marshaler is the interface implemented by objects that can marshal
 // themselves into values understood by Cassandra.
 type Marshaler interface {
-	MarshalCQL(info *TypeInfo, value interface{}) ([]byte, error)
+	MarshalCQL(info *TypeInfo) ([]byte, error)
 }
 
 // Unmarshaler is the interface implemented by objects that can unmarshal
 // a Cassandra specific description of themselves.
 type Unmarshaler interface {
-	UnmarshalCQL(info *TypeInfo, data []byte, value interface{}) error
+	UnmarshalCQL(info *TypeInfo, data []byte) error
 }
 
 // Marshal returns the CQL encoding of the value for the Cassandra
 // internal type described by the info parameter.
 func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 	if v, ok := value.(Marshaler); ok {
-		return v.MarshalCQL(info, value)
+		return v.MarshalCQL(info)
 	}
 	switch info.Type {
 	case TypeVarchar, TypeAscii, TypeBlob:
-		switch v := value.(type) {
-		case string:
-			return []byte(v), nil
-		case []byte:
-			return v, nil
-		}
+		return marshalVarchar(info, value)
 	case TypeBoolean:
-		if v, ok := value.(bool); ok {
-			if v {
-				return []byte{1}, nil
-			} else {
-				return []byte{0}, nil
-			}
-		}
+		return marshalBool(info, value)
 	case TypeInt:
-		switch v := value.(type) {
-		case int:
-			x := int32(v)
-			return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)}, nil
-		}
+		return marshalInt(info, value)
+	case TypeBigInt:
+		return marshalBigInt(info, value)
+	case TypeFloat:
+		return marshalFloat(info, value)
+	case TypeDouble:
+		return marshalDouble(info, value)
 	case TypeTimestamp:
-		if v, ok := value.(time.Time); ok {
-			x := v.In(time.UTC).UnixNano() / int64(time.Millisecond)
-			return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40),
-				byte(x >> 32), byte(x >> 24), byte(x >> 16),
-				byte(x >> 8), byte(x)}, nil
-		}
+		return marshalTimestamp(info, value)
 	}
-	// TODO(tux21b): add reflection and a lot of other types
+	// TODO(tux21b): add the remaining types
 	return nil, fmt.Errorf("can not marshal %T into %s", value, info)
 }
 
@@ -66,54 +54,657 @@ func Marshal(info *TypeInfo, value interface{}) ([]byte, error) {
 // value pointed by value.
 func Unmarshal(info *TypeInfo, data []byte, value interface{}) error {
 	if v, ok := value.(Unmarshaler); ok {
-		return v.UnmarshalCQL(info, data, value)
+		return v.UnmarshalCQL(info, data)
 	}
 	switch info.Type {
 	case TypeVarchar, TypeAscii, TypeBlob:
-		switch v := value.(type) {
-		case *string:
-			*v = string(data)
-			return nil
-		case *[]byte:
-			val := make([]byte, len(data))
-			copy(val, data)
-			*v = val
-			return nil
-		}
+		return unmarshalVarchar(info, data, value)
 	case TypeBoolean:
-		if v, ok := value.(*bool); ok && len(data) == 1 {
-			*v = data[0] != 0
-			return nil
-		}
-	case TypeBigInt:
-		if v, ok := value.(*int); ok && len(data) == 8 {
-			*v = int(data[0])<<56 | int(data[1])<<48 | int(data[2])<<40 |
-				int(data[3])<<32 | int(data[4])<<24 | int(data[5])<<16 |
-				int(data[6])<<8 | int(data[7])
-			return nil
-		}
+		return unmarshalBool(info, data, value)
 	case TypeInt:
-		if v, ok := value.(*int); ok && len(data) == 4 {
-			*v = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 |
-				int(data[3])
-			return nil
-		}
+		return unmarshalInt(info, data, value)
+	case TypeBigInt, TypeCounter:
+		return unmarshalBigInt(info, data, value)
+	case TypeFloat:
+		return unmarshalFloat(info, data, value)
+	case TypeDouble:
+		return unmarshalDouble(info, data, value)
 	case TypeTimestamp:
-		if v, ok := value.(*time.Time); ok && len(data) == 8 {
-			x := int64(data[0])<<56 | int64(data[1])<<48 |
-				int64(data[2])<<40 | int64(data[3])<<32 |
-				int64(data[4])<<24 | int64(data[5])<<16 |
-				int64(data[6])<<8 | int64(data[7])
-			sec := x / 1000
-			nsec := (x - sec*1000) * 1000000
-			*v = time.Unix(sec, nsec)
-			return nil
-		}
-	}
-	// TODO(tux21b): add reflection and a lot of other basic types
+		return unmarshalTimestamp(info, data, value)
+	}
+	// TODO(tux21b): add the remaining types
 	return fmt.Errorf("can not unmarshal %s into %T", info, value)
 }
 
+func marshalVarchar(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case string:
+		return []byte(v), nil
+	case []byte:
+		return v, nil
+	}
+	rv := reflect.ValueOf(value)
+	t := rv.Type()
+	k := t.Kind()
+	switch {
+	case k == reflect.String:
+		return []byte(rv.String()), nil
+	case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8:
+		return rv.Bytes(), nil
+	case k == reflect.Ptr:
+		return marshalVarchar(info, rv.Elem().Interface())
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func unmarshalVarchar(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *string:
+		*v = string(data)
+		return nil
+	case *[]byte:
+		var dataCopy []byte
+		if data != nil {
+			dataCopy = make([]byte, len(data))
+			copy(dataCopy, data)
+		}
+		*v = dataCopy
+		return nil
+	}
+	rv := reflect.ValueOf(value)
+	if rv.Kind() != reflect.Ptr {
+		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+	}
+	rv = rv.Elem()
+	t := rv.Type()
+	k := t.Kind()
+	switch {
+	case k == reflect.String:
+		rv.SetString(string(data))
+		return nil
+	case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8:
+		var dataCopy []byte
+		if data != nil {
+			dataCopy = make([]byte, len(data))
+			copy(dataCopy, data)
+		}
+		rv.SetBytes(dataCopy)
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
+func marshalInt(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case int:
+		if v > math.MaxInt32 || v < math.MinInt32 {
+			return nil, marshalErrorf("marshal int: value %d out of range", v)
+		}
+		return encInt(int32(v)), nil
+	case uint:
+		if v > math.MaxInt32 {
+			return nil, marshalErrorf("marshal int: value %d out of range", v)
+		}
+		return encInt(int32(v)), nil
+	case int64:
+		if v > math.MaxInt32 || v < math.MinInt32 {
+			return nil, marshalErrorf("marshal int: value %d out of range", v)
+		}
+		return encInt(int32(v)), nil
+	case uint64:
+		if v > math.MaxInt32 {
+			return nil, marshalErrorf("marshal int: value %d out of range", v)
+		}
+		return encInt(int32(v)), nil
+	case int32:
+		return encInt(v), nil
+	case uint32:
+		if v > math.MaxInt32 {
+			return nil, marshalErrorf("marshal int: value %d out of range", v)
+		}
+		return encInt(int32(v)), nil
+	case int16:
+		return encInt(int32(v)), nil
+	case uint16:
+		return encInt(int32(v)), nil
+	case int8:
+		return encInt(int32(v)), nil
+	case uint8:
+		return encInt(int32(v)), nil
+	}
+	rv := reflect.ValueOf(value)
+	switch rv.Type().Kind() {
+	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
+		v := rv.Int()
+		if v > math.MaxInt32 || v < math.MinInt32 {
+			return nil, marshalErrorf("marshal int: value %d out of range", v)
+		}
+		return encInt(int32(v)), nil
+	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
+		v := rv.Uint()
+		if v > math.MaxInt32 {
+			return nil, marshalErrorf("marshal int: value %d out of range", v)
+		}
+		return encInt(int32(v)), nil
+	case reflect.Ptr:
+		return marshalInt(info, rv.Elem().Interface())
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func encInt(x int32) []byte {
+	return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)}
+}
+
+func unmarshalInt(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *int:
+		*v = int(decInt(data))
+		return nil
+	case *uint:
+		x := decInt(data)
+		if x < 0 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint(x)
+		return nil
+	case *int64:
+		*v = int64(decInt(data))
+		return nil
+	case *uint64:
+		x := decInt(data)
+		if x < 0 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint64(x)
+		return nil
+	case *int32:
+		*v = decInt(data)
+		return nil
+	case *uint32:
+		x := decInt(data)
+		if x < 0 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint32(x)
+		return nil
+	case *int16:
+		x := decInt(data)
+		if x < math.MinInt16 || x > math.MaxInt16 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = int16(x)
+		return nil
+	case *uint16:
+		x := decInt(data)
+		if x < 0 || x > math.MaxUint16 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint16(x)
+		return nil
+	case *int8:
+		x := decInt(data)
+		if x < math.MinInt8 || x > math.MaxInt8 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = int8(x)
+		return nil
+	case *uint8:
+		x := decInt(data)
+		if x < 0 || x > math.MaxUint8 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint8(x)
+		return nil
+	}
+	rv := reflect.ValueOf(value)
+	if rv.Kind() != reflect.Ptr {
+		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+	}
+	rv = rv.Elem()
+	switch rv.Type().Kind() {
+	case reflect.Int, reflect.Int64, reflect.Int32:
+		rv.SetInt(int64(decInt(data)))
+		return nil
+	case reflect.Int16:
+		x := decInt(data)
+		if x < math.MinInt16 || x > math.MaxInt16 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetInt(int64(x))
+		return nil
+	case reflect.Int8:
+		x := decInt(data)
+		if x < math.MinInt8 || x > math.MaxInt8 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetInt(int64(x))
+		return nil
+	case reflect.Uint, reflect.Uint64, reflect.Uint32:
+		x := decInt(data)
+		if x < 0 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetUint(uint64(x))
+		return nil
+	case reflect.Uint16:
+		x := decInt(data)
+		if x < 0 || x > math.MaxUint16 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetUint(uint64(x))
+		return nil
+	case reflect.Uint8:
+		x := decInt(data)
+		if x < 0 || x > math.MaxUint8 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetUint(uint64(x))
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
+func decInt(x []byte) int32 {
+	if len(x) != 4 {
+		return 0
+	}
+	return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3])
+}
+
+func marshalBigInt(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case int:
+		return encBigInt(int64(v)), nil
+	case uint:
+		if v > math.MaxInt64 {
+			return nil, marshalErrorf("marshal bigint: value %d out of range", v)
+		}
+		return encBigInt(int64(v)), nil
+	case int64:
+		return encBigInt(v), nil
+	case uint64:
+		if v > math.MaxInt64 {
+			return nil, marshalErrorf("marshal bigint: value %d out of range", v)
+		}
+		return encBigInt(int64(v)), nil
+	case int32:
+		return encBigInt(int64(v)), nil
+	case uint32:
+		return encBigInt(int64(v)), nil
+	case int16:
+		return encBigInt(int64(v)), nil
+	case uint16:
+		return encBigInt(int64(v)), nil
+	case int8:
+		return encBigInt(int64(v)), nil
+	case uint8:
+		return encBigInt(int64(v)), nil
+	}
+	rv := reflect.ValueOf(value)
+	switch rv.Type().Kind() {
+	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
+		v := rv.Int()
+		return encBigInt(v), nil
+	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
+		v := rv.Uint()
+		if v > math.MaxInt64 {
+			return nil, marshalErrorf("marshal bigint: value %d out of range", v)
+		}
+		return encBigInt(int64(v)), nil
+	case reflect.Ptr:
+		return marshalBigInt(info, rv.Elem().Interface())
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func encBigInt(x int64) []byte {
+	return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32),
+		byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)}
+}
+
+func unmarshalBigInt(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *int:
+		x := decBigInt(data)
+		if ^uint(0) == math.MaxUint32 && (x < math.MinInt32 || x > math.MaxInt32) {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = int(x)
+		return nil
+	case *uint:
+		x := decBigInt(data)
+		if x < 0 || (^uint(0) == math.MaxUint32 && x > math.MaxUint32) {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint(x)
+		return nil
+	case *int64:
+		*v = decBigInt(data)
+		return nil
+	case *uint64:
+		x := decBigInt(data)
+		if x < 0 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint64(x)
+		return nil
+	case *int32:
+		x := decBigInt(data)
+		if x < math.MinInt32 || x > math.MaxInt32 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = int32(x)
+		return nil
+	case *uint32:
+		x := decBigInt(data)
+		if x < 0 || x > math.MaxUint32 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint32(x)
+		return nil
+	case *int16:
+		x := decBigInt(data)
+		if x < math.MinInt16 || x > math.MaxInt16 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = int16(x)
+		return nil
+	case *uint16:
+		x := decBigInt(data)
+		if x < 0 || x > math.MaxUint16 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint16(x)
+		return nil
+	case *int8:
+		x := decBigInt(data)
+		if x < math.MinInt8 || x > math.MaxInt8 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = int8(x)
+		return nil
+	case *uint8:
+		x := decBigInt(data)
+		if x < 0 || x > math.MaxUint8 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		*v = uint8(x)
+		return nil
+	}
+	rv := reflect.ValueOf(value)
+	if rv.Kind() != reflect.Ptr {
+		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+	}
+	rv = rv.Elem()
+	switch rv.Type().Kind() {
+	case reflect.Int:
+		x := decBigInt(data)
+		if ^uint(0) == math.MaxUint32 && (x < math.MinInt32 || x > math.MaxInt32) {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetInt(x)
+		return nil
+	case reflect.Int64:
+		rv.SetInt(decBigInt(data))
+		return nil
+	case reflect.Int32:
+		x := decBigInt(data)
+		if x < math.MinInt32 || x > math.MaxInt32 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetInt(x)
+		return nil
+	case reflect.Int16:
+		x := decBigInt(data)
+		if x < math.MinInt16 || x > math.MaxInt16 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetInt(x)
+		return nil
+	case reflect.Int8:
+		x := decBigInt(data)
+		if x < math.MinInt8 || x > math.MaxInt8 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetInt(x)
+		return nil
+	case reflect.Uint:
+		x := decBigInt(data)
+		if x < 0 || (^uint(0) == math.MaxUint32 && x > math.MaxUint32) {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetUint(uint64(x))
+		return nil
+	case reflect.Uint64:
+		x := decBigInt(data)
+		if x < 0 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetUint(uint64(x))
+		return nil
+	case reflect.Uint32:
+		x := decBigInt(data)
+		if x < 0 || x > math.MaxUint32 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetUint(uint64(x))
+		return nil
+	case reflect.Uint16:
+		x := decBigInt(data)
+		if x < 0 || x > math.MaxUint16 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetUint(uint64(x))
+		return nil
+	case reflect.Uint8:
+		x := decBigInt(data)
+		if x < 0 || x > math.MaxUint8 {
+			return unmarshalErrorf("unmarshal int: value %d out of range", x)
+		}
+		rv.SetUint(uint64(x))
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
+func decBigInt(data []byte) int64 {
+	if len(data) != 8 {
+		return 0
+	}
+	return int64(data[0])<<56 | int64(data[1])<<48 |
+		int64(data[2])<<40 | int64(data[3])<<32 |
+		int64(data[4])<<24 | int64(data[5])<<16 |
+		int64(data[6])<<8 | int64(data[7])
+}
+
+func marshalBool(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case bool:
+		return encBool(v), nil
+	}
+	rv := reflect.ValueOf(value)
+	switch rv.Type().Kind() {
+	case reflect.Bool:
+		return encBool(rv.Bool()), nil
+	case reflect.Ptr:
+		return marshalBool(info, rv.Elem().Interface())
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func encBool(v bool) []byte {
+	if v {
+		return []byte{1}
+	}
+	return []byte{0}
+}
+
+func unmarshalBool(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *bool:
+		*v = decBool(data)
+		return nil
+	}
+	rv := reflect.ValueOf(value)
+	if rv.Kind() != reflect.Ptr {
+		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+	}
+	rv = rv.Elem()
+	switch rv.Type().Kind() {
+	case reflect.Bool:
+		rv.SetBool(decBool(data))
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
+func decBool(v []byte) bool {
+	if len(v) == 0 {
+		return false
+	}
+	return v[0] != 0
+}
+
+func marshalFloat(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case float32:
+		return encInt(int32(math.Float32bits(v))), nil
+	}
+	rv := reflect.ValueOf(value)
+	switch rv.Type().Kind() {
+	case reflect.Float32:
+		return encInt(int32(math.Float32bits(float32(rv.Float())))), nil
+	case reflect.Ptr:
+		return marshalFloat(info, rv.Elem().Interface())
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func unmarshalFloat(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *float32:
+		*v = math.Float32frombits(uint32(decInt(data)))
+		return nil
+	}
+	rv := reflect.ValueOf(value)
+	if rv.Kind() != reflect.Ptr {
+		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+	}
+	rv = rv.Elem()
+	switch rv.Type().Kind() {
+	case reflect.Float32:
+		rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data)))))
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
+func marshalDouble(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case float64:
+		return encBigInt(int64(math.Float64bits(v))), nil
+	}
+	rv := reflect.ValueOf(value)
+	switch rv.Type().Kind() {
+	case reflect.Float64:
+		return encBigInt(int64(math.Float64bits(rv.Float()))), nil
+	case reflect.Ptr:
+		return marshalFloat(info, rv.Elem().Interface())
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func unmarshalDouble(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *float64:
+		*v = math.Float64frombits(uint64(decBigInt(data)))
+		return nil
+	}
+	rv := reflect.ValueOf(value)
+	if rv.Kind() != reflect.Ptr {
+		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+	}
+	rv = rv.Elem()
+	switch rv.Type().Kind() {
+	case reflect.Float64:
+		rv.SetFloat(math.Float64frombits(uint64(decBigInt(data))))
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
+func marshalTimestamp(info *TypeInfo, value interface{}) ([]byte, error) {
+	switch v := value.(type) {
+	case Marshaler:
+		return v.MarshalCQL(info)
+	case int64:
+		return encBigInt(v), nil
+	case time.Time:
+		x := v.In(time.UTC).UnixNano() / int64(time.Millisecond)
+		return encBigInt(x), nil
+	}
+	rv := reflect.ValueOf(value)
+	switch rv.Type().Kind() {
+	case reflect.Int64:
+		return encBigInt(rv.Int()), nil
+	case reflect.Ptr:
+		return marshalTimestamp(info, rv.Elem().Interface())
+	}
+	return nil, marshalErrorf("can not marshal %T into %s", value, info)
+}
+
+func unmarshalTimestamp(info *TypeInfo, data []byte, value interface{}) error {
+	switch v := value.(type) {
+	case Unmarshaler:
+		return v.UnmarshalCQL(info, data)
+	case *int64:
+		*v = decBigInt(data)
+		return nil
+	case *time.Time:
+		x := decBigInt(data)
+		sec := x / 1000
+		nsec := (x - sec*1000) * 1000000
+		*v = time.Unix(sec, nsec).In(time.UTC)
+		return nil
+	}
+	rv := reflect.ValueOf(value)
+	if rv.Kind() != reflect.Ptr {
+		return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
+	}
+	rv = rv.Elem()
+	switch rv.Type().Kind() {
+	case reflect.Int64:
+		rv.SetInt(decBigInt(data))
+		return nil
+	}
+	return unmarshalErrorf("can not unmarshal %s into %T", info, value)
+}
+
 // TypeInfo describes a Cassandra specific data type.
 type TypeInfo struct {
 	Type   Type
@@ -178,6 +769,8 @@ func (t Type) String() string {
 		return "counter"
 	case TypeDecimal:
 		return "decimal"
+	case TypeDouble:
+		return "double"
 	case TypeFloat:
 		return "float"
 	case TypeInt:
@@ -202,3 +795,23 @@ func (t Type) String() string {
 		return "unknown"
 	}
 }
+
+type MarshalError string
+
+func (m MarshalError) Error() string {
+	return string(m)
+}
+
+func marshalErrorf(format string, args ...interface{}) MarshalError {
+	return MarshalError(fmt.Sprintf(format, args...))
+}
+
+type UnmarshalError string
+
+func (m UnmarshalError) Error() string {
+	return string(m)
+}
+
+func unmarshalErrorf(format string, args ...interface{}) UnmarshalError {
+	return UnmarshalError(fmt.Sprintf(format, args...))
+}

+ 163 - 0
marshal_test.go

@@ -0,0 +1,163 @@
+package gocql
+
+import (
+	"bytes"
+	"math"
+	"reflect"
+	"strings"
+	"testing"
+	"time"
+)
+
+var marshalTests = []struct {
+	Info  *TypeInfo
+	Data  []byte
+	Value interface{}
+}{
+	{
+		&TypeInfo{Type: TypeVarchar},
+		[]byte("hello world"),
+		[]byte("hello world"),
+	},
+	{
+		&TypeInfo{Type: TypeVarchar},
+		[]byte("hello world"),
+		"hello world",
+	},
+	{
+		&TypeInfo{Type: TypeVarchar},
+		[]byte(nil),
+		[]byte(nil),
+	},
+	{
+		&TypeInfo{Type: TypeVarchar},
+		[]byte("hello world"),
+		MyString("hello world"),
+	},
+	{
+		&TypeInfo{Type: TypeVarchar},
+		[]byte("HELLO WORLD"),
+		CustomString("hello world"),
+	},
+	{
+		&TypeInfo{Type: TypeBlob},
+		[]byte("hello\x00"),
+		[]byte("hello\x00"),
+	},
+	{
+		&TypeInfo{Type: TypeBlob},
+		[]byte(nil),
+		[]byte(nil),
+	},
+	{
+		&TypeInfo{Type: TypeInt},
+		[]byte("\x00\x00\x00\x00"),
+		0,
+	},
+	{
+		&TypeInfo{Type: TypeInt},
+		[]byte("\x01\x02\x03\x04"),
+		16909060,
+	},
+	{
+		&TypeInfo{Type: TypeInt},
+		[]byte("\x80\x00\x00\x00"),
+		int32(math.MinInt32),
+	},
+	{
+		&TypeInfo{Type: TypeInt},
+		[]byte("\x7f\xff\xff\xff"),
+		int32(math.MaxInt32),
+	},
+	{
+		&TypeInfo{Type: TypeBigInt},
+		[]byte("\x00\x00\x00\x00\x00\x00\x00\x00"),
+		0,
+	},
+	{
+		&TypeInfo{Type: TypeBigInt},
+		[]byte("\x01\x02\x03\x04\x05\x06\x07\x08"),
+		72623859790382856,
+	},
+	{
+		&TypeInfo{Type: TypeBigInt},
+		[]byte("\x80\x00\x00\x00\x00\x00\x00\x00"),
+		int64(math.MinInt64),
+	},
+	{
+		&TypeInfo{Type: TypeBigInt},
+		[]byte("\x7f\xff\xff\xff\xff\xff\xff\xff"),
+		int64(math.MaxInt64),
+	},
+	{
+		&TypeInfo{Type: TypeBoolean},
+		[]byte("\x00"),
+		false,
+	},
+	{
+		&TypeInfo{Type: TypeBoolean},
+		[]byte("\x01"),
+		true,
+	},
+	{
+		&TypeInfo{Type: TypeFloat},
+		[]byte("\x40\x49\x0f\xdb"),
+		float32(3.14159265),
+	},
+	{
+		&TypeInfo{Type: TypeDouble},
+		[]byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"),
+		float64(3.14159265),
+	},
+	{
+		&TypeInfo{Type: TypeTimestamp},
+		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
+		time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC),
+	},
+	{
+		&TypeInfo{Type: TypeTimestamp},
+		[]byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"),
+		int64(1376387523000),
+	},
+}
+
+func TestMarshal(t *testing.T) {
+	for i, test := range marshalTests {
+		data, err := Marshal(test.Info, test.Value)
+		if err != nil {
+			t.Errorf("marshalTest[%d]: %v", i, err)
+			continue
+		}
+		if !bytes.Equal(data, test.Data) {
+			t.Errorf("marshalTest[%d]: expected %q, got %q.", i, test.Data, data)
+		}
+	}
+}
+
+func TestUnmarshal(t *testing.T) {
+	for i, test := range marshalTests {
+		v := reflect.New(reflect.TypeOf(test.Value))
+		err := Unmarshal(test.Info, test.Data, v.Interface())
+		if err != nil {
+			t.Errorf("marshalTest[%d]: %v", i, err)
+			continue
+		}
+		if !reflect.DeepEqual(v.Elem().Interface(), test.Value) {
+			t.Errorf("marshalTest[%d]: expected %#v, got %#v.", i, test.Value, v.Elem().Interface())
+		}
+	}
+}
+
+type CustomString string
+
+func (c CustomString) MarshalCQL(info *TypeInfo) ([]byte, error) {
+	return []byte(strings.ToUpper(string(c))), nil
+}
+func (c *CustomString) UnmarshalCQL(info *TypeInfo, data []byte) error {
+	*c = CustomString(strings.ToLower(string(data)))
+	return nil
+}
+
+type MyString string
+
+type MyInt int