gencall.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/util"
  9. )
  10. const (
  11. typesFilename = "types.go"
  12. callTemplateText = `{{.head}}
  13. //go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
  14. package {{.filePackage}}
  15. import (
  16. "context"
  17. {{.package}}
  18. "github.com/tal-tech/go-zero/core/jsonx"
  19. "github.com/tal-tech/go-zero/zrpc"
  20. )
  21. type (
  22. {{.serviceName}} interface {
  23. {{.interface}}
  24. }
  25. default{{.serviceName}} struct {
  26. cli zrpc.Client
  27. }
  28. )
  29. func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
  30. return &default{{.serviceName}}{
  31. cli: cli,
  32. }
  33. }
  34. {{.functions}}
  35. `
  36. callTemplateTypes = `{{.head}}
  37. package {{.filePackage}}
  38. import "errors"
  39. var errJsonConvert = errors.New("json convert error")
  40. {{.const}}
  41. {{.types}}
  42. `
  43. callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
  44. {{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
  45. callFunctionTemplate = `
  46. {{if .hasComment}}{{.comment}}{{end}}
  47. func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) {
  48. var request {{.pbRequest}}
  49. bts, err := jsonx.Marshal(in)
  50. if err != nil {
  51. return nil, errJsonConvert
  52. }
  53. err = jsonx.Unmarshal(bts, &request)
  54. if err != nil {
  55. return nil, errJsonConvert
  56. }
  57. client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
  58. resp, err := client.{{.method}}(ctx, &request)
  59. if err != nil{
  60. return nil, err
  61. }
  62. var ret {{.pbResponse}}
  63. bts, err = jsonx.Marshal(resp)
  64. if err != nil{
  65. return nil, errJsonConvert
  66. }
  67. err = jsonx.Unmarshal(bts, &ret)
  68. if err != nil{
  69. return nil, errJsonConvert
  70. }
  71. return &ret, nil
  72. }
  73. `
  74. )
  75. func (g *defaultRpcGenerator) genCall() error {
  76. file := g.ast
  77. if len(file.Service) == 0 {
  78. return nil
  79. }
  80. if len(file.Service) > 1 {
  81. return fmt.Errorf("we recommend only one service in a proto, currently %d", len(file.Service))
  82. }
  83. typeCode, err := file.GenTypesCode()
  84. if err != nil {
  85. return err
  86. }
  87. constLit, err := file.GenEnumCode()
  88. if err != nil {
  89. return err
  90. }
  91. service := file.Service[0]
  92. callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower())
  93. if err = util.MkdirIfNotExist(callPath); err != nil {
  94. return err
  95. }
  96. filename := filepath.Join(callPath, typesFilename)
  97. head := util.GetHead(g.Ctx.ProtoSource)
  98. err = util.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{
  99. "head": head,
  100. "const": constLit,
  101. "filePackage": service.Name.Lower(),
  102. "serviceName": g.Ctx.ServiceName.Title(),
  103. "lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
  104. "types": typeCode,
  105. }, filename, true)
  106. if err != nil {
  107. return err
  108. }
  109. filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
  110. functions, importList, err := g.genFunction(service)
  111. if err != nil {
  112. return err
  113. }
  114. iFunctions, err := g.getInterfaceFuncs(service)
  115. if err != nil {
  116. return err
  117. }
  118. err = util.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{
  119. "name": service.Name.Lower(),
  120. "head": head,
  121. "filePackage": service.Name.Lower(),
  122. "package": strings.Join(importList, util.NL),
  123. "serviceName": service.Name.Title(),
  124. "functions": strings.Join(functions, util.NL),
  125. "interface": strings.Join(iFunctions, util.NL),
  126. }, filename, true)
  127. return err
  128. }
  129. func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) {
  130. file := g.ast
  131. pkgName := file.Package
  132. functions := make([]string, 0)
  133. imports := collection.NewSet()
  134. imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
  135. for _, method := range service.Funcs {
  136. imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
  137. buffer, err := util.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{
  138. "rpcServiceName": service.Name.Title(),
  139. "method": method.Name.Title(),
  140. "package": pkgName,
  141. "pbRequestName": method.ParameterIn.Name,
  142. "pbRequest": method.ParameterIn.Expression,
  143. "pbResponse": method.ParameterOut.Name,
  144. "hasComment": method.HaveDoc(),
  145. "comment": method.GetDoc(),
  146. })
  147. if err != nil {
  148. return nil, nil, err
  149. }
  150. functions = append(functions, buffer.String())
  151. }
  152. return functions, imports.KeysStr(), nil
  153. }
  154. func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
  155. functions := make([]string, 0)
  156. for _, method := range service.Funcs {
  157. buffer, err := util.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute(
  158. map[string]interface{}{
  159. "hasComment": method.HaveDoc(),
  160. "comment": method.GetDoc(),
  161. "method": method.Name.Title(),
  162. "pbRequest": method.ParameterIn.Name,
  163. "pbResponse": method.ParameterOut.Name,
  164. })
  165. if err != nil {
  166. return nil, err
  167. }
  168. functions = append(functions, buffer.String())
  169. }
  170. return functions, nil
  171. }