Przeglądaj źródła

Support encoding.TextMarshaler/Unmarshaler.

Fixes #38.
Gustavo Niemeyer 11 lat temu
rodzic
commit
eca94c41d9
4 zmienionych plików z 37 dodań i 7 usunięć
  1. 12 5
      decode.go
  2. 7 0
      decode_test.go
  3. 11 2
      encode.go
  4. 7 0
      encode_test.go

+ 12 - 5
decode.go

@@ -1,6 +1,7 @@
 package yaml
 
 import (
+	"encoding"
 	"encoding/base64"
 	"fmt"
 	"reflect"
@@ -258,10 +259,8 @@ func (d *decoder) prepare(n *node, out reflect.Value) (newout reflect.Value, unm
 		if out.Kind() == reflect.Ptr {
 			if out.IsNil() {
 				out.Set(reflect.New(out.Type().Elem()))
-				out = out.Elem()
-			} else {
-				out = out.Elem()
 			}
+			out = out.Elem()
 			again = true
 		}
 		if out.CanAddr() {
@@ -351,8 +350,16 @@ func (d *decoder) scalar(n *node, out reflect.Value) (good bool) {
 		} else {
 			out.Set(reflect.Zero(out.Type()))
 		}
-		good = true
-		return
+		return true
+	}
+	if s, ok := resolved.(string); ok && out.CanAddr() {
+		if u, ok := out.Addr().Interface().(encoding.TextUnmarshaler); ok {
+			err := u.UnmarshalText([]byte(s))
+			if err != nil {
+				fail(err)
+			}
+			return true
+		}
 	}
 	switch out.Kind() {
 	case reflect.String:

+ 7 - 0
decode_test.go

@@ -5,6 +5,7 @@ import (
 	. "gopkg.in/check.v1"
 	"gopkg.in/yaml.v2"
 	"math"
+	"net"
 	"reflect"
 	"strings"
 	"time"
@@ -418,6 +419,12 @@ var unmarshalTests = []struct {
 		"a: {b: c}",
 		M{"a": M{"b": "c"}},
 	},
+
+	// Support encoding.TextUnmarshaler.
+	{
+		"a: 1.2.3.4\n",
+		map[string]net.IP{"a": net.IPv4(1, 2, 3, 4)},
+	},
 }
 
 type M map[interface{}]interface{}

+ 11 - 2
encode.go

@@ -1,6 +1,7 @@
 package yaml
 
 import (
+	"encoding"
 	"reflect"
 	"regexp"
 	"sort"
@@ -62,7 +63,8 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
 		e.nilv()
 		return
 	}
-	if m, ok := in.Interface().(Marshaler); ok {
+	iface := in.Interface()
+	if m, ok := iface.(Marshaler); ok {
 		v, err := m.MarshalYAML()
 		if err != nil {
 			fail(err)
@@ -73,6 +75,13 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
 		}
 		in = reflect.ValueOf(v)
 	}
+	if m, ok := iface.(encoding.TextMarshaler); ok {
+		text, err := m.MarshalText()
+		if err != nil {
+			fail(err)
+		}
+		in = reflect.ValueOf(string(text))
+	}
 	switch in.Kind() {
 	case reflect.Interface:
 		if in.IsNil() {
@@ -100,7 +109,7 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
 		e.stringv(tag, in)
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 		if in.Type() == durationType {
-			e.stringv(tag, reflect.ValueOf(in.Interface().(time.Duration).String()))
+			e.stringv(tag, reflect.ValueOf(iface.(time.Duration).String()))
 		} else {
 			e.intv(tag, in)
 		}

+ 7 - 0
encode_test.go

@@ -9,6 +9,7 @@ import (
 
 	. "gopkg.in/check.v1"
 	"gopkg.in/yaml.v2"
+	"net"
 )
 
 var marshalIntTest = 123
@@ -260,6 +261,12 @@ var marshalTests = []struct {
 		map[string]string{"a": "你好"},
 		"a: 你好\n",
 	},
+
+	// Support encoding.TextMarshaler.
+	{
+		map[string]net.IP{"a": net.IPv4(1, 2, 3, 4)},
+		"a: 1.2.3.4\n",
+	},
 }
 
 func (s *S) TestMarshal(c *C) {