genmongomodel.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package gen
  2. import (
  3. "fmt"
  4. "strings"
  5. "text/template"
  6. "zero/tools/goctl/api/spec"
  7. "zero/tools/goctl/api/util"
  8. "zero/tools/goctl/model/mongomodel/utils"
  9. )
  10. const (
  11. functionTypeGet = "get" //GetByField return single model
  12. functionTypeFind = "find" // findByField return many model
  13. functionTypeSet = "set" // SetField only set specified field
  14. TagOperate = "o" //字段函数的tag
  15. TagComment = "c" //字段注释的tag
  16. )
  17. type (
  18. FunctionDesc struct {
  19. Type string // get,find,set
  20. FieldName string // 字段名字 eg:Age
  21. FieldType string // 字段类型 eg: string,int64 等
  22. }
  23. )
  24. func GenMongoModel(goFilePath string, needCache bool) error {
  25. structs, imports, err := utils.ParseGoFile(goFilePath)
  26. if err != nil {
  27. return err
  28. }
  29. if len(structs) != 1 {
  30. return fmt.Errorf("only 1 struct should be provided")
  31. }
  32. structStr, err := genStructs(structs)
  33. if err != nil {
  34. return err
  35. }
  36. fp, err := util.ClearAndOpenFile(goFilePath)
  37. if err != nil {
  38. return err
  39. }
  40. defer fp.Close()
  41. var myTemplate string
  42. if needCache {
  43. myTemplate = cacheTemplate
  44. } else {
  45. myTemplate = noCacheTemplate
  46. }
  47. structName := getStructName(structs)
  48. functionList := getFunctionList(structs)
  49. for _, fun := range functionList {
  50. funTmp := genMethodTemplate(fun, needCache)
  51. if funTmp == "" {
  52. continue
  53. }
  54. myTemplate += "\n"
  55. myTemplate += funTmp
  56. myTemplate += "\n"
  57. }
  58. t := template.Must(template.New("mongoTemplate").Parse(myTemplate))
  59. return t.Execute(fp, map[string]string{
  60. "modelName": structName,
  61. "importArray": getImports(imports, needCache),
  62. "modelFields": structStr,
  63. })
  64. }
  65. func getFunctionList(structs []utils.Struct) []FunctionDesc {
  66. var list []FunctionDesc
  67. for _, field := range structs[0].Fields {
  68. tagMap := parseTag(field.Tag)
  69. if fun, ok := tagMap[TagOperate]; ok {
  70. funList := strings.Split(fun, ",")
  71. for _, o := range funList {
  72. var f FunctionDesc
  73. f.FieldType = field.Type
  74. f.FieldName = field.Name
  75. f.Type = o
  76. list = append(list, f)
  77. }
  78. }
  79. }
  80. return list
  81. }
  82. func getStructName(structs []utils.Struct) string {
  83. for _, structItem := range structs {
  84. return structItem.Name
  85. }
  86. return ""
  87. }
  88. func genStructs(structs []utils.Struct) (string, error) {
  89. if len(structs) > 1 {
  90. return "", fmt.Errorf("input .go file must only one struct")
  91. }
  92. modelFields := `Id bson.ObjectId ` + quotationMark + `bson:"_id" json:"id,omitempty"` + quotationMark + "\n\t"
  93. for _, structItem := range structs {
  94. for _, field := range structItem.Fields {
  95. modelFields += getFieldLine(field)
  96. }
  97. }
  98. modelFields += "\t" + `CreateTime time.Time ` + quotationMark + `json:"createTime,omitempty" bson:"createTime"` + quotationMark + "\n\t"
  99. modelFields += "\t" + `UpdateTime time.Time ` + quotationMark + `json:"updateTime,omitempty" bson:"updateTime"` + quotationMark
  100. return modelFields, nil
  101. }
  102. func getFieldLine(member spec.Member) string {
  103. if member.Name == "CreateTime" || member.Name == "UpdateTime" || member.Name == "Id" {
  104. return ""
  105. }
  106. jsonName := utils.UpperCamelToLower(member.Name)
  107. result := "\t" + member.Name + ` ` + member.Type + ` ` + quotationMark + `json:"` + jsonName + `,omitempty"` + ` bson:"` + jsonName + `"` + quotationMark
  108. tagMap := parseTag(member.Tag)
  109. if comment, ok := tagMap[TagComment]; ok {
  110. result += ` //` + comment + "\n\t"
  111. } else {
  112. result += "\n\t"
  113. }
  114. return result
  115. }
  116. // tag like `o:"find,get,update" c:"姓名"`
  117. func parseTag(tag string) map[string]string {
  118. var result = make(map[string]string, 0)
  119. tags := strings.Split(tag, " ")
  120. for _, kv := range tags {
  121. temp := strings.Split(kv, ":")
  122. if len(temp) > 1 {
  123. key := strings.ReplaceAll(strings.ReplaceAll(temp[0], "\"", ""), "`", "")
  124. value := strings.ReplaceAll(strings.ReplaceAll(temp[1], "\"", ""), "`", "")
  125. result[key] = value
  126. }
  127. }
  128. return result
  129. }
  130. func getImports(imports []string, needCache bool) string {
  131. importStr := strings.Join(imports, "\n\t")
  132. importStr += "\"errors\"\n\t"
  133. importStr += "\"time\"\n\t"
  134. importStr += "\n\t\"zero/core/stores/cache\"\n\t"
  135. importStr += "\"zero/core/stores/mongoc\"\n\t"
  136. importStr += "\n\t\"github.com/globalsign/mgo/bson\""
  137. return importStr
  138. }