123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- package parser
- import (
- "errors"
- "fmt"
- "go/token"
- "os"
- "path/filepath"
- "strings"
- "unicode"
- "unicode/utf8"
- "github.com/emicklei/proto"
- )
- type (
- defaultProtoParser struct{}
- )
- func NewDefaultProtoParser() *defaultProtoParser {
- return &defaultProtoParser{}
- }
- func (p *defaultProtoParser) Parse(src string) (Proto, error) {
- var ret Proto
- abs, err := filepath.Abs(src)
- if err != nil {
- return Proto{}, err
- }
- r, err := os.Open(abs)
- if err != nil {
- return ret, err
- }
- defer r.Close()
- parser := proto.NewParser(r)
- set, err := parser.Parse()
- if err != nil {
- return ret, err
- }
- var serviceList []Service
- proto.Walk(
- set,
- proto.WithImport(func(i *proto.Import) {
- ret.Import = append(ret.Import, Import{Import: i})
- }),
- proto.WithMessage(func(message *proto.Message) {
- ret.Message = append(ret.Message, Message{Message: message})
- }),
- proto.WithPackage(func(p *proto.Package) {
- ret.Package = Package{Package: p}
- }),
- proto.WithService(func(service *proto.Service) {
- serv := Service{Service: service}
- elements := service.Elements
- for _, el := range elements {
- v, _ := el.(*proto.RPC)
- if v == nil {
- continue
- }
- serv.RPC = append(serv.RPC, &RPC{RPC: v})
- }
- serviceList = append(serviceList, serv)
- }),
- proto.WithOption(func(option *proto.Option) {
- if option.Name == "go_package" {
- ret.GoPackage = option.Constant.Source
- }
- }),
- )
- if len(serviceList) == 0 {
- return ret, errors.New("rpc service not found")
- }
- if len(serviceList) > 1 {
- return ret, errors.New("only one service expected")
- }
- service := serviceList[0]
- name := filepath.Base(abs)
- for _, rpc := range service.RPC {
- if strings.Contains(rpc.RequestType, ".") {
- return ret, fmt.Errorf("line %v:%v, request type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
- }
- if strings.Contains(rpc.ReturnsType, ".") {
- return ret, fmt.Errorf("line %v:%v, returns type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
- }
- }
- if len(ret.GoPackage) == 0 {
- ret.GoPackage = ret.Package.Name
- }
- ret.PbPackage = GoSanitized(filepath.Base(ret.GoPackage))
- ret.Src = abs
- ret.Name = name
- ret.Service = service
- return ret, nil
- }
- // see google.golang.org/protobuf@v1.25.0/internal/strs/strings.go:71
- func GoSanitized(s string) string {
- // Sanitize the input to the set of valid characters,
- // which must be '_' or be in the Unicode L or N categories.
- s = strings.Map(func(r rune) rune {
- if unicode.IsLetter(r) || unicode.IsDigit(r) {
- return r
- }
- return '_'
- }, s)
- // Prepend '_' in the event of a Go keyword conflict or if
- // the identifier is invalid (does not start in the Unicode L category).
- r, _ := utf8.DecodeRuneInString(s)
- if token.Lookup(s).IsKeyword() || !unicode.IsLetter(r) {
- return "_" + s
- }
- return s
- }
- // copy from github.com/golang/protobuf@v1.4.2/protoc-gen-go/generator/generator.go:2648
- func CamelCase(s string) string {
- if s == "" {
- return ""
- }
- t := make([]byte, 0, 32)
- i := 0
- if s[0] == '_' {
- // Need a capital letter; drop the '_'.
- t = append(t, 'X')
- i++
- }
- // Invariant: if the next letter is lower case, it must be converted
- // to upper case.
- // That is, we process a word at a time, where words are marked by _ or
- // upper case letter. Digits are treated as words.
- for ; i < len(s); i++ {
- c := s[i]
- if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) {
- continue // Skip the underscore in s.
- }
- if isASCIIDigit(c) {
- t = append(t, c)
- continue
- }
- // Assume we have a letter now - if not, it's a bogus identifier.
- // The next word is a sequence of characters that must start upper case.
- if isASCIILower(c) {
- c ^= ' ' // Make it a capital letter.
- }
- t = append(t, c) // Guaranteed not lower case.
- // Accept lower case sequence that follows.
- for i+1 < len(s) && isASCIILower(s[i+1]) {
- i++
- t = append(t, s[i])
- }
- }
- return string(t)
- }
- func isASCIILower(c byte) bool {
- return 'a' <= c && c <= 'z'
- }
- // Is c an ASCII digit?
- func isASCIIDigit(c byte) bool {
- return '0' <= c && c <= '9'
- }
|