genlogic.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/util"
  9. )
  10. const (
  11. logicTemplate = `package logic
  12. import (
  13. "context"
  14. {{.imports}}
  15. "github.com/tal-tech/go-zero/core/logx"
  16. )
  17. type {{.logicName}} struct {
  18. ctx context.Context
  19. svcCtx *svc.ServiceContext
  20. logx.Logger
  21. }
  22. func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
  23. return &{{.logicName}}{
  24. ctx: ctx,
  25. svcCtx: svcCtx,
  26. Logger: logx.WithContext(ctx),
  27. }
  28. }
  29. {{.functions}}
  30. `
  31. logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
  32. func (l *{{.logicName}}) {{.method}} (in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
  33. // todo: add your logic here and delete this line
  34. return &{{.package}}.{{.response}}{}, nil
  35. }
  36. `
  37. )
  38. func (g *defaultRpcGenerator) genLogic() error {
  39. logicPath := g.dirM[dirLogic]
  40. protoPkg := g.ast.Package
  41. service := g.ast.Service
  42. for _, item := range service {
  43. for _, method := range item.Funcs {
  44. logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
  45. filename := filepath.Join(logicPath, logicName)
  46. functions, err := genLogicFunction(protoPkg, method)
  47. if err != nil {
  48. return err
  49. }
  50. imports := collection.NewSet()
  51. pbImport := fmt.Sprintf(`%v "%v"`, protoPkg, g.mustGetPackage(dirPb))
  52. svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
  53. imports.AddStr(pbImport, svcImport)
  54. err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
  55. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  56. "functions": functions,
  57. "imports": strings.Join(imports.KeysStr(), "\n"),
  58. }, filename, false)
  59. if err != nil {
  60. return err
  61. }
  62. }
  63. }
  64. return nil
  65. }
  66. func genLogicFunction(packageName string, method *parser.Func) (string, error) {
  67. var functions = make([]string, 0)
  68. buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
  69. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  70. "method": method.Name.Title(),
  71. "package": packageName,
  72. "request": method.InType,
  73. "response": method.OutType,
  74. "hasComment": len(method.Document) > 0,
  75. "comment": strings.Join(method.Document, "\n"),
  76. })
  77. if err != nil {
  78. return "", err
  79. }
  80. functions = append(functions, buffer.String())
  81. return strings.Join(functions, "\n"), nil
  82. }