gencall.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. package gen
  2. import (
  3. "fmt"
  4. "path/filepath"
  5. "strings"
  6. "github.com/tal-tech/go-zero/core/collection"
  7. "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  8. "github.com/tal-tech/go-zero/tools/goctl/templatex"
  9. "github.com/tal-tech/go-zero/tools/goctl/util"
  10. )
  11. const (
  12. typesFilename = "types.go"
  13. callTemplateText = `{{.head}}
  14. //go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
  15. package {{.filePackage}}
  16. import (
  17. "context"
  18. {{.package}}
  19. "github.com/tal-tech/go-zero/core/jsonx"
  20. "github.com/tal-tech/go-zero/zrpc"
  21. )
  22. type (
  23. {{.serviceName}} interface {
  24. {{.interface}}
  25. }
  26. default{{.serviceName}} struct {
  27. cli zrpc.Client
  28. }
  29. )
  30. func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
  31. return &default{{.serviceName}}{
  32. cli: cli,
  33. }
  34. }
  35. {{.functions}}
  36. `
  37. callTemplateTypes = `{{.head}}
  38. package {{.filePackage}}
  39. import "errors"
  40. var errJsonConvert = errors.New("json convert error")
  41. {{.const}}
  42. {{.types}}
  43. `
  44. callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
  45. {{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
  46. callFunctionTemplate = `
  47. {{if .hasComment}}{{.comment}}{{end}}
  48. func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) {
  49. var request {{.pbRequest}}
  50. bts, err := jsonx.Marshal(in)
  51. if err != nil {
  52. return nil, errJsonConvert
  53. }
  54. err = jsonx.Unmarshal(bts, &request)
  55. if err != nil {
  56. return nil, errJsonConvert
  57. }
  58. client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
  59. resp, err := client.{{.method}}(ctx, &request)
  60. if err != nil{
  61. return nil, err
  62. }
  63. var ret {{.pbResponse}}
  64. bts, err = jsonx.Marshal(resp)
  65. if err != nil{
  66. return nil, errJsonConvert
  67. }
  68. err = jsonx.Unmarshal(bts, &ret)
  69. if err != nil{
  70. return nil, errJsonConvert
  71. }
  72. return &ret, nil
  73. }
  74. `
  75. )
  76. func (g *defaultRpcGenerator) genCall() error {
  77. file := g.ast
  78. if len(file.Service) == 0 {
  79. return nil
  80. }
  81. if len(file.Service) > 1 {
  82. return fmt.Errorf("we recommend only one service in a proto, currently %d", len(file.Service))
  83. }
  84. typeCode, err := file.GenTypesCode()
  85. if err != nil {
  86. return err
  87. }
  88. constLit, err := file.GenEnumCode()
  89. if err != nil {
  90. return err
  91. }
  92. service := file.Service[0]
  93. callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower())
  94. if err = util.MkdirIfNotExist(callPath); err != nil {
  95. return err
  96. }
  97. filename := filepath.Join(callPath, typesFilename)
  98. head := templatex.GetHead(g.Ctx.ProtoSource)
  99. err = templatex.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{
  100. "head": head,
  101. "const": constLit,
  102. "filePackage": service.Name.Lower(),
  103. "serviceName": g.Ctx.ServiceName.Title(),
  104. "lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
  105. "types": typeCode,
  106. }, filename, true)
  107. if err != nil {
  108. return err
  109. }
  110. filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
  111. functions, importList, err := g.genFunction(service)
  112. if err != nil {
  113. return err
  114. }
  115. iFunctions, err := g.getInterfaceFuncs(service)
  116. if err != nil {
  117. return err
  118. }
  119. err = templatex.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{
  120. "name": service.Name.Lower(),
  121. "head": head,
  122. "filePackage": service.Name.Lower(),
  123. "package": strings.Join(importList, util.NL),
  124. "serviceName": service.Name.Title(),
  125. "functions": strings.Join(functions, util.NL),
  126. "interface": strings.Join(iFunctions, util.NL),
  127. }, filename, true)
  128. return err
  129. }
  130. func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) {
  131. file := g.ast
  132. pkgName := file.Package
  133. functions := make([]string, 0)
  134. imports := collection.NewSet()
  135. imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
  136. for _, method := range service.Funcs {
  137. imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
  138. buffer, err := templatex.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{
  139. "rpcServiceName": service.Name.Title(),
  140. "method": method.Name.Title(),
  141. "package": pkgName,
  142. "pbRequestName": method.ParameterIn.Name,
  143. "pbRequest": method.ParameterIn.Expression,
  144. "pbResponse": method.ParameterOut.Name,
  145. "hasComment": method.HaveDoc(),
  146. "comment": method.GetDoc(),
  147. })
  148. if err != nil {
  149. return nil, nil, err
  150. }
  151. functions = append(functions, buffer.String())
  152. }
  153. return functions, imports.KeysStr(), nil
  154. }
  155. func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
  156. functions := make([]string, 0)
  157. for _, method := range service.Funcs {
  158. buffer, err := templatex.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute(
  159. map[string]interface{}{
  160. "hasComment": method.HaveDoc(),
  161. "comment": method.GetDoc(),
  162. "method": method.Name.Title(),
  163. "pbRequest": method.ParameterIn.Name,
  164. "pbResponse": method.ParameterOut.Name,
  165. })
  166. if err != nil {
  167. return nil, err
  168. }
  169. functions = append(functions, buffer.String())
  170. }
  171. return functions, nil
  172. }