浏览代码

#71 sort non string map keys

Tao Wen 8 年之前
父节点
当前提交
0c07128d3c
共有 1 个文件被更改,包括 40 次插入7 次删除
  1. 40 7
      feature_reflect_map.go

+ 40 - 7
feature_reflect_map.go

@@ -156,17 +156,25 @@ func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
 	realVal := reflect.ValueOf(*realInterface)
 
 	// Extract and sort the keys.
-	var sv stringValues = realVal.MapKeys()
-	sort.Sort(sv)
+	keys := realVal.MapKeys()
+	sv := make([]reflectWithString, len(keys))
+	for i, v := range keys {
+		sv[i].v = v
+		if err := sv[i].resolve(); err != nil {
+			stream.Error = err
+			return
+		}
+	}
+	sort.Slice(sv, func(i, j int) bool { return sv[i].s < sv[j].s })
 
 	stream.WriteObjectStart()
 	for i, key := range sv {
 		if i != 0 {
 			stream.WriteMore()
 		}
-		encodeMapKey(key, stream)
+		stream.WriteString(key.s)
 		stream.writeByte(':')
-		val := realVal.MapIndex(key).Interface()
+		val := realVal.MapIndex(key.v).Interface()
 		encoder.elemEncoder.EncodeInterface(val, stream)
 	}
 	stream.WriteObjectEnd()
@@ -174,12 +182,37 @@ func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
 
 // stringValues is a slice of reflect.Value holding *reflect.StringValue.
 // It implements the methods to sort by string.
-type stringValues []reflect.Value
+type stringValues []reflectWithString
+
+type reflectWithString struct {
+	v reflect.Value
+	s string
+}
+
+func (w *reflectWithString) resolve() error {
+	if w.v.Kind() == reflect.String {
+		w.s = w.v.String()
+		return nil
+	}
+	if tm, ok := w.v.Interface().(encoding.TextMarshaler); ok {
+		buf, err := tm.MarshalText()
+		w.s = string(buf)
+		return err
+	}
+	switch w.v.Kind() {
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		w.s = strconv.FormatInt(w.v.Int(), 10)
+		return nil
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+		w.s = strconv.FormatUint(w.v.Uint(), 10)
+		return nil
+	}
+	return &json.UnsupportedTypeError{w.v.Type()}
+}
 
 func (sv stringValues) Len() int           { return len(sv) }
 func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
-func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
-func (sv stringValues) get(i int) string   { return sv[i].String() }
+func (sv stringValues) Less(i, j int) bool { return sv[i].s < sv[j].s }
 
 func (encoder *sortKeysMapEncoder) EncodeInterface(val interface{}, stream *Stream) {
 	WriteToStream(val, stream, encoder)