go.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. // Copyright 2017 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package main
  5. import (
  6. "errors"
  7. "fmt"
  8. "go/format"
  9. "reflect"
  10. "sort"
  11. "strings"
  12. "text/template"
  13. "github.com/go-xorm/core"
  14. )
  15. var (
  16. supportComment bool
  17. GoLangTmpl LangTmpl = LangTmpl{
  18. template.FuncMap{"Mapper": mapper.Table2Obj,
  19. "Type": typestring,
  20. "Tag": tag,
  21. "UnTitle": unTitle,
  22. "gt": gt,
  23. "getCol": getCol,
  24. "UpperTitle": upTitle,
  25. },
  26. formatGo,
  27. genGoImports,
  28. }
  29. )
  30. var (
  31. errBadComparisonType = errors.New("invalid type for comparison")
  32. errBadComparison = errors.New("incompatible types for comparison")
  33. errNoComparison = errors.New("missing argument for comparison")
  34. )
  35. type kind int
  36. const (
  37. invalidKind kind = iota
  38. boolKind
  39. complexKind
  40. intKind
  41. floatKind
  42. integerKind
  43. stringKind
  44. uintKind
  45. )
  46. func basicKind(v reflect.Value) (kind, error) {
  47. switch v.Kind() {
  48. case reflect.Bool:
  49. return boolKind, nil
  50. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  51. return intKind, nil
  52. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  53. return uintKind, nil
  54. case reflect.Float32, reflect.Float64:
  55. return floatKind, nil
  56. case reflect.Complex64, reflect.Complex128:
  57. return complexKind, nil
  58. case reflect.String:
  59. return stringKind, nil
  60. }
  61. return invalidKind, errBadComparisonType
  62. }
  63. // eq evaluates the comparison a == b || a == c || ...
  64. func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
  65. v1 := reflect.ValueOf(arg1)
  66. k1, err := basicKind(v1)
  67. if err != nil {
  68. return false, err
  69. }
  70. if len(arg2) == 0 {
  71. return false, errNoComparison
  72. }
  73. for _, arg := range arg2 {
  74. v2 := reflect.ValueOf(arg)
  75. k2, err := basicKind(v2)
  76. if err != nil {
  77. return false, err
  78. }
  79. if k1 != k2 {
  80. return false, errBadComparison
  81. }
  82. truth := false
  83. switch k1 {
  84. case boolKind:
  85. truth = v1.Bool() == v2.Bool()
  86. case complexKind:
  87. truth = v1.Complex() == v2.Complex()
  88. case floatKind:
  89. truth = v1.Float() == v2.Float()
  90. case intKind:
  91. truth = v1.Int() == v2.Int()
  92. case stringKind:
  93. truth = v1.String() == v2.String()
  94. case uintKind:
  95. truth = v1.Uint() == v2.Uint()
  96. default:
  97. panic("invalid kind")
  98. }
  99. if truth {
  100. return true, nil
  101. }
  102. }
  103. return false, nil
  104. }
  105. // lt evaluates the comparison a < b.
  106. func lt(arg1, arg2 interface{}) (bool, error) {
  107. v1 := reflect.ValueOf(arg1)
  108. k1, err := basicKind(v1)
  109. if err != nil {
  110. return false, err
  111. }
  112. v2 := reflect.ValueOf(arg2)
  113. k2, err := basicKind(v2)
  114. if err != nil {
  115. return false, err
  116. }
  117. if k1 != k2 {
  118. return false, errBadComparison
  119. }
  120. truth := false
  121. switch k1 {
  122. case boolKind, complexKind:
  123. return false, errBadComparisonType
  124. case floatKind:
  125. truth = v1.Float() < v2.Float()
  126. case intKind:
  127. truth = v1.Int() < v2.Int()
  128. case stringKind:
  129. truth = v1.String() < v2.String()
  130. case uintKind:
  131. truth = v1.Uint() < v2.Uint()
  132. default:
  133. panic("invalid kind")
  134. }
  135. return truth, nil
  136. }
  137. // le evaluates the comparison <= b.
  138. func le(arg1, arg2 interface{}) (bool, error) {
  139. // <= is < or ==.
  140. lessThan, err := lt(arg1, arg2)
  141. if lessThan || err != nil {
  142. return lessThan, err
  143. }
  144. return eq(arg1, arg2)
  145. }
  146. // gt evaluates the comparison a > b.
  147. func gt(arg1, arg2 interface{}) (bool, error) {
  148. // > is the inverse of <=.
  149. lessOrEqual, err := le(arg1, arg2)
  150. if err != nil {
  151. return false, err
  152. }
  153. return !lessOrEqual, nil
  154. }
  155. func getCol(cols map[string]*core.Column, name string) *core.Column {
  156. return cols[strings.ToLower(name)]
  157. }
  158. func formatGo(src string) (string, error) {
  159. source, err := format.Source([]byte(src))
  160. if err != nil {
  161. return "", err
  162. }
  163. return string(source), nil
  164. }
  165. func genGoImports(tables []*core.Table) map[string]string {
  166. imports := make(map[string]string)
  167. for _, table := range tables {
  168. for _, col := range table.Columns() {
  169. if typestring(col) == "time.Time" {
  170. imports["time"] = "time"
  171. }
  172. }
  173. }
  174. return imports
  175. }
  176. func typestring(col *core.Column) string {
  177. st := col.SQLType
  178. t := core.SQLType2Type(st)
  179. s := t.String()
  180. if s == "[]uint8" {
  181. return "[]byte"
  182. }
  183. return s
  184. }
  185. func tag(table *core.Table, col *core.Column) string {
  186. isNameId := (mapper.Table2Obj(col.Name) == "Id")
  187. isIdPk := isNameId && typestring(col) == "int64"
  188. var res []string
  189. if !col.Nullable {
  190. if !isIdPk {
  191. res = append(res, "not null")
  192. }
  193. }
  194. if col.IsPrimaryKey {
  195. res = append(res, "pk")
  196. }
  197. if col.Default != "" {
  198. res = append(res, "default "+col.Default)
  199. }
  200. if col.IsAutoIncrement {
  201. res = append(res, "autoincr")
  202. }
  203. if col.SQLType.IsTime() && include(created, col.Name) {
  204. res = append(res, "created")
  205. }
  206. if col.SQLType.IsTime() && include(updated, col.Name) {
  207. res = append(res, "updated")
  208. }
  209. if col.SQLType.IsTime() && include(deleted, col.Name) {
  210. res = append(res, "deleted")
  211. }
  212. if supportComment && col.Comment != "" {
  213. res = append(res, fmt.Sprintf("comment('%s')", col.Comment))
  214. }
  215. names := make([]string, 0, len(col.Indexes))
  216. for name := range col.Indexes {
  217. names = append(names, name)
  218. }
  219. sort.Strings(names)
  220. for _, name := range names {
  221. index := table.Indexes[name]
  222. var uistr string
  223. if index.Type == core.UniqueType {
  224. uistr = "unique"
  225. } else if index.Type == core.IndexType {
  226. uistr = "index"
  227. }
  228. if len(index.Cols) > 1 {
  229. uistr += "(" + index.Name + ")"
  230. }
  231. res = append(res, uistr)
  232. }
  233. nstr := col.SQLType.Name
  234. if col.Length != 0 {
  235. if col.Length2 != 0 {
  236. nstr += fmt.Sprintf("(%v,%v)", col.Length, col.Length2)
  237. } else {
  238. nstr += fmt.Sprintf("(%v)", col.Length)
  239. }
  240. } else if len(col.EnumOptions) > 0 { //enum
  241. nstr += "("
  242. opts := ""
  243. enumOptions := make([]string, 0, len(col.EnumOptions))
  244. for enumOption := range col.EnumOptions {
  245. enumOptions = append(enumOptions, enumOption)
  246. }
  247. sort.Strings(enumOptions)
  248. for _, v := range enumOptions {
  249. opts += fmt.Sprintf(",'%v'", v)
  250. }
  251. nstr += strings.TrimLeft(opts, ",")
  252. nstr += ")"
  253. } else if len(col.SetOptions) > 0 { //enum
  254. nstr += "("
  255. opts := ""
  256. setOptions := make([]string, 0, len(col.SetOptions))
  257. for setOption := range col.SetOptions {
  258. setOptions = append(setOptions, setOption)
  259. }
  260. sort.Strings(setOptions)
  261. for _, v := range setOptions {
  262. opts += fmt.Sprintf(",'%v'", v)
  263. }
  264. nstr += strings.TrimLeft(opts, ",")
  265. nstr += ")"
  266. }
  267. res = append(res, nstr)
  268. var tags []string
  269. if genJson {
  270. if include(ignoreColumnsJSON, col.Name) {
  271. tags = append(tags, "json:\"-\"")
  272. } else {
  273. tags = append(tags, "json:\""+col.Name+"\"")
  274. }
  275. }
  276. if len(res) > 0 {
  277. tags = append(tags, "xorm:\""+strings.Join(res, " ")+"\"")
  278. }
  279. if len(tags) > 0 {
  280. return "`" + strings.Join(tags, " ") + "`"
  281. } else {
  282. return ""
  283. }
  284. }
  285. func include(source []string, target string) bool {
  286. for _, s := range source {
  287. if s == target {
  288. return true
  289. }
  290. }
  291. return false
  292. }