Pārlūkot izejas kodu

Fixes to timestamp support.

Factor the timestamp parsing into its own function that
can easily be changed later to support the full range of
time stamps and add a quick early check to avoid
going through all the time.Parse calls in this common
path.

Also fix marshaling so that times are still printed unquoted.

For backwards compatibility, when timestamps are unmarshaled
into interface{} values, we unmarshal them as strings.

We also make it possible to unmarshal any value that implements
TextUnmarshaler, regardless of any explicit or implicit tag.
Roger Peppe 8 gadi atpakaļ
vecāks
revīzija
1f2a25ba94
5 mainītis faili ar 222 papildinājumiem un 100 dzēšanām
  1. 53 30
      decode.go
  2. 72 7
      decode_test.go
  3. 46 14
      encode.go
  4. 8 3
      encode_test.go
  5. 43 46
      resolve.go

+ 53 - 30
decode.go

@@ -232,6 +232,7 @@ var (
 	durationType   = reflect.TypeOf(time.Duration(0))
 	defaultMapType = reflect.TypeOf(map[interface{}]interface{}{})
 	ifaceType      = defaultMapType.Elem()
+	timeType       = reflect.TypeOf(time.Time{})
 )
 
 func newDecoder(strict bool) *decoder {
@@ -360,7 +361,7 @@ func resetMap(out reflect.Value) {
 	}
 }
 
-func (d *decoder) scalar(n *node, out reflect.Value) (good bool) {
+func (d *decoder) scalar(n *node, out reflect.Value) bool {
 	var tag string
 	var resolved interface{}
 	if n.tag == "" && !n.implicit {
@@ -384,9 +385,26 @@ func (d *decoder) scalar(n *node, out reflect.Value) (good bool) {
 		}
 		return true
 	}
-	if s, ok := resolved.(string); ok && out.CanAddr() {
-		if u, ok := out.Addr().Interface().(encoding.TextUnmarshaler); ok {
-			err := u.UnmarshalText([]byte(s))
+	if resolvedv := reflect.ValueOf(resolved); out.Type() == resolvedv.Type() {
+		// We've resolved to exactly the type we want, so use that.
+		out.Set(resolvedv)
+		return true
+	}
+	// Perhaps we can use the value as a TextUnmarshaler to
+	// set its value.
+	if out.CanAddr() {
+		u, ok := out.Addr().Interface().(encoding.TextUnmarshaler)
+		if ok {
+			var text []byte
+			if tag == yaml_BINARY_TAG {
+				text = []byte(resolved.(string))
+			} else {
+				// We let any value be unmarshaled into TextUnmarshaler.
+				// That might be more lax than we'd like, but the
+				// TextUnmarshaler itself should bowl out any dubious values.
+				text = []byte(n.value)
+			}
+			err := u.UnmarshalText(text)
 			if err != nil {
 				fail(err)
 			}
@@ -397,46 +415,53 @@ func (d *decoder) scalar(n *node, out reflect.Value) (good bool) {
 	case reflect.String:
 		if tag == yaml_BINARY_TAG {
 			out.SetString(resolved.(string))
-			good = true
-		} else if resolved != nil {
+			return true
+		}
+		if resolved != nil {
 			out.SetString(n.value)
-			good = true
+			return true
 		}
 	case reflect.Interface:
 		if resolved == nil {
 			out.Set(reflect.Zero(out.Type()))
+		} else if tag == yaml_TIMESTAMP_TAG {
+			// It looks like a timestamp but for backward compatibility
+			// reasons we set it as a string, so that code that unmarshals
+			// timestamp-like values into interface{} will continue to
+			// see a string and not a time.Time.
+			out.Set(reflect.ValueOf(n.value))
 		} else {
 			out.Set(reflect.ValueOf(resolved))
 		}
-		good = true
+		return true
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 		switch resolved := resolved.(type) {
 		case int:
 			if !out.OverflowInt(int64(resolved)) {
 				out.SetInt(int64(resolved))
-				good = true
+				return true
 			}
 		case int64:
 			if !out.OverflowInt(resolved) {
 				out.SetInt(resolved)
-				good = true
+				return true
 			}
 		case uint64:
 			if resolved <= math.MaxInt64 && !out.OverflowInt(int64(resolved)) {
 				out.SetInt(int64(resolved))
-				good = true
+				return true
 			}
 		case float64:
 			if resolved <= math.MaxInt64 && !out.OverflowInt(int64(resolved)) {
 				out.SetInt(int64(resolved))
-				good = true
+				return true
 			}
 		case string:
 			if out.Type() == durationType {
 				d, err := time.ParseDuration(resolved)
 				if err == nil {
 					out.SetInt(int64(d))
-					good = true
+					return true
 				}
 			}
 		}
@@ -445,49 +470,49 @@ func (d *decoder) scalar(n *node, out reflect.Value) (good bool) {
 		case int:
 			if resolved >= 0 && !out.OverflowUint(uint64(resolved)) {
 				out.SetUint(uint64(resolved))
-				good = true
+				return true
 			}
 		case int64:
 			if resolved >= 0 && !out.OverflowUint(uint64(resolved)) {
 				out.SetUint(uint64(resolved))
-				good = true
+				return true
 			}
 		case uint64:
 			if !out.OverflowUint(uint64(resolved)) {
 				out.SetUint(uint64(resolved))
-				good = true
+				return true
 			}
 		case float64:
 			if resolved <= math.MaxUint64 && !out.OverflowUint(uint64(resolved)) {
 				out.SetUint(uint64(resolved))
-				good = true
+				return true
 			}
 		}
 	case reflect.Bool:
 		switch resolved := resolved.(type) {
 		case bool:
 			out.SetBool(resolved)
-			good = true
+			return true
 		}
 	case reflect.Float32, reflect.Float64:
 		switch resolved := resolved.(type) {
 		case int:
 			out.SetFloat(float64(resolved))
-			good = true
+			return true
 		case int64:
 			out.SetFloat(float64(resolved))
-			good = true
+			return true
 		case uint64:
 			out.SetFloat(float64(resolved))
-			good = true
+			return true
 		case float64:
 			out.SetFloat(resolved)
-			good = true
+			return true
 		}
 	case reflect.Struct:
-		if out.Type() == reflect.TypeOf(resolved) {
-			out.Set(reflect.ValueOf(resolved))
-			good = true
+		if resolvedv := reflect.ValueOf(resolved); out.Type() == resolvedv.Type() {
+			out.Set(resolvedv)
+			return true
 		}
 	case reflect.Ptr:
 		if out.Type().Elem() == reflect.TypeOf(resolved) {
@@ -495,13 +520,11 @@ func (d *decoder) scalar(n *node, out reflect.Value) (good bool) {
 			elem := reflect.New(out.Type().Elem())
 			elem.Elem().Set(reflect.ValueOf(resolved))
 			out.Set(elem)
-			good = true
+			return true
 		}
 	}
-	if !good {
-		d.terror(n, tag, out)
-	}
-	return good
+	d.terror(n, tag, out)
+	return false
 }
 
 func settableValueOf(i interface{}) reflect.Value {

+ 72 - 7
decode_test.go

@@ -4,7 +4,6 @@ import (
 	"errors"
 	"io"
 	"math"
-	"net"
 	"reflect"
 	"strings"
 	"time"
@@ -576,24 +575,81 @@ var unmarshalTests = []struct {
 	// Support encoding.TextUnmarshaler.
 	{
 		"a: 1.2.3.4\n",
-		map[string]net.IP{"a": net.IPv4(1, 2, 3, 4)},
+		map[string]textUnmarshaler{"a": textUnmarshaler{S: "1.2.3.4"}},
 	},
 	{
 		"a: 2015-02-24T18:19:39Z\n",
-		map[string]time.Time{"a": time.Unix(1424801979, 0).In(time.UTC)},
+		map[string]textUnmarshaler{"a": textUnmarshaler{"2015-02-24T18:19:39Z"}},
 	},
+
+	// Timestamps
 	{
-		"a: 2015-01-01",
-		map[string]time.Time{"a": time.Unix(1420070400, 0)},
+		// Date only.
+		"a: 2015-01-01\n",
+		map[string]time.Time{"a": time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)},
+	},
+	{
+		// RFC3339
+		"a: 2015-02-24T18:19:39.12Z\n",
+		map[string]time.Time{"a": time.Date(2015, 2, 24, 18, 19, 39, .12e9, time.UTC)},
+	},
+	{
+		// RFC3339 with short dates.
+		"a: 2015-2-3T3:4:5Z",
+		map[string]time.Time{"a": time.Date(2015, 2, 3, 3, 4, 5, 0, time.UTC)},
+	},
+	{
+		// ISO8601 lower case t
+		"a: 2015-02-24t18:19:39Z\n",
+		map[string]time.Time{"a": time.Date(2015, 2, 24, 18, 19, 39, 0, time.UTC)},
+	},
+	{
+		// space separate, no time zone
+		"a: 2015-02-24 18:19:39\n",
+		map[string]time.Time{"a": time.Date(2015, 2, 24, 18, 19, 39, 0, time.UTC)},
 	},
+	// Some cases not currently handled. Uncomment these when
+	// the code is fixed.
+	//	{
+	//		// space separated with time zone
+	//		"a: 2001-12-14 21:59:43.10 -5",
+	//		map[string]interface{}{"a": time.Date(2001, 12, 14, 21, 59, 43, .1e9, time.UTC)},
+	//	},
+	//	{
+	//		// arbitrary whitespace between fields
+	//		"a: 2001-12-14 \t\t \t21:59:43.10 \t Z",
+	//		map[string]interface{}{"a": time.Date(2001, 12, 14, 21, 59, 43, .1e9, time.UTC)},
+	//	},
 	{
+		// explicit string tag
 		"a: !!str 2015-01-01",
-		map[string]string{"a": "2015-01-01"},
+		map[string]interface{}{"a": "2015-01-01"},
+	},
+	{
+		// explicit timestamp tag on quoted string
+		"a: !!timestamp \"2015-01-01\"",
+		map[string]time.Time{"a": time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)},
+	},
+	{
+		// explicit timestamp tag on unquoted string
+		"a: !!timestamp 2015-01-01",
+		map[string]time.Time{"a": time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)},
 	},
 	{
+		// quoted string that's a valid timestamp
 		"a: \"2015-01-01\"",
 		map[string]interface{}{"a": "2015-01-01"},
 	},
+	{
+		// explicit timestamp tag into interface.
+		"a: !!timestamp \"2015-01-01\"",
+		map[string]interface{}{"a": "2015-01-01"},
+	},
+	{
+		// implicit timestamp tag into interface.
+		"a: 2015-01-01",
+		map[string]interface{}{"a": "2015-01-01"},
+	},
 
 	// Encode empty lists as zero-length slices.
 	{
@@ -668,7 +724,7 @@ func (s *S) TestUnmarshal(c *C) {
 		if _, ok := err.(*yaml.TypeError); !ok {
 			c.Assert(err, IsNil)
 		}
-		c.Assert(value.Elem().Interface(), DeepEquals, item.value)
+		c.Assert(value.Elem().Interface(), DeepEquals, item.value, Commentf("error: %v", err))
 	}
 }
 
@@ -1177,6 +1233,15 @@ func (s *S) TestUnmarshalStrict(c *C) {
 	}
 }
 
+type textUnmarshaler struct {
+	S string
+}
+
+func (t *textUnmarshaler) UnmarshalText(s []byte) error {
+	t.S = string(s)
+	return nil
+}
+
 //var data []byte
 //func init() {
 //	var err error

+ 46 - 14
encode.go

@@ -10,6 +10,7 @@ import (
 	"strconv"
 	"strings"
 	"time"
+	"unicode/utf8"
 )
 
 type encoder struct {
@@ -87,7 +88,8 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
 		return
 	}
 	iface := in.Interface()
-	if m, ok := iface.(Marshaler); ok {
+	switch m := iface.(type) {
+	case Marshaler:
 		v, err := m.MarshalYAML()
 		if err != nil {
 			fail(err)
@@ -97,7 +99,12 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
 			return
 		}
 		in = reflect.ValueOf(v)
-	} else if m, ok := iface.(encoding.TextMarshaler); ok {
+	case time.Time:
+		// Although time.Time implements TextMarshaler,
+		// we don't want to treat it as a string for YAML
+		// purposes because YAML has special support for
+		// timestamps.
+	case encoding.TextMarshaler:
 		text, err := m.MarshalText()
 		if err != nil {
 			fail(err)
@@ -120,7 +127,11 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
 			e.marshal(tag, in.Elem())
 		}
 	case reflect.Struct:
-		e.structv(tag, in)
+		if in.Type() == timeType {
+			e.timev(tag, in)
+		} else {
+			e.structv(tag, in)
+		}
 	case reflect.Slice:
 		if in.Type().Elem() == mapItemType {
 			e.itemsv(tag, in)
@@ -262,23 +273,36 @@ var base60float = regexp.MustCompile(`^[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+(?:\.[0
 func (e *encoder) stringv(tag string, in reflect.Value) {
 	var style yaml_scalar_style_t
 	s := in.String()
-	rtag, rs := resolve("", s)
-	if rtag == yaml_BINARY_TAG {
-		if tag == "" || tag == yaml_STR_TAG {
-			tag = rtag
-			s = rs.(string)
-		} else if tag == yaml_BINARY_TAG {
+	canUsePlain := true
+	switch {
+	case !utf8.ValidString(s):
+		if tag == yaml_BINARY_TAG {
 			failf("explicitly tagged !!binary data must be base64-encoded")
-		} else {
+		}
+		if tag != "" {
 			failf("cannot marshal invalid UTF-8 data as %s", shortTag(tag))
 		}
+		// It can't be encoded directly as YAML so use a binary tag
+		// and encode it as base64.
+		tag = yaml_BINARY_TAG
+		s = encodeBase64(s)
+	case tag == "":
+		// Check to see if it would resolve to a specific
+		// tag when encoded unquoted. If it doesn't,
+		// there's no need to quote it.
+		rtag, _ := resolve("", s)
+		canUsePlain = rtag == yaml_STR_TAG && !isBase60Float(s)
 	}
-	if tag == "" && (rtag != yaml_STR_TAG || isBase60Float(s)) {
-		style = yaml_DOUBLE_QUOTED_SCALAR_STYLE
-	} else if strings.Contains(s, "\n") {
+	// Note: it's possible for user code to emit invalid YAML
+	// if they explicitly specify a tag and a string containing
+	// text that's incompatible with that tag.
+	switch {
+	case strings.Contains(s, "\n"):
 		style = yaml_LITERAL_SCALAR_STYLE
-	} else {
+	case canUsePlain:
 		style = yaml_PLAIN_SCALAR_STYLE
+	default:
+		style = yaml_DOUBLE_QUOTED_SCALAR_STYLE
 	}
 	e.emitScalar(s, "", tag, style)
 }
@@ -303,6 +327,14 @@ func (e *encoder) uintv(tag string, in reflect.Value) {
 	e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE)
 }
 
+func (e *encoder) timev(tag string, in reflect.Value) {
+	t := in.Interface().(time.Time)
+	if tag == "" {
+		tag = yaml_TIMESTAMP_TAG
+	}
+	e.emitScalar(t.Format(time.RFC3339Nano), "", tag, yaml_PLAIN_SCALAR_STYLE)
+}
+
 func (e *encoder) floatv(tag string, in reflect.Value) {
 	// FIXME: Handle 64 bits here.
 	s := strconv.FormatFloat(float64(in.Float()), 'g', -1, 32)

+ 8 - 3
encode_test.go

@@ -305,8 +305,13 @@ var marshalTests = []struct {
 		"a: 1.2.3.4\n",
 	},
 	{
-		map[string]time.Time{"a": time.Unix(1424801979, 0)},
-		"a: 2015-02-24T18:19:39Z\n",
+		map[string]time.Time{"a": time.Date(2015, 2, 24, 18, 19, 39, 0, time.UTC)},
+		"a: !!timestamp 2015-02-24T18:19:39Z\n",
+	},
+	// Ensure timestamp-like strings are quoted.
+	{
+		map[string]string{"a": "2015-02-24T18:19:39Z"},
+		"a: \"2015-02-24T18:19:39Z\"\n",
 	},
 
 	// Ensure strings containing ": " are quoted (reported as PR #43, but not reproducible).
@@ -330,7 +335,7 @@ func (s *S) TestMarshal(c *C) {
 	defer os.Setenv("TZ", os.Getenv("TZ"))
 	os.Setenv("TZ", "UTC")
 	for i, item := range marshalTests {
-		c.Logf("test %d. %q", i, item.data)
+		c.Logf("test %d: %q", i, item.data)
 		data, err := yaml.Marshal(item.value)
 		c.Assert(err, IsNil)
 		c.Assert(string(data), Equals, item.data)

+ 43 - 46
resolve.go

@@ -7,7 +7,6 @@ import (
 	"strconv"
 	"strings"
 	"time"
-	"unicode/utf8"
 )
 
 type resolveMapItem struct {
@@ -126,34 +125,12 @@ func resolve(tag string, in string) (rtag string, out interface{}) {
 
 		case 'D', 'S':
 			// Int, float, or timestamp.
-
-			// Handle custom timestamp formats as described on http://yaml.org/type/timestamp.html
-			// RFC3339 is handled automatically by the time.Time implementation of the
-			// encoding.TextUnmarshaler interface but we are going to explicitly
-			// handle it here. We should only perform timestamp manipulation if
-			// there is either no quotes on the value or there is an explicit !!timestamp tag.
-
-			if shortTag(tag) == shortTag(yaml_TIMESTAMP_TAG) || tag == "" {
-				var possibleTime time.Time
-				if tryTime(time.RFC3339, in, &possibleTime) {
-					return yaml_TIMESTAMP_TAG, possibleTime
-				}
-
-				// valid iso8601
-				if tryTime("2006-01-02t15:04:05.99-07:00", in, &possibleTime) {
-					return yaml_TIMESTAMP_TAG, possibleTime
-				}
-				// space separated
-				if tryTime("2006-01-02 15:04:05.99 -7", in, &possibleTime) {
-					return yaml_TIMESTAMP_TAG, possibleTime
-				}
-				// no time zone
-				if tryTime("2006-01-02 15:04:05.99", in, &possibleTime) {
-					return yaml_TIMESTAMP_TAG, possibleTime
-				}
-				// date (00:00:00Z)
-				if tryTime("2006-01-02", in, &possibleTime) {
-					return yaml_TIMESTAMP_TAG, possibleTime
+			// Only try values as a timestamp if the value is unquoted or there's an explicit
+			// !!timestamp tag.
+			if tag == "" || tag == yaml_TIMESTAMP_TAG {
+				t, ok := parseTimestamp(in)
+				if ok {
+					return yaml_TIMESTAMP_TAG, t
 				}
 			}
 
@@ -203,23 +180,7 @@ func resolve(tag string, in string) (rtag string, out interface{}) {
 			panic("resolveTable item not yet handled: " + string(rune(hint)) + " (with " + in + ")")
 		}
 	}
-	if tag == yaml_BINARY_TAG {
-		return yaml_BINARY_TAG, in
-	}
-	if utf8.ValidString(in) {
-		return yaml_STR_TAG, in
-	}
-	return yaml_BINARY_TAG, encodeBase64(in)
-}
-
-func tryTime(format, value string, t *time.Time) bool {
-	attempt, err := time.Parse(format, value)
-	if err == nil {
-		*t = attempt
-		return true
-	} else {
-		return false
-	}
+	return yaml_STR_TAG, in
 }
 
 // encodeBase64 encodes s as base64 that is broken up into multiple lines
@@ -246,3 +207,39 @@ func encodeBase64(s string) string {
 	}
 	return string(out[:k])
 }
+
+// This is a subset of the formats allowed by the regular expression
+// defined at http://yaml.org/type/timestamp.html.
+var allowedTimestampFormats = []string{
+	"2006-1-2T15:4:5Z07:00",
+	"2006-1-2t15:4:5Z07:00", // RFC3339 with lower-case "t".
+	"2006-1-2 15:4:5",       // space separated with no time zone
+	"2006-1-2",              // date only
+	// Notable exception: time.Parse cannot handle: "2001-12-14 21:59:43.10 -5"
+	// from the set of examples.
+}
+
+// parseTimestamp parses s as a timestamp string and
+// returns the timestamp and reports whether it succeeded.
+// Timestamp formats are defined at http://yaml.org/type/timestamp.html
+func parseTimestamp(s string) (time.Time, bool) {
+	// TODO write code to check all the formats supported by
+	// http://yaml.org/type/timestamp.html instead of using time.Parse.
+
+	// Quick check: all date formats start with YYYY-.
+	i := 0
+	for ; i < len(s); i++ {
+		if c := s[i]; c < '0' || c > '9' {
+			break
+		}
+	}
+	if i != 4 || i == len(s) || s[i] != '-' {
+		return time.Time{}, false
+	}
+	for _, format := range allowedTimestampFormats {
+		if t, err := time.Parse(format, s); err == nil {
+			return t, true
+		}
+	}
+	return time.Time{}, false
+}