rewrite.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. // Copyright 2017 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package pipeline
  5. import (
  6. "bytes"
  7. "fmt"
  8. "go/ast"
  9. "go/constant"
  10. "go/format"
  11. "go/token"
  12. "io"
  13. "os"
  14. "strings"
  15. "golang.org/x/tools/go/loader"
  16. )
  17. const printerType = "golang.org/x/text/message.Printer"
  18. // Rewrite rewrites the Go files in a single package to use the localization
  19. // machinery and rewrites strings to adopt best practices when possible.
  20. // If w is not nil the generated files are written to it, each files with a
  21. // "--- <filename>" header. Otherwise the files are overwritten.
  22. func Rewrite(w io.Writer, args ...string) error {
  23. conf := &loader.Config{
  24. AllowErrors: true, // Allow unused instances of message.Printer.
  25. }
  26. prog, err := loadPackages(conf, args)
  27. if err != nil {
  28. return wrap(err, "")
  29. }
  30. for _, info := range prog.InitialPackages() {
  31. for _, f := range info.Files {
  32. // Associate comments with nodes.
  33. // Pick up initialized Printers at the package level.
  34. r := rewriter{info: info, conf: conf}
  35. for _, n := range info.InitOrder {
  36. if t := r.info.Types[n.Rhs].Type.String(); strings.HasSuffix(t, printerType) {
  37. r.printerVar = n.Lhs[0].Name()
  38. }
  39. }
  40. ast.Walk(&r, f)
  41. w := w
  42. if w == nil {
  43. var err error
  44. if w, err = os.Create(conf.Fset.File(f.Pos()).Name()); err != nil {
  45. return wrap(err, "open failed")
  46. }
  47. } else {
  48. fmt.Fprintln(w, "---", conf.Fset.File(f.Pos()).Name())
  49. }
  50. if err := format.Node(w, conf.Fset, f); err != nil {
  51. return wrap(err, "go format failed")
  52. }
  53. }
  54. }
  55. return nil
  56. }
  57. type rewriter struct {
  58. info *loader.PackageInfo
  59. conf *loader.Config
  60. printerVar string
  61. }
  62. // print returns Go syntax for the specified node.
  63. func (r *rewriter) print(n ast.Node) string {
  64. var buf bytes.Buffer
  65. format.Node(&buf, r.conf.Fset, n)
  66. return buf.String()
  67. }
  68. func (r *rewriter) Visit(n ast.Node) ast.Visitor {
  69. // Save the state by scope.
  70. if _, ok := n.(*ast.BlockStmt); ok {
  71. r := *r
  72. return &r
  73. }
  74. // Find Printers created by assignment.
  75. stmt, ok := n.(*ast.AssignStmt)
  76. if ok {
  77. for _, v := range stmt.Lhs {
  78. if r.printerVar == r.print(v) {
  79. r.printerVar = ""
  80. }
  81. }
  82. for i, v := range stmt.Rhs {
  83. if t := r.info.Types[v].Type.String(); strings.HasSuffix(t, printerType) {
  84. r.printerVar = r.print(stmt.Lhs[i])
  85. return r
  86. }
  87. }
  88. }
  89. // Find Printers created by variable declaration.
  90. spec, ok := n.(*ast.ValueSpec)
  91. if ok {
  92. for _, v := range spec.Names {
  93. if r.printerVar == r.print(v) {
  94. r.printerVar = ""
  95. }
  96. }
  97. for i, v := range spec.Values {
  98. if t := r.info.Types[v].Type.String(); strings.HasSuffix(t, printerType) {
  99. r.printerVar = r.print(spec.Names[i])
  100. return r
  101. }
  102. }
  103. }
  104. if r.printerVar == "" {
  105. return r
  106. }
  107. call, ok := n.(*ast.CallExpr)
  108. if !ok {
  109. return r
  110. }
  111. // TODO: Handle literal values?
  112. sel, ok := call.Fun.(*ast.SelectorExpr)
  113. if !ok {
  114. return r
  115. }
  116. meth := r.info.Selections[sel]
  117. source := r.print(sel.X)
  118. fun := r.print(sel.Sel)
  119. if meth != nil {
  120. source = meth.Recv().String()
  121. fun = meth.Obj().Name()
  122. }
  123. // TODO: remove cheap hack and check if the type either
  124. // implements some interface or is specifically of type
  125. // "golang.org/x/text/message".Printer.
  126. m, ok := rewriteFuncs[source]
  127. if !ok {
  128. return r
  129. }
  130. rewriteType, ok := m[fun]
  131. if !ok {
  132. return r
  133. }
  134. ident := ast.NewIdent(r.printerVar)
  135. ident.NamePos = sel.X.Pos()
  136. sel.X = ident
  137. if rewriteType.method != "" {
  138. sel.Sel.Name = rewriteType.method
  139. }
  140. // Analyze arguments.
  141. argn := rewriteType.arg
  142. if rewriteType.format || argn >= len(call.Args) {
  143. return r
  144. }
  145. hasConst := false
  146. for _, a := range call.Args[argn:] {
  147. if v := r.info.Types[a].Value; v != nil && v.Kind() == constant.String {
  148. hasConst = true
  149. break
  150. }
  151. }
  152. if !hasConst {
  153. return r
  154. }
  155. sel.Sel.Name = rewriteType.methodf
  156. // We are done if there is only a single string that does not need to be
  157. // escaped.
  158. if len(call.Args) == 1 {
  159. s, ok := constStr(r.info, call.Args[0])
  160. if ok && !strings.Contains(s, "%") && !rewriteType.newLine {
  161. return r
  162. }
  163. }
  164. // Rewrite arguments as format string.
  165. expr := &ast.BasicLit{
  166. ValuePos: call.Lparen,
  167. Kind: token.STRING,
  168. }
  169. newArgs := append(call.Args[:argn:argn], expr)
  170. newStr := []string{}
  171. for i, a := range call.Args[argn:] {
  172. if s, ok := constStr(r.info, a); ok {
  173. newStr = append(newStr, strings.Replace(s, "%", "%%", -1))
  174. } else {
  175. newStr = append(newStr, "%v")
  176. newArgs = append(newArgs, call.Args[argn+i])
  177. }
  178. }
  179. s := strings.Join(newStr, rewriteType.sep)
  180. if rewriteType.newLine {
  181. s += "\n"
  182. }
  183. expr.Value = fmt.Sprintf("%q", s)
  184. call.Args = newArgs
  185. // TODO: consider creating an expression instead of a constant string and
  186. // then wrapping it in an escape function or so:
  187. // call.Args[argn+i] = &ast.CallExpr{
  188. // Fun: &ast.SelectorExpr{
  189. // X: ast.NewIdent("message"),
  190. // Sel: ast.NewIdent("Lookup"),
  191. // },
  192. // Args: []ast.Expr{a},
  193. // }
  194. // }
  195. return r
  196. }
  197. type rewriteType struct {
  198. // method is the name of the equivalent method on a printer, or "" if it is
  199. // the same.
  200. method string
  201. // methodf is the method to use if the arguments can be rewritten as a
  202. // arguments to a printf-style call.
  203. methodf string
  204. // format is true if the method takes a formatting string followed by
  205. // substitution arguments.
  206. format bool
  207. // arg indicates the position of the argument to extract. If all is
  208. // positive, all arguments from this argument onwards needs to be extracted.
  209. arg int
  210. sep string
  211. newLine bool
  212. }
  213. // rewriteFuncs list functions that can be directly mapped to the printer
  214. // functions of the message package.
  215. var rewriteFuncs = map[string]map[string]rewriteType{
  216. // TODO: Printer -> *golang.org/x/text/message.Printer
  217. "fmt": {
  218. "Print": rewriteType{methodf: "Printf"},
  219. "Sprint": rewriteType{methodf: "Sprintf"},
  220. "Fprint": rewriteType{methodf: "Fprintf"},
  221. "Println": rewriteType{methodf: "Printf", sep: " ", newLine: true},
  222. "Sprintln": rewriteType{methodf: "Sprintf", sep: " ", newLine: true},
  223. "Fprintln": rewriteType{methodf: "Fprintf", sep: " ", newLine: true},
  224. "Printf": rewriteType{method: "Printf", format: true},
  225. "Sprintf": rewriteType{method: "Sprintf", format: true},
  226. "Fprintf": rewriteType{method: "Fprintf", format: true},
  227. },
  228. }
  229. func constStr(info *loader.PackageInfo, e ast.Expr) (s string, ok bool) {
  230. v := info.Types[e].Value
  231. if v == nil || v.Kind() != constant.String {
  232. return "", false
  233. }
  234. return constant.StringVal(v), true
  235. }