util.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package gogen
  2. import (
  3. "bytes"
  4. "fmt"
  5. goformat "go/format"
  6. "io"
  7. "path/filepath"
  8. "strings"
  9. "text/template"
  10. "github.com/tal-tech/go-zero/core/collection"
  11. "github.com/tal-tech/go-zero/tools/goctl/api/spec"
  12. "github.com/tal-tech/go-zero/tools/goctl/api/util"
  13. ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
  14. "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
  15. )
  16. type fileGenConfig struct {
  17. dir string
  18. subdir string
  19. filename string
  20. templateName string
  21. category string
  22. templateFile string
  23. builtinTemplate string
  24. data interface{}
  25. }
  26. func genFile(c fileGenConfig) error {
  27. fp, created, err := util.MaybeCreateFile(c.dir, c.subdir, c.filename)
  28. if err != nil {
  29. return err
  30. }
  31. if !created {
  32. return nil
  33. }
  34. defer fp.Close()
  35. var text string
  36. if len(c.category) == 0 || len(c.templateFile) == 0 {
  37. text = c.builtinTemplate
  38. } else {
  39. text, err = ctlutil.LoadTemplate(c.category, c.templateFile, c.builtinTemplate)
  40. if err != nil {
  41. return err
  42. }
  43. }
  44. t := template.Must(template.New(c.templateName).Parse(text))
  45. buffer := new(bytes.Buffer)
  46. err = t.Execute(buffer, c.data)
  47. if err != nil {
  48. return err
  49. }
  50. code := formatCode(buffer.String())
  51. _, err = fp.WriteString(code)
  52. return err
  53. }
  54. func getParentPackage(dir string) (string, error) {
  55. abs, err := filepath.Abs(dir)
  56. if err != nil {
  57. return "", err
  58. }
  59. projectCtx, err := ctx.Prepare(abs)
  60. if err != nil {
  61. return "", err
  62. }
  63. return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil
  64. }
  65. func writeProperty(writer io.Writer, name, tag, comment string, tp spec.Type, indent int) error {
  66. util.WriteIndent(writer, indent)
  67. var err error
  68. if len(comment) > 0 {
  69. comment = strings.TrimPrefix(comment, "//")
  70. comment = "//" + comment
  71. _, err = fmt.Fprintf(writer, "%s %s %s %s\n", strings.Title(name), tp.Name(), tag, comment)
  72. } else {
  73. _, err = fmt.Fprintf(writer, "%s %s %s\n", strings.Title(name), tp.Name(), tag)
  74. }
  75. return err
  76. }
  77. func getAuths(api *spec.ApiSpec) []string {
  78. authNames := collection.NewSet()
  79. for _, g := range api.Service.Groups {
  80. jwt := g.GetAnnotation("jwt")
  81. if len(jwt) > 0 {
  82. authNames.Add(jwt)
  83. }
  84. signature := g.GetAnnotation("signature")
  85. if len(signature) > 0 {
  86. authNames.Add(signature)
  87. }
  88. }
  89. return authNames.KeysStr()
  90. }
  91. func getMiddleware(api *spec.ApiSpec) []string {
  92. result := collection.NewSet()
  93. for _, g := range api.Service.Groups {
  94. middleware := g.GetAnnotation("middleware")
  95. if len(middleware) > 0 {
  96. for _, item := range strings.Split(middleware, ",") {
  97. result.Add(strings.TrimSpace(item))
  98. }
  99. }
  100. }
  101. return result.KeysStr()
  102. }
  103. func formatCode(code string) string {
  104. ret, err := goformat.Source([]byte(code))
  105. if err != nil {
  106. return code
  107. }
  108. return string(ret)
  109. }
  110. func responseGoTypeName(r spec.Route, pkg ...string) string {
  111. if r.ResponseType == nil {
  112. return ""
  113. }
  114. resp := golangExpr(r.ResponseType, pkg...)
  115. switch r.ResponseType.(type) {
  116. case spec.DefineStruct:
  117. if !strings.HasPrefix(resp, "*") {
  118. return "*" + resp
  119. }
  120. }
  121. return resp
  122. }
  123. func requestGoTypeName(r spec.Route, pkg ...string) string {
  124. if r.RequestType == nil {
  125. return ""
  126. }
  127. return golangExpr(r.RequestType, pkg...)
  128. }
  129. func golangExpr(ty spec.Type, pkg ...string) string {
  130. switch v := ty.(type) {
  131. case spec.PrimitiveType:
  132. return v.RawName
  133. case spec.DefineStruct:
  134. if len(pkg) > 1 {
  135. panic("package cannot be more than 1")
  136. }
  137. if len(pkg) == 0 {
  138. return v.RawName
  139. }
  140. return fmt.Sprintf("%s.%s", pkg[0], strings.Title(v.RawName))
  141. case spec.ArrayType:
  142. if len(pkg) > 1 {
  143. panic("package cannot be more than 1")
  144. }
  145. if len(pkg) == 0 {
  146. return v.RawName
  147. }
  148. return fmt.Sprintf("[]%s", golangExpr(v.Value, pkg...))
  149. case spec.MapType:
  150. if len(pkg) > 1 {
  151. panic("package cannot be more than 1")
  152. }
  153. if len(pkg) == 0 {
  154. return v.RawName
  155. }
  156. return fmt.Sprintf("map[%s]%s", v.Key, golangExpr(v.Value, pkg...))
  157. case spec.PointerType:
  158. if len(pkg) > 1 {
  159. panic("package cannot be more than 1")
  160. }
  161. if len(pkg) == 0 {
  162. return v.RawName
  163. }
  164. return fmt.Sprintf("*%s", golangExpr(v.Type, pkg...))
  165. case spec.InterfaceType:
  166. return v.RawName
  167. }
  168. return ""
  169. }