genserver.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  7. "github.com/tal-tech/go-zero/tools/goctl/util"
  8. )
  9. const (
  10. serverTemplate = `{{.head}}
  11. package server
  12. import (
  13. "context"
  14. {{.imports}}
  15. )
  16. type {{.types}}
  17. func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
  18. return &{{.server}}Server{
  19. svcCtx: svcCtx,
  20. }
  21. }
  22. {{.funcs}}
  23. `
  24. functionTemplate = `
  25. {{if .hasComment}}{{.comment}}{{end}}
  26. func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
  27. l := logic.New{{.logicName}}(ctx,s.svcCtx)
  28. return l.{{.method}}(in)
  29. }
  30. `
  31. typeFmt = `%sServer struct {
  32. svcCtx *svc.ServiceContext
  33. }`
  34. )
  35. func (g *defaultRpcGenerator) genHandler() error {
  36. serverPath := g.dirM[dirServer]
  37. file := g.ast
  38. pkg := file.Package
  39. pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
  40. logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
  41. svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
  42. imports := []string{
  43. pbImport,
  44. logicImport,
  45. svcImport,
  46. }
  47. head := util.GetHead(g.Ctx.ProtoSource)
  48. for _, service := range file.Service {
  49. filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
  50. serverFile := filepath.Join(serverPath, filename)
  51. funcList, err := g.genFunctions(service)
  52. if err != nil {
  53. return err
  54. }
  55. err = util.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{
  56. "head": head,
  57. "types": fmt.Sprintf(typeFmt, service.Name.Title()),
  58. "server": service.Name.Title(),
  59. "imports": strings.Join(imports, "\n\t"),
  60. "funcs": strings.Join(funcList, "\n"),
  61. }, serverFile, true)
  62. if err != nil {
  63. return err
  64. }
  65. }
  66. return nil
  67. }
  68. func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, error) {
  69. file := g.ast
  70. pkg := file.Package
  71. var functionList []string
  72. for _, method := range service.Funcs {
  73. buffer, err := util.With("func").Parse(functionTemplate).Execute(map[string]interface{}{
  74. "server": service.Name.Title(),
  75. "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
  76. "method": method.Name.Title(),
  77. "package": pkg,
  78. "request": method.InType,
  79. "response": method.OutType,
  80. "hasComment": len(method.Document),
  81. "comment": strings.Join(method.Document, "\n"),
  82. })
  83. if err != nil {
  84. return nil, err
  85. }
  86. functionList = append(functionList, buffer.String())
  87. }
  88. return functionList, nil
  89. }