reverse.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. // Copyright 2017 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package main
  5. import (
  6. "bytes"
  7. "fmt"
  8. "io/ioutil"
  9. "os"
  10. "path"
  11. "path/filepath"
  12. "regexp"
  13. "strconv"
  14. "strings"
  15. "text/template"
  16. "github.com/go-xorm/core"
  17. "github.com/go-xorm/xorm"
  18. "github.com/lunny/log"
  19. _ "github.com/denisenkom/go-mssqldb"
  20. _ "github.com/go-sql-driver/mysql"
  21. _ "github.com/lib/pq"
  22. _ "github.com/ziutek/mymysql/godrv"
  23. )
  24. var CmdReverse = &Command{
  25. UsageLine: "reverse [-s] driverName datasourceName tmplPath [generatedPath] [tableFilterReg]",
  26. Short: "reverse a db to codes",
  27. Long: `
  28. according database's tables and columns to generate codes for Go, C++ and etc.
  29. -s Generated one go file for every table
  30. driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres
  31. datasourceName Database connection uri, for detail infomation please visit driver's project page
  32. tmplPath Template dir for generated. the default templates dir has provide 1 template
  33. generatedPath This parameter is optional, if blank, the default value is models, then will
  34. generated all codes in models dir
  35. tableFilterReg Table name filter regexp
  36. `,
  37. }
  38. func init() {
  39. CmdReverse.Run = runReverse
  40. CmdReverse.Flags = map[string]bool{
  41. "-s": false,
  42. "-l": false,
  43. }
  44. }
  45. var (
  46. genJson bool = false
  47. ignoreColumnsJSON, created, updated, deleted []string = []string{}, []string{"created_at"}, []string{"updated_at"}, []string{"deleted_at"}
  48. )
  49. func printReversePrompt(flag string) {
  50. }
  51. type Tmpl struct {
  52. Tables []*core.Table
  53. Imports map[string]string
  54. Models string
  55. }
  56. func dirExists(dir string) bool {
  57. d, e := os.Stat(dir)
  58. switch {
  59. case e != nil:
  60. return false
  61. case !d.IsDir():
  62. return false
  63. }
  64. return true
  65. }
  66. func runReverse(cmd *Command, args []string) {
  67. num := checkFlags(cmd.Flags, args, printReversePrompt)
  68. if num == -1 {
  69. return
  70. }
  71. args = args[num:]
  72. if len(args) < 3 {
  73. fmt.Println("params error, please see xorm help reverse")
  74. return
  75. }
  76. var isMultiFile bool = true
  77. if use, ok := cmd.Flags["-s"]; ok {
  78. isMultiFile = !use
  79. }
  80. curPath, err := os.Getwd()
  81. if err != nil {
  82. fmt.Println(err)
  83. return
  84. }
  85. var genDir string
  86. var model string
  87. var filterPat *regexp.Regexp
  88. if len(args) >= 4 {
  89. genDir, err = filepath.Abs(args[3])
  90. if err != nil {
  91. fmt.Println(err)
  92. return
  93. }
  94. //[SWH|+] 经测试,path.Base不能解析windows下的“\”,需要替换为“/”
  95. genDir = strings.Replace(genDir, "\\", "/", -1)
  96. model = path.Base(genDir)
  97. if len(args) >= 5 {
  98. filterPat, err = regexp.Compile(args[4])
  99. if err != nil {
  100. fmt.Println(err)
  101. return
  102. }
  103. }
  104. } else {
  105. model = "models"
  106. genDir = path.Join(curPath, model)
  107. }
  108. dir, err := filepath.Abs(args[2])
  109. if err != nil {
  110. log.Errorf("%v", err)
  111. return
  112. }
  113. if !dirExists(dir) {
  114. log.Errorf("Template %v path is not exist", dir)
  115. return
  116. }
  117. var langTmpl LangTmpl
  118. var ok bool
  119. var lang string = "go"
  120. var prefix string = "" //[SWH|+]
  121. cfgPath := path.Join(dir, "config")
  122. info, err := os.Stat(cfgPath)
  123. var configs map[string]string
  124. if err == nil && !info.IsDir() {
  125. configs = loadConfig(cfgPath)
  126. if l, ok := configs["lang"]; ok {
  127. lang = l
  128. }
  129. if j, ok := configs["genJson"]; ok {
  130. genJson, err = strconv.ParseBool(j)
  131. }
  132. //[SWH|+]
  133. if j, ok := configs["prefix"]; ok {
  134. prefix = j
  135. }
  136. if j, ok := configs["ignoreColumnsJSON"]; ok {
  137. ignoreColumnsJSON = strings.Split(j, ",")
  138. }
  139. if j, ok := configs["created"]; ok {
  140. created = strings.Split(j, ",")
  141. }
  142. if j, ok := configs["updated"]; ok {
  143. updated = strings.Split(j, ",")
  144. }
  145. if j, ok := configs["deleted"]; ok {
  146. deleted = strings.Split(j, ",")
  147. }
  148. }
  149. if langTmpl, ok = langTmpls[lang]; !ok {
  150. fmt.Println("Unsupported programing language", lang)
  151. return
  152. }
  153. os.MkdirAll(genDir, os.ModePerm)
  154. supportComment = (args[0] == "mysql" || args[0] == "mymysql")
  155. Orm, err := xorm.NewEngine(args[0], args[1])
  156. if err != nil {
  157. log.Errorf("%v", err)
  158. return
  159. }
  160. tables, err := Orm.DBMetas()
  161. if err != nil {
  162. log.Errorf("%v", err)
  163. return
  164. }
  165. if filterPat != nil && len(tables) > 0 {
  166. size := 0
  167. for _, t := range tables {
  168. if filterPat.MatchString(t.Name) {
  169. tables[size] = t
  170. size++
  171. }
  172. }
  173. tables = tables[:size]
  174. }
  175. filepath.Walk(dir, func(f string, info os.FileInfo, err error) error {
  176. if info.IsDir() {
  177. return nil
  178. }
  179. if info.Name() == "config" {
  180. return nil
  181. }
  182. bs, err := ioutil.ReadFile(f)
  183. if err != nil {
  184. log.Errorf("%v", err)
  185. return err
  186. }
  187. t := template.New(f)
  188. t.Funcs(langTmpl.Funcs)
  189. tmpl, err := t.Parse(string(bs))
  190. if err != nil {
  191. log.Errorf("%v", err)
  192. return err
  193. }
  194. var w *os.File
  195. fileName := info.Name()
  196. newFileName := fileName[:len(fileName)-4]
  197. ext := path.Ext(newFileName)
  198. if !isMultiFile {
  199. w, err = os.Create(path.Join(genDir, newFileName))
  200. if err != nil {
  201. log.Errorf("%v", err)
  202. return err
  203. }
  204. imports := langTmpl.GenImports(tables)
  205. tbls := make([]*core.Table, 0)
  206. for _, table := range tables {
  207. //[SWH|+]
  208. if prefix != "" {
  209. table.Name = strings.TrimPrefix(table.Name, prefix)
  210. }
  211. tbls = append(tbls, table)
  212. }
  213. newbytes := bytes.NewBufferString("")
  214. t := &Tmpl{Tables: tbls, Imports: imports, Models: model}
  215. err = tmpl.Execute(newbytes, t)
  216. if err != nil {
  217. log.Errorf("%v", err)
  218. return err
  219. }
  220. tplcontent, err := ioutil.ReadAll(newbytes)
  221. if err != nil {
  222. log.Errorf("%v", err)
  223. return err
  224. }
  225. var source string
  226. if langTmpl.Formater != nil {
  227. source, err = langTmpl.Formater(string(tplcontent))
  228. if err != nil {
  229. log.Errorf("%v", err)
  230. return err
  231. }
  232. } else {
  233. source = string(tplcontent)
  234. }
  235. w.WriteString(source)
  236. w.Close()
  237. } else {
  238. for _, table := range tables {
  239. //[SWH|+]
  240. if prefix != "" {
  241. table.Name = strings.TrimPrefix(table.Name, prefix)
  242. }
  243. // imports
  244. tbs := []*core.Table{table}
  245. imports := langTmpl.GenImports(tbs)
  246. w, err := os.Create(path.Join(genDir, table.Name+ext))
  247. if err != nil {
  248. log.Errorf("%v", err)
  249. return err
  250. }
  251. defer w.Close()
  252. newbytes := bytes.NewBufferString("")
  253. t := &Tmpl{Tables: tbs, Imports: imports, Models: model}
  254. err = tmpl.Execute(newbytes, t)
  255. if err != nil {
  256. log.Errorf("%v", err)
  257. return err
  258. }
  259. tplcontent, err := ioutil.ReadAll(newbytes)
  260. if err != nil {
  261. log.Errorf("%v", err)
  262. return err
  263. }
  264. var source string
  265. if langTmpl.Formater != nil {
  266. source, err = langTmpl.Formater(string(tplcontent))
  267. if err != nil {
  268. log.Errorf("%v-%v", err, string(tplcontent))
  269. return err
  270. }
  271. } else {
  272. source = string(tplcontent)
  273. }
  274. w.WriteString(source)
  275. w.Close()
  276. }
  277. }
  278. return nil
  279. })
  280. }