util.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. goformat "go/format"
  6. "io"
  7. "path/filepath"
  8. "strings"
  9. "text/template"
  10. "github.com/tal-tech/go-zero/core/collection"
  11. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  12. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  13. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  14. "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
  15. )
  16. type fileGenConfig struct {
  17. dir string
  18. subdir string
  19. filename string
  20. templateName string
  21. category string
  22. templateFile string
  23. builtinTemplate string
  24. data interface{}
  25. }
  26. func genFile(c fileGenConfig) error {
  27. fp, created, err := util.MaybeCreateFile(c.dir, c.subdir, c.filename)
  28. if err != nil {
  29. return err
  30. }
  31. if !created {
  32. return nil
  33. }
  34. defer fp.Close()
  35. var text string
  36. if len(c.category) == 0 || len(c.templateFile) == 0 {
  37. text = c.builtinTemplate
  38. } else {
  39. text, err = ctlutil.LoadTemplate(c.category, c.templateFile, c.builtinTemplate)
  40. if err != nil {
  41. return err
  42. }
  43. }
  44. t := template.Must(template.New(c.templateName).Parse(text))
  45. buffer := new(bytes.Buffer)
  46. err = t.Execute(buffer, c.data)
  47. if err != nil {
  48. return err
  49. }
  50. code := formatCode(buffer.String())
  51. _, err = fp.WriteString(code)
  52. return err
  53. }
  54. func getParentPackage(dir string) (string, error) {
  55. abs, err := filepath.Abs(dir)
  56. if err != nil {
  57. return "", err
  58. }
  59. projectCtx, err := ctx.Prepare(abs)
  60. if err != nil {
  61. return "", err
  62. }
  63. return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil
  64. }
  65. func writeProperty(writer io.Writer, name, tag, comment string, tp spec.Type, indent int) error {
  66. util.WriteIndent(writer, indent)
  67. var err error
  68. if len(comment) > 0 {
  69. comment = strings.TrimPrefix(comment, "//")
  70. comment = "//" + comment
  71. _, err = fmt.Fprintf(writer, "%s %s %s %s\n", strings.Title(name), tp.Name(), tag, comment)
  72. } else {
  73. _, err = fmt.Fprintf(writer, "%s %s %s\n", strings.Title(name), tp.Name(), tag)
  74. }
  75. return err
  76. }
  77. func getAuths(api *spec.ApiSpec) []string {
  78. authNames := collection.NewSet()
  79. for _, g := range api.Service.Groups {
  80. jwt := g.GetAnnotation("jwt")
  81. if len(jwt) > 0 {
  82. authNames.Add(jwt)
  83. }
  84. }
  85. return authNames.KeysStr()
  86. }
  87. func getMiddleware(api *spec.ApiSpec) []string {
  88. result := collection.NewSet()
  89. for _, g := range api.Service.Groups {
  90. middleware := g.GetAnnotation("middleware")
  91. if len(middleware) > 0 {
  92. for _, item := range strings.Split(middleware, ",") {
  93. result.Add(strings.TrimSpace(item))
  94. }
  95. }
  96. }
  97. return result.KeysStr()
  98. }
  99. func formatCode(code string) string {
  100. ret, err := goformat.Source([]byte(code))
  101. if err != nil {
  102. return code
  103. }
  104. return string(ret)
  105. }
  106. func responseGoTypeName(r spec.Route, pkg ...string) string {
  107. if r.ResponseType == nil {
  108. return ""
  109. }
  110. resp := golangExpr(r.ResponseType, pkg...)
  111. switch r.ResponseType.(type) {
  112. case spec.DefineStruct:
  113. if !strings.HasPrefix(resp, "*") {
  114. return "*" + resp
  115. }
  116. }
  117. return resp
  118. }
  119. func requestGoTypeName(r spec.Route, pkg ...string) string {
  120. if r.RequestType == nil {
  121. return ""
  122. }
  123. return golangExpr(r.RequestType, pkg...)
  124. }
  125. func golangExpr(ty spec.Type, pkg ...string) string {
  126. switch v := ty.(type) {
  127. case spec.PrimitiveType:
  128. return v.RawName
  129. case spec.DefineStruct:
  130. if len(pkg) > 1 {
  131. panic("package cannot be more than 1")
  132. }
  133. if len(pkg) == 0 {
  134. return v.RawName
  135. }
  136. return fmt.Sprintf("%s.%s", pkg[0], strings.Title(v.RawName))
  137. case spec.ArrayType:
  138. if len(pkg) > 1 {
  139. panic("package cannot be more than 1")
  140. }
  141. if len(pkg) == 0 {
  142. return v.RawName
  143. }
  144. return fmt.Sprintf("[]%s", golangExpr(v.Value, pkg...))
  145. case spec.MapType:
  146. if len(pkg) > 1 {
  147. panic("package cannot be more than 1")
  148. }
  149. if len(pkg) == 0 {
  150. return v.RawName
  151. }
  152. return fmt.Sprintf("map[%s]%s", v.Key, golangExpr(v.Value, pkg...))
  153. case spec.PointerType:
  154. if len(pkg) > 1 {
  155. panic("package cannot be more than 1")
  156. }
  157. if len(pkg) == 0 {
  158. return v.RawName
  159. }
  160. return fmt.Sprintf("*%s", golangExpr(v.Type, pkg...))
  161. case spec.InterfaceType:
  162. return v.RawName
  163. }
  164. return ""
  165. }