encode.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. package toml
  2. // TODO: Build a decent encoder.
  3. // Interestingly, this isn't as trivial as recursing down the type of the
  4. // value given and outputting the corresponding TOML. In particular, multiple
  5. // TOML types (especially if tuples are added) can map to a single Go type, so
  6. // that the reverse correspondence isn't clear.
  7. //
  8. // One possible avenue is to choose a reasonable default (like structs map
  9. // to hashes), but allow the user to override with struct tags. But this seems
  10. // like a mess.
  11. //
  12. // The other possibility is to scrap an encoder altogether. After all, TOML
  13. // is a configuration file format, and not a data exchange format.
  14. import (
  15. "bufio"
  16. "encoding"
  17. "errors"
  18. "fmt"
  19. "io"
  20. "reflect"
  21. "sort"
  22. "strconv"
  23. "strings"
  24. )
  25. var (
  26. ErrArrayMixedElementTypes = errors.New(
  27. "can't encode array with mixed element types")
  28. ErrArrayNilElement = errors.New(
  29. "can't encode array with nil element")
  30. )
  31. type Encoder struct {
  32. // A single indentation level. By default it is two spaces.
  33. Indent string
  34. w *bufio.Writer
  35. // hasWritten is whether we have written any output to w yet.
  36. hasWritten bool
  37. }
  38. func NewEncoder(w io.Writer) *Encoder {
  39. return &Encoder{
  40. w: bufio.NewWriter(w),
  41. Indent: " ",
  42. }
  43. }
  44. func (enc *Encoder) Encode(v interface{}) error {
  45. rv := eindirect(reflect.ValueOf(v))
  46. if err := enc.encode(Key([]string{}), rv); err != nil {
  47. return err
  48. }
  49. return enc.w.Flush()
  50. }
  51. func (enc *Encoder) encode(key Key, rv reflect.Value) error {
  52. // Special case. If we can marshal the type to text, then we used that.
  53. if _, ok := rv.Interface().(encoding.TextMarshaler); ok {
  54. err := enc.eKeyEq(key)
  55. if err != nil {
  56. return err
  57. }
  58. return enc.eElement(rv)
  59. }
  60. k := rv.Kind()
  61. switch k {
  62. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
  63. reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
  64. reflect.Uint64,
  65. reflect.Float32, reflect.Float64,
  66. reflect.String, reflect.Bool:
  67. err := enc.eKeyEq(key)
  68. if err != nil {
  69. return err
  70. }
  71. return enc.eElement(rv)
  72. case reflect.Array, reflect.Slice:
  73. return enc.eArrayOrSlice(key, rv)
  74. case reflect.Interface:
  75. if rv.IsNil() {
  76. return nil
  77. }
  78. return enc.encode(key, rv.Elem())
  79. case reflect.Map:
  80. if rv.IsNil() {
  81. return nil
  82. }
  83. return enc.eTable(key, rv)
  84. case reflect.Ptr:
  85. if rv.IsNil() {
  86. return nil
  87. }
  88. return enc.encode(key, rv.Elem())
  89. case reflect.Struct:
  90. return enc.eTable(key, rv)
  91. }
  92. return e("Unsupported type for key '%s': %s", key, k)
  93. }
  94. // eElement encodes any value that can be an array element (primitives and
  95. // arrays).
  96. func (enc *Encoder) eElement(rv reflect.Value) error {
  97. ws := func(s string) error {
  98. _, err := io.WriteString(enc.w, s)
  99. return err
  100. }
  101. // By the TOML spec, all floats must have a decimal with at least one
  102. // number on either side.
  103. floatAddDecimal := func(fstr string) string {
  104. if !strings.Contains(fstr, ".") {
  105. return fstr + ".0"
  106. }
  107. return fstr
  108. }
  109. // Special case. Use text marshaler if it's available for this value.
  110. if v, ok := rv.Interface().(encoding.TextMarshaler); ok {
  111. s, err := v.MarshalText()
  112. if err != nil {
  113. return err
  114. }
  115. return ws(string(s))
  116. }
  117. var err error
  118. k := rv.Kind()
  119. switch k {
  120. case reflect.Bool:
  121. err = ws(strconv.FormatBool(rv.Bool()))
  122. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  123. err = ws(strconv.FormatInt(rv.Int(), 10))
  124. case reflect.Uint, reflect.Uint8, reflect.Uint16,
  125. reflect.Uint32, reflect.Uint64:
  126. err = ws(strconv.FormatUint(rv.Uint(), 10))
  127. case reflect.Float32:
  128. err = ws(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
  129. case reflect.Float64:
  130. err = ws(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
  131. case reflect.Array, reflect.Slice:
  132. return enc.eArrayOrSliceElement(rv)
  133. case reflect.Interface:
  134. return enc.eElement(rv.Elem())
  135. case reflect.String:
  136. s := rv.String()
  137. s = strings.NewReplacer(
  138. "\t", "\\t",
  139. "\n", "\\n",
  140. "\r", "\\r",
  141. "\"", "\\\"",
  142. "\\", "\\\\",
  143. ).Replace(s)
  144. err = ws("\"" + s + "\"")
  145. default:
  146. return e("Unexpected primitive type: %s", k)
  147. }
  148. return err
  149. }
  150. func (enc *Encoder) eArrayOrSlice(key Key, rv reflect.Value) error {
  151. // Determine whether this is an array of tables or of primitives.
  152. elemV := reflect.ValueOf(nil)
  153. if rv.Len() > 0 {
  154. elemV = rv.Index(0)
  155. }
  156. isTableType, err := isTOMLTableType(rv.Type().Elem(), elemV)
  157. if err != nil {
  158. return err
  159. }
  160. if len(key) > 0 && isTableType {
  161. return enc.eArrayOfTables(key, rv)
  162. }
  163. err = enc.eKeyEq(key)
  164. if err != nil {
  165. return err
  166. }
  167. return enc.eArrayOrSliceElement(rv)
  168. }
  169. func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) error {
  170. if _, err := enc.w.Write([]byte{'['}); err != nil {
  171. return err
  172. }
  173. length := rv.Len()
  174. if length > 0 {
  175. arrayElemType, isNil := tomlTypeName(rv.Index(0))
  176. if isNil {
  177. return ErrArrayNilElement
  178. }
  179. for i := 0; i < length; i++ {
  180. elem := rv.Index(i)
  181. // Ensure that the array's elements each have the same TOML type.
  182. elemType, isNil := tomlTypeName(elem)
  183. if isNil {
  184. return ErrArrayNilElement
  185. }
  186. if elemType != arrayElemType {
  187. return ErrArrayMixedElementTypes
  188. }
  189. if err := enc.eElement(elem); err != nil {
  190. return err
  191. }
  192. if i != length-1 {
  193. if _, err := enc.w.Write([]byte(", ")); err != nil {
  194. return err
  195. }
  196. }
  197. }
  198. }
  199. if _, err := enc.w.Write([]byte{']'}); err != nil {
  200. return err
  201. }
  202. return nil
  203. }
  204. func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) error {
  205. if enc.hasWritten {
  206. _, err := enc.w.Write([]byte{'\n'})
  207. if err != nil {
  208. return err
  209. }
  210. }
  211. for i := 0; i < rv.Len(); i++ {
  212. trv := rv.Index(i)
  213. if isNil(trv) {
  214. continue
  215. }
  216. _, err := fmt.Fprintf(enc.w, "%s[[%s]]\n",
  217. strings.Repeat(enc.Indent, len(key)-1), key.String())
  218. if err != nil {
  219. return err
  220. }
  221. err = enc.eMapOrStruct(key, trv)
  222. if err != nil {
  223. return err
  224. }
  225. if i != rv.Len()-1 {
  226. if _, err := enc.w.Write([]byte("\n\n")); err != nil {
  227. return err
  228. }
  229. }
  230. enc.hasWritten = true
  231. }
  232. return nil
  233. }
  234. func isStructOrMap(rv reflect.Value) bool {
  235. switch rv.Kind() {
  236. case reflect.Interface, reflect.Ptr:
  237. return isStructOrMap(rv.Elem())
  238. case reflect.Map, reflect.Struct:
  239. return true
  240. default:
  241. return false
  242. }
  243. }
  244. func (enc *Encoder) eTable(key Key, rv reflect.Value) error {
  245. if enc.hasWritten {
  246. _, err := enc.w.Write([]byte{'\n'})
  247. if err != nil {
  248. return err
  249. }
  250. }
  251. if len(key) > 0 {
  252. _, err := fmt.Fprintf(enc.w, "%s[%s]\n",
  253. strings.Repeat(enc.Indent, len(key)-1), key.String())
  254. if err != nil {
  255. return err
  256. }
  257. }
  258. return enc.eMapOrStruct(key, rv)
  259. }
  260. func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) error {
  261. switch rv.Kind() {
  262. case reflect.Map:
  263. return enc.eMap(key, rv)
  264. case reflect.Struct:
  265. return enc.eStruct(key, rv)
  266. case reflect.Ptr, reflect.Interface:
  267. return enc.eMapOrStruct(key, rv.Elem())
  268. default:
  269. panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
  270. }
  271. }
  272. func (enc *Encoder) eMap(key Key, rv reflect.Value) error {
  273. rt := rv.Type()
  274. if rt.Key().Kind() != reflect.String {
  275. return errors.New("can't encode a map with non-string key type")
  276. }
  277. // Sort keys so that we have deterministic output. And write keys directly
  278. // underneath this key first, before writing sub-structs or sub-maps.
  279. var mapKeysDirect, mapKeysSub []string
  280. for _, mapKey := range rv.MapKeys() {
  281. k := mapKey.String()
  282. mrv := rv.MapIndex(mapKey)
  283. if isStructOrMap(mrv) {
  284. mapKeysSub = append(mapKeysSub, k)
  285. } else {
  286. mapKeysDirect = append(mapKeysDirect, k)
  287. }
  288. }
  289. var writeMapKeys = func(mapKeys []string) error {
  290. sort.Strings(mapKeys)
  291. for i, mapKey := range mapKeys {
  292. mrv := rv.MapIndex(reflect.ValueOf(mapKey))
  293. if isNil(mrv) {
  294. // Don't write anything for nil fields.
  295. continue
  296. }
  297. if err := enc.encode(key.add(mapKey), mrv); err != nil {
  298. return err
  299. }
  300. if i != len(mapKeys)-1 {
  301. if _, err := enc.w.Write([]byte{'\n'}); err != nil {
  302. return err
  303. }
  304. }
  305. enc.hasWritten = true
  306. }
  307. return nil
  308. }
  309. err := writeMapKeys(mapKeysDirect)
  310. if err != nil {
  311. return err
  312. }
  313. err = writeMapKeys(mapKeysSub)
  314. if err != nil {
  315. return err
  316. }
  317. return nil
  318. }
  319. func (enc *Encoder) eStruct(key Key, rv reflect.Value) error {
  320. // Write keys for fields directly under this key first, because if we write
  321. // a field that creates a new table, then all keys under it will be in that
  322. // table (not the one we're writing here).
  323. rt := rv.Type()
  324. var fieldsDirect, fieldsSub [][]int
  325. var addFields func(rt reflect.Type, rv reflect.Value, start []int)
  326. addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
  327. for i := 0; i < rt.NumField(); i++ {
  328. f := rt.Field(i)
  329. frv := rv.Field(i)
  330. if f.Anonymous {
  331. t := frv.Type()
  332. if t.Kind() == reflect.Ptr {
  333. t = t.Elem()
  334. frv = frv.Elem()
  335. }
  336. addFields(t, frv, f.Index)
  337. } else if isStructOrMap(frv) {
  338. fieldsSub = append(fieldsSub, append(start, f.Index...))
  339. } else {
  340. fieldsDirect = append(fieldsDirect, append(start, f.Index...))
  341. }
  342. }
  343. }
  344. addFields(rt, rv, nil)
  345. var writeFields = func(fields [][]int) error {
  346. for i, fieldIndex := range fields {
  347. sft := rt.FieldByIndex(fieldIndex)
  348. sf := rv.FieldByIndex(fieldIndex)
  349. if isNil(sf) {
  350. // Don't write anything for nil fields.
  351. continue
  352. }
  353. keyName := sft.Tag.Get("toml")
  354. if keyName == "-" {
  355. continue
  356. }
  357. if keyName == "" {
  358. keyName = sft.Name
  359. }
  360. if err := enc.encode(key.add(keyName), sf); err != nil {
  361. return err
  362. }
  363. if i != len(fields)-1 {
  364. if _, err := enc.w.Write([]byte{'\n'}); err != nil {
  365. return err
  366. }
  367. }
  368. enc.hasWritten = true
  369. }
  370. return nil
  371. }
  372. err := writeFields(fieldsDirect)
  373. if err != nil {
  374. return err
  375. }
  376. if len(fieldsDirect) > 0 && len(fieldsSub) > 0 {
  377. _, err = enc.w.Write([]byte{'\n'})
  378. if err != nil {
  379. return err
  380. }
  381. }
  382. err = writeFields(fieldsSub)
  383. if err != nil {
  384. return err
  385. }
  386. return nil
  387. }
  388. // tomlTypeName returns the TOML type name of the Go value's type. It is used to
  389. // determine whether the types of array elements are mixed (which is forbidden).
  390. // If the Go value is nil, then it is illegal for it to be an array element, and
  391. // valueIsNil is returned as true.
  392. func tomlTypeName(rv reflect.Value) (typeName string, valueIsNil bool) {
  393. if isNil(rv) {
  394. return "", true
  395. }
  396. k := rv.Kind()
  397. switch k {
  398. case reflect.Bool:
  399. return "bool", false
  400. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
  401. reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
  402. reflect.Uint64:
  403. return "integer", false
  404. case reflect.Float32, reflect.Float64:
  405. return "float", false
  406. case reflect.Array, reflect.Slice:
  407. return "array", false
  408. case reflect.Ptr, reflect.Interface:
  409. return tomlTypeName(rv.Elem())
  410. case reflect.String:
  411. return "string", false
  412. case reflect.Map, reflect.Struct:
  413. return "table", false
  414. default:
  415. panic("unexpected reflect.Kind: " + k.String())
  416. }
  417. }
  418. // isTOMLTableType returns whether this type and value represents a TOML table
  419. // type (true) or element type (false). Both rt and rv are needed to determine
  420. // this, in case the Go type is interface{} or in case rv is nil. If there is
  421. // some other impossible situation detected, an error is returned.
  422. func isTOMLTableType(rt reflect.Type, rv reflect.Value) (bool, error) {
  423. k := rt.Kind()
  424. switch k {
  425. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
  426. reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
  427. reflect.Uint64,
  428. reflect.Float32, reflect.Float64,
  429. reflect.String, reflect.Bool:
  430. return false, nil
  431. case reflect.Array, reflect.Slice:
  432. // Make sure that these eventually contain an underlying non-table type
  433. // element.
  434. elemV := reflect.ValueOf(nil)
  435. if rv.Len() > 0 {
  436. elemV = rv.Index(0)
  437. }
  438. hasUnderlyingTableType, err := isTOMLTableType(rt.Elem(), elemV)
  439. if err != nil {
  440. return false, err
  441. }
  442. if hasUnderlyingTableType {
  443. return true, errors.New("TOML array element can't contain a table")
  444. }
  445. return false, nil
  446. case reflect.Ptr:
  447. return isTOMLTableType(rt.Elem(), rv.Elem())
  448. case reflect.Interface:
  449. if rv.Kind() == reflect.Interface {
  450. return false, nil
  451. }
  452. return isTOMLTableType(rv.Type(), rv)
  453. case reflect.Map, reflect.Struct:
  454. return true, nil
  455. default:
  456. panic("unexpected reflect.Kind: " + k.String())
  457. }
  458. }
  459. func isNil(rv reflect.Value) bool {
  460. switch rv.Kind() {
  461. case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
  462. return rv.IsNil()
  463. default:
  464. return false
  465. }
  466. }
  467. func (enc *Encoder) eKeyEq(key Key) error {
  468. _, err := io.WriteString(enc.w, strings.Repeat(enc.Indent, len(key)-1))
  469. if err != nil {
  470. return err
  471. }
  472. _, err = io.WriteString(enc.w, key[len(key)-1]+" = ")
  473. if err != nil {
  474. return err
  475. }
  476. return nil
  477. }
  478. func eindirect(v reflect.Value) reflect.Value {
  479. if v.Kind() != reflect.Ptr {
  480. return v
  481. }
  482. return eindirect(reflect.Indirect(v))
  483. }