genserver.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/util"
  9. )
  10. const (
  11. serverTemplate = `{{.head}}
  12. package server
  13. import (
  14. "context"
  15. {{.imports}}
  16. )
  17. type {{.types}}
  18. func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
  19. return &{{.server}}Server{
  20. svcCtx: svcCtx,
  21. }
  22. }
  23. {{.funcs}}
  24. `
  25. functionTemplate = `
  26. {{if .hasComment}}{{.comment}}{{end}}
  27. func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
  28. l := logic.New{{.logicName}}(ctx,s.svcCtx)
  29. return l.{{.method}}(in)
  30. }
  31. `
  32. typeFmt = `%sServer struct {
  33. svcCtx *svc.ServiceContext
  34. }`
  35. )
  36. func (g *defaultRpcGenerator) genHandler() error {
  37. serverPath := g.dirM[dirServer]
  38. file := g.ast
  39. logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
  40. svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
  41. imports := collection.NewSet()
  42. imports.AddStr(logicImport, svcImport)
  43. head := util.GetHead(g.Ctx.ProtoSource)
  44. for _, service := range file.Service {
  45. filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
  46. serverFile := filepath.Join(serverPath, filename)
  47. funcList, importList, err := g.genFunctions(service)
  48. if err != nil {
  49. return err
  50. }
  51. imports.AddStr(importList...)
  52. text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate)
  53. if err != nil {
  54. return err
  55. }
  56. err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
  57. "head": head,
  58. "types": fmt.Sprintf(typeFmt, service.Name.Title()),
  59. "server": service.Name.Title(),
  60. "imports": strings.Join(imports.KeysStr(), util.NL),
  61. "funcs": strings.Join(funcList, util.NL),
  62. }, serverFile, true)
  63. if err != nil {
  64. return err
  65. }
  66. }
  67. return nil
  68. }
  69. func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, []string, error) {
  70. file := g.ast
  71. pkg := file.Package
  72. var functionList []string
  73. imports := collection.NewSet()
  74. for _, method := range service.Funcs {
  75. if method.ParameterIn.Package == pkg || method.ParameterOut.Package == pkg {
  76. imports.AddStr(fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb)))
  77. }
  78. imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
  79. imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
  80. text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
  81. if err != nil {
  82. return nil, nil, err
  83. }
  84. buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
  85. "server": service.Name.Title(),
  86. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  87. "method": method.Name.Title(),
  88. "package": pkg,
  89. "request": method.ParameterIn.StarExpression,
  90. "response": method.ParameterOut.StarExpression,
  91. "hasComment": method.HaveDoc(),
  92. "comment": method.GetDoc(),
  93. })
  94. if err != nil {
  95. return nil, nil, err
  96. }
  97. functionList = append(functionList, buffer.String())
  98. }
  99. return functionList, imports.KeysStr(), nil
  100. }