parser.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. package parser
  2. import (
  3. "fmt"
  4. "sort"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
  8. "github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
  9. "github.com/tal-tech/go-zero/tools/goctl/util/console"
  10. "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
  11. "github.com/xwb1989/sqlparser"
  12. )
  13. const timeImport = "time.Time"
  14. type (
  15. // Table describes a mysql table
  16. Table struct {
  17. Name stringx.String
  18. PrimaryKey Primary
  19. UniqueIndex map[string][]*Field
  20. NormalIndex map[string][]*Field
  21. Fields []*Field
  22. }
  23. // Primary describes a primary key
  24. Primary struct {
  25. Field
  26. AutoIncrement bool
  27. }
  28. // Field describes a table field
  29. Field struct {
  30. Name stringx.String
  31. DataBaseType string
  32. DataType string
  33. Comment string
  34. SeqInIndex int
  35. OrdinalPosition int
  36. }
  37. // KeyType types alias of int
  38. KeyType int
  39. )
  40. // Parse parses ddl into golang structure
  41. func Parse(ddl string) (*Table, error) {
  42. stmt, err := sqlparser.ParseStrictDDL(ddl)
  43. if err != nil {
  44. return nil, err
  45. }
  46. ddlStmt, ok := stmt.(*sqlparser.DDL)
  47. if !ok {
  48. return nil, errUnsupportDDL
  49. }
  50. action := ddlStmt.Action
  51. if action != sqlparser.CreateStr {
  52. return nil, fmt.Errorf("expected [CREATE] action,but found: %s", action)
  53. }
  54. tableName := ddlStmt.NewName.Name.String()
  55. tableSpec := ddlStmt.TableSpec
  56. if tableSpec == nil {
  57. return nil, errTableBodyNotFound
  58. }
  59. columns := tableSpec.Columns
  60. indexes := tableSpec.Indexes
  61. primaryColumn, uniqueKeyMap, normalKeyMap, err := convertIndexes(indexes)
  62. if err != nil {
  63. return nil, err
  64. }
  65. primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
  66. if err != nil {
  67. return nil, err
  68. }
  69. var fields []*Field
  70. for _, e := range fieldM {
  71. fields = append(fields, e)
  72. }
  73. var (
  74. uniqueIndex = make(map[string][]*Field)
  75. normalIndex = make(map[string][]*Field)
  76. )
  77. for indexName, each := range uniqueKeyMap {
  78. for _, columnName := range each {
  79. uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
  80. }
  81. }
  82. for indexName, each := range normalKeyMap {
  83. for _, columnName := range each {
  84. normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
  85. }
  86. }
  87. checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
  88. return &Table{
  89. Name: stringx.From(tableName),
  90. PrimaryKey: primaryKey,
  91. UniqueIndex: uniqueIndex,
  92. NormalIndex: normalIndex,
  93. Fields: fields,
  94. }, nil
  95. }
  96. func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
  97. log := console.NewColorConsole()
  98. uniqueSet := collection.NewSet()
  99. for k, i := range uniqueIndex {
  100. var list []string
  101. for _, e := range i {
  102. list = append(list, e.Name.Source())
  103. }
  104. joinRet := strings.Join(list, ",")
  105. if uniqueSet.Contains(joinRet) {
  106. log.Warning("table %s: duplicate unique index %s", tableName, joinRet)
  107. delete(uniqueIndex, k)
  108. continue
  109. }
  110. uniqueSet.AddStr(joinRet)
  111. }
  112. normalIndexSet := collection.NewSet()
  113. for k, i := range normalIndex {
  114. var list []string
  115. for _, e := range i {
  116. list = append(list, e.Name.Source())
  117. }
  118. joinRet := strings.Join(list, ",")
  119. if normalIndexSet.Contains(joinRet) {
  120. log.Warning("table %s: duplicate index %s", tableName, joinRet)
  121. delete(normalIndex, k)
  122. continue
  123. }
  124. normalIndexSet.Add(joinRet)
  125. }
  126. }
  127. func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
  128. var (
  129. primaryKey Primary
  130. fieldM = make(map[string]*Field)
  131. )
  132. for _, column := range columns {
  133. if column == nil {
  134. continue
  135. }
  136. var comment string
  137. if column.Type.Comment != nil {
  138. comment = string(column.Type.Comment.Val)
  139. }
  140. isDefaultNull := true
  141. if column.Type.NotNull {
  142. isDefaultNull = false
  143. } else {
  144. if column.Type.Default == nil {
  145. isDefaultNull = false
  146. } else if string(column.Type.Default.Val) != "null" {
  147. isDefaultNull = false
  148. }
  149. }
  150. dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
  151. if err != nil {
  152. return Primary{}, nil, err
  153. }
  154. var field Field
  155. field.Name = stringx.From(column.Name.String())
  156. field.DataBaseType = column.Type.Type
  157. field.DataType = dataType
  158. field.Comment = comment
  159. if field.Name.Source() == primaryColumn {
  160. primaryKey = Primary{
  161. Field: field,
  162. AutoIncrement: bool(column.Type.Autoincrement),
  163. }
  164. }
  165. fieldM[field.Name.Source()] = &field
  166. }
  167. return primaryKey, fieldM, nil
  168. }
  169. func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) {
  170. var primaryColumn string
  171. uniqueKeyMap := make(map[string][]string)
  172. normalKeyMap := make(map[string][]string)
  173. isCreateTimeOrUpdateTime := func(name string) bool {
  174. camelColumnName := stringx.From(name).ToCamel()
  175. // by default, createTime|updateTime findOne is not used.
  176. return camelColumnName == "CreateTime" || camelColumnName == "UpdateTime"
  177. }
  178. for _, index := range indexes {
  179. info := index.Info
  180. if info == nil {
  181. continue
  182. }
  183. indexName := index.Info.Name.String()
  184. if info.Primary {
  185. if len(index.Columns) > 1 {
  186. return "", nil, nil, errPrimaryKey
  187. }
  188. columnName := index.Columns[0].Column.String()
  189. if isCreateTimeOrUpdateTime(columnName) {
  190. continue
  191. }
  192. primaryColumn = columnName
  193. continue
  194. } else if info.Unique {
  195. for _, each := range index.Columns {
  196. columnName := each.Column.String()
  197. if isCreateTimeOrUpdateTime(columnName) {
  198. break
  199. }
  200. uniqueKeyMap[indexName] = append(uniqueKeyMap[indexName], columnName)
  201. }
  202. } else if info.Spatial {
  203. // do nothing
  204. } else {
  205. for _, each := range index.Columns {
  206. columnName := each.Column.String()
  207. if isCreateTimeOrUpdateTime(columnName) {
  208. break
  209. }
  210. normalKeyMap[indexName] = append(normalKeyMap[indexName], each.Column.String())
  211. }
  212. }
  213. }
  214. return primaryColumn, uniqueKeyMap, normalKeyMap, nil
  215. }
  216. // ContainsTime returns true if contains golang type time.Time
  217. func (t *Table) ContainsTime() bool {
  218. for _, item := range t.Fields {
  219. if item.DataType == timeImport {
  220. return true
  221. }
  222. }
  223. return false
  224. }
  225. // ConvertDataType converts mysql data type into golang data type
  226. func ConvertDataType(table *model.Table) (*Table, error) {
  227. isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
  228. primaryDataType, err := converter.ConvertDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
  229. if err != nil {
  230. return nil, err
  231. }
  232. var reply Table
  233. reply.UniqueIndex = map[string][]*Field{}
  234. reply.NormalIndex = map[string][]*Field{}
  235. reply.Name = stringx.From(table.Table)
  236. seqInIndex := 0
  237. if table.PrimaryKey.Index != nil {
  238. seqInIndex = table.PrimaryKey.Index.SeqInIndex
  239. }
  240. reply.PrimaryKey = Primary{
  241. Field: Field{
  242. Name: stringx.From(table.PrimaryKey.Name),
  243. DataBaseType: table.PrimaryKey.DataType,
  244. DataType: primaryDataType,
  245. Comment: table.PrimaryKey.Comment,
  246. SeqInIndex: seqInIndex,
  247. OrdinalPosition: table.PrimaryKey.OrdinalPosition,
  248. },
  249. AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
  250. }
  251. fieldM, err := getTableFields(table)
  252. if err != nil {
  253. return nil, err
  254. }
  255. for _, each := range fieldM {
  256. reply.Fields = append(reply.Fields, each)
  257. }
  258. sort.Slice(reply.Fields, func(i, j int) bool {
  259. return reply.Fields[i].OrdinalPosition < reply.Fields[j].OrdinalPosition
  260. })
  261. uniqueIndexSet := collection.NewSet()
  262. log := console.NewColorConsole()
  263. for indexName, each := range table.UniqueIndex {
  264. sort.Slice(each, func(i, j int) bool {
  265. if each[i].Index != nil {
  266. return each[i].Index.SeqInIndex < each[j].Index.SeqInIndex
  267. }
  268. return false
  269. })
  270. if len(each) == 1 {
  271. one := each[0]
  272. if one.Name == table.PrimaryKey.Name {
  273. log.Warning("table %s: duplicate unique index with primary key, %s", table.Table, one.Name)
  274. continue
  275. }
  276. }
  277. var list []*Field
  278. var uniqueJoin []string
  279. for _, c := range each {
  280. list = append(list, fieldM[c.Name])
  281. uniqueJoin = append(uniqueJoin, c.Name)
  282. }
  283. uniqueKey := strings.Join(uniqueJoin, ",")
  284. if uniqueIndexSet.Contains(uniqueKey) {
  285. log.Warning("table %s: duplicate unique index, %s", table.Table, uniqueKey)
  286. continue
  287. }
  288. uniqueIndexSet.AddStr(uniqueKey)
  289. reply.UniqueIndex[indexName] = list
  290. }
  291. normalIndexSet := collection.NewSet()
  292. for indexName, each := range table.NormalIndex {
  293. var list []*Field
  294. var normalJoin []string
  295. for _, c := range each {
  296. list = append(list, fieldM[c.Name])
  297. normalJoin = append(normalJoin, c.Name)
  298. }
  299. normalKey := strings.Join(normalJoin, ",")
  300. if normalIndexSet.Contains(normalKey) {
  301. log.Warning("table %s: duplicate index, %s", table.Table, normalKey)
  302. continue
  303. }
  304. normalIndexSet.AddStr(normalKey)
  305. sort.Slice(list, func(i, j int) bool {
  306. return list[i].SeqInIndex < list[j].SeqInIndex
  307. })
  308. reply.NormalIndex[indexName] = list
  309. }
  310. return &reply, nil
  311. }
  312. func getTableFields(table *model.Table) (map[string]*Field, error) {
  313. fieldM := make(map[string]*Field)
  314. for _, each := range table.Columns {
  315. isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
  316. dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
  317. if err != nil {
  318. return nil, err
  319. }
  320. columnSeqInIndex := 0
  321. if each.Index != nil {
  322. columnSeqInIndex = each.Index.SeqInIndex
  323. }
  324. field := &Field{
  325. Name: stringx.From(each.Name),
  326. DataBaseType: each.DataType,
  327. DataType: dt,
  328. Comment: each.Comment,
  329. SeqInIndex: columnSeqInIndex,
  330. OrdinalPosition: each.OrdinalPosition,
  331. }
  332. fieldM[each.Name] = field
  333. }
  334. return fieldM, nil
  335. }