parser.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package parser
  2. import (
  3. "errors"
  4. "fmt"
  5. "go/token"
  6. "os"
  7. "path/filepath"
  8. "strings"
  9. "unicode"
  10. "unicode/utf8"
  11. "github.com/emicklei/proto"
  12. )
  13. type (
  14. defaultProtoParser struct{}
  15. )
  16. func NewDefaultProtoParser() *defaultProtoParser {
  17. return &defaultProtoParser{}
  18. }
  19. func (p *defaultProtoParser) Parse(src string) (Proto, error) {
  20. var ret Proto
  21. abs, err := filepath.Abs(src)
  22. if err != nil {
  23. return Proto{}, err
  24. }
  25. r, err := os.Open(abs)
  26. if err != nil {
  27. return ret, err
  28. }
  29. defer r.Close()
  30. parser := proto.NewParser(r)
  31. set, err := parser.Parse()
  32. if err != nil {
  33. return ret, err
  34. }
  35. var serviceList []Service
  36. proto.Walk(
  37. set,
  38. proto.WithImport(func(i *proto.Import) {
  39. ret.Import = append(ret.Import, Import{Import: i})
  40. }),
  41. proto.WithMessage(func(message *proto.Message) {
  42. ret.Message = append(ret.Message, Message{Message: message})
  43. }),
  44. proto.WithPackage(func(p *proto.Package) {
  45. ret.Package = Package{Package: p}
  46. }),
  47. proto.WithService(func(service *proto.Service) {
  48. serv := Service{Service: service}
  49. elements := service.Elements
  50. for _, el := range elements {
  51. v, _ := el.(*proto.RPC)
  52. if v == nil {
  53. continue
  54. }
  55. serv.RPC = append(serv.RPC, &RPC{RPC: v})
  56. }
  57. serviceList = append(serviceList, serv)
  58. }),
  59. proto.WithOption(func(option *proto.Option) {
  60. if option.Name == "go_package" {
  61. ret.GoPackage = option.Constant.Source
  62. }
  63. }),
  64. )
  65. if len(serviceList) == 0 {
  66. return ret, errors.New("rpc service not found")
  67. }
  68. if len(serviceList) > 1 {
  69. return ret, errors.New("only one service expected")
  70. }
  71. service := serviceList[0]
  72. name := filepath.Base(abs)
  73. for _, rpc := range service.RPC {
  74. if strings.Contains(rpc.RequestType, ".") {
  75. return ret, fmt.Errorf("line %v:%v, request type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
  76. }
  77. if strings.Contains(rpc.ReturnsType, ".") {
  78. return ret, fmt.Errorf("line %v:%v, returns type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
  79. }
  80. }
  81. if len(ret.GoPackage) == 0 {
  82. ret.GoPackage = ret.Package.Name
  83. }
  84. ret.PbPackage = GoSanitized(filepath.Base(ret.GoPackage))
  85. ret.Src = abs
  86. ret.Name = name
  87. ret.Service = service
  88. return ret, nil
  89. }
  90. // see google.golang.org/protobuf@v1.25.0/internal/strs/strings.go:71
  91. func GoSanitized(s string) string {
  92. // Sanitize the input to the set of valid characters,
  93. // which must be '_' or be in the Unicode L or N categories.
  94. s = strings.Map(func(r rune) rune {
  95. if unicode.IsLetter(r) || unicode.IsDigit(r) {
  96. return r
  97. }
  98. return '_'
  99. }, s)
  100. // Prepend '_' in the event of a Go keyword conflict or if
  101. // the identifier is invalid (does not start in the Unicode L category).
  102. r, _ := utf8.DecodeRuneInString(s)
  103. if token.Lookup(s).IsKeyword() || !unicode.IsLetter(r) {
  104. return "_" + s
  105. }
  106. return s
  107. }
  108. // copy from github.com/golang/protobuf@v1.4.2/protoc-gen-go/generator/generator.go:2648
  109. func CamelCase(s string) string {
  110. if s == "" {
  111. return ""
  112. }
  113. t := make([]byte, 0, 32)
  114. i := 0
  115. if s[0] == '_' {
  116. // Need a capital letter; drop the '_'.
  117. t = append(t, 'X')
  118. i++
  119. }
  120. // Invariant: if the next letter is lower case, it must be converted
  121. // to upper case.
  122. // That is, we process a word at a time, where words are marked by _ or
  123. // upper case letter. Digits are treated as words.
  124. for ; i < len(s); i++ {
  125. c := s[i]
  126. if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) {
  127. continue // Skip the underscore in s.
  128. }
  129. if isASCIIDigit(c) {
  130. t = append(t, c)
  131. continue
  132. }
  133. // Assume we have a letter now - if not, it's a bogus identifier.
  134. // The next word is a sequence of characters that must start upper case.
  135. if isASCIILower(c) {
  136. c ^= ' ' // Make it a capital letter.
  137. }
  138. t = append(t, c) // Guaranteed not lower case.
  139. // Accept lower case sequence that follows.
  140. for i+1 < len(s) && isASCIILower(s[i+1]) {
  141. i++
  142. t = append(t, s[i])
  143. }
  144. }
  145. return string(t)
  146. }
  147. func isASCIILower(c byte) bool {
  148. return 'a' <= c && c <= 'z'
  149. }
  150. // Is c an ASCII digit?
  151. func isASCIIDigit(c byte) bool {
  152. return '0' <= c && c <= '9'
  153. }