genpb.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package gen
  2. import (
  3. "errors"
  4. "fmt"
  5. "io/ioutil"
  6. "path/filepath"
  7. "strings"
  8. "github.com/dsymonds/gotoc/parser"
  9. "github.com/tal-tech/go-zero/core/lang"
  10. "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
  11. astParser "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
  12. "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
  13. )
  14. func (g *defaultRpcGenerator) genPb() error {
  15. importPath, filename := filepath.Split(g.Ctx.ProtoFileSrc)
  16. tree, err := parser.ParseFiles([]string{filename}, []string{importPath})
  17. if err != nil {
  18. return err
  19. }
  20. if len(tree.Files) == 0 {
  21. return errors.New("proto ast parse failed")
  22. }
  23. file := tree.Files[0]
  24. if len(file.Package) == 0 {
  25. return errors.New("expected package, but nothing found")
  26. }
  27. targetStruct := make(map[string]lang.PlaceholderType)
  28. for _, item := range file.Messages {
  29. if len(item.Messages) > 0 {
  30. return fmt.Errorf(`line %v: unexpected inner message near: "%v""`, item.Messages[0].Position.Line, item.Messages[0].Name)
  31. }
  32. name := stringx.From(item.Name)
  33. if _, ok := targetStruct[name.Lower()]; ok {
  34. return fmt.Errorf("line %v: duplicate %v", item.Position.Line, name)
  35. }
  36. targetStruct[name.Lower()] = lang.Placeholder
  37. }
  38. pbPath := g.dirM[dirPb]
  39. protoFileName := filepath.Base(g.Ctx.ProtoFileSrc)
  40. err = g.protocGenGo(pbPath)
  41. if err != nil {
  42. return err
  43. }
  44. pbGo := strings.TrimSuffix(protoFileName, ".proto") + ".pb.go"
  45. pbFile := filepath.Join(pbPath, pbGo)
  46. bts, err := ioutil.ReadFile(pbFile)
  47. if err != nil {
  48. return err
  49. }
  50. aspParser := astParser.NewAstParser(bts, targetStruct, g.Ctx.Console)
  51. ast, err := aspParser.Parse()
  52. if err != nil {
  53. return err
  54. }
  55. if len(ast.Service) == 0 {
  56. return fmt.Errorf("service not found")
  57. }
  58. g.ast = ast
  59. return nil
  60. }
  61. func (g *defaultRpcGenerator) protocGenGo(target string) error {
  62. src := filepath.Dir(g.Ctx.ProtoFileSrc)
  63. sh := fmt.Sprintf(`protoc -I=%s --go_out=plugins=grpc:%s %s`, src, target, g.Ctx.ProtoFileSrc)
  64. stdout, err := execx.Run(sh, "")
  65. if err != nil {
  66. return err
  67. }
  68. if len(stdout) > 0 {
  69. g.Ctx.Info(stdout)
  70. }
  71. return nil
  72. }