pbast.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. package parser
  2. import (
  3. "errors"
  4. "fmt"
  5. "go/ast"
  6. "go/parser"
  7. "go/token"
  8. "io/ioutil"
  9. "sort"
  10. "strings"
  11. "github.com/tal-tech/go-zero/core/lang"
  12. sx "github.com/tal-tech/go-zero/core/stringx"
  13. "github.com/tal-tech/go-zero/tools/goctl/util"
  14. "github.com/tal-tech/go-zero/tools/goctl/util/console"
  15. "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
  16. )
  17. const (
  18. flagStar = "*"
  19. flagDot = "."
  20. suffixServer = "Server"
  21. referenceContext = "context"
  22. unknownPrefix = "XXX_"
  23. ignoreJsonTagExpression = `json:"-"`
  24. )
  25. var (
  26. errorParseError = errors.New("pb parse error")
  27. typeTemplate = `type (
  28. {{.types}}
  29. )`
  30. structTemplate = `{{if .type}}type {{end}}{{.name}} struct {
  31. {{.fields}}
  32. }`
  33. fieldTemplate = `{{if .hasDoc}}{{.doc}}
  34. {{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}`
  35. anyTypeTemplate = "Any struct {\n\tTypeUrl string `json:\"typeUrl\"`\n\tValue []byte `json:\"value\"`\n}"
  36. objectM = make(map[string]*Struct)
  37. )
  38. type (
  39. astParser struct {
  40. filterStruct map[string]lang.PlaceholderType
  41. filterEnum map[string]*Enum
  42. console.Console
  43. fileSet *token.FileSet
  44. proto *Proto
  45. }
  46. Field struct {
  47. Name stringx.String
  48. Type Type
  49. JsonTag string
  50. Document []string
  51. Comment []string
  52. }
  53. Struct struct {
  54. Name stringx.String
  55. Document []string
  56. Comment []string
  57. Field []*Field
  58. }
  59. ConstLit struct {
  60. Name stringx.String
  61. Document []string
  62. Comment []string
  63. Lit []*Lit
  64. }
  65. Lit struct {
  66. Key string
  67. Value int
  68. }
  69. Type struct {
  70. // eg:context.Context
  71. Expression string
  72. // eg: *context.Context
  73. StarExpression string
  74. // Invoke Type Expression
  75. InvokeTypeExpression string
  76. // eg:context
  77. Package string
  78. // eg:Context
  79. Name string
  80. }
  81. Func struct {
  82. Name stringx.String
  83. ParameterIn Type
  84. ParameterOut Type
  85. Document []string
  86. }
  87. RpcService struct {
  88. Name stringx.String
  89. Funcs []*Func
  90. }
  91. // parsing for rpc
  92. PbAst struct {
  93. ContainsAny bool
  94. Imports map[string]string
  95. Structure map[string]*Struct
  96. Service []*RpcService
  97. *Proto
  98. }
  99. )
  100. func MustNewAstParser(proto *Proto, log console.Console) *astParser {
  101. return &astParser{
  102. filterStruct: proto.Message,
  103. filterEnum: proto.Enum,
  104. Console: log,
  105. fileSet: token.NewFileSet(),
  106. proto: proto,
  107. }
  108. }
  109. func (a *astParser) Parse() (*PbAst, error) {
  110. var pbAst PbAst
  111. pbAst.ContainsAny = a.proto.ContainsAny
  112. pbAst.Proto = a.proto
  113. pbAst.Structure = make(map[string]*Struct)
  114. pbAst.Imports = make(map[string]string)
  115. structure, imports, services, err := a.parse(a.proto.PbSrc)
  116. if err != nil {
  117. return nil, err
  118. }
  119. dependencyStructure, err := a.parseExternalDependency()
  120. if err != nil {
  121. return nil, err
  122. }
  123. for k, v := range structure {
  124. pbAst.Structure[k] = v
  125. }
  126. for k, v := range dependencyStructure {
  127. pbAst.Structure[k] = v
  128. }
  129. for key, path := range imports {
  130. pbAst.Imports[key] = path
  131. }
  132. pbAst.Service = append(pbAst.Service, services...)
  133. return &pbAst, nil
  134. }
  135. func (a *astParser) parse(pbSrc string) (structure map[string]*Struct, imports map[string]string, services []*RpcService, retErr error) {
  136. structure = make(map[string]*Struct)
  137. imports = make(map[string]string)
  138. data, err := ioutil.ReadFile(pbSrc)
  139. if err != nil {
  140. retErr = err
  141. return
  142. }
  143. fSet := a.fileSet
  144. f, err := parser.ParseFile(fSet, "", data, parser.ParseComments)
  145. if err != nil {
  146. retErr = err
  147. return
  148. }
  149. commentMap := ast.NewCommentMap(fSet, f, f.Comments)
  150. f.Comments = commentMap.Filter(f).Comments()
  151. strucs, function := a.mustScope(f.Scope, a.mustGetIndentName(f.Name))
  152. for k, v := range strucs {
  153. if v == nil {
  154. continue
  155. }
  156. structure[k] = v
  157. }
  158. importList := f.Imports
  159. for _, item := range importList {
  160. name := a.mustGetIndentName(item.Name)
  161. if item.Path != nil {
  162. imports[name] = item.Path.Value
  163. }
  164. }
  165. services = append(services, function...)
  166. return
  167. }
  168. func (a *astParser) parseExternalDependency() (map[string]*Struct, error) {
  169. m := make(map[string]*Struct)
  170. for _, impo := range a.proto.Import {
  171. ret, _, _, err := a.parse(impo.OriginalPbPath)
  172. if err != nil {
  173. return nil, err
  174. }
  175. for k, v := range ret {
  176. m[k] = v
  177. }
  178. }
  179. return m, nil
  180. }
  181. func (a *astParser) mustScope(scope *ast.Scope, sourcePackage string) (map[string]*Struct, []*RpcService) {
  182. if scope == nil {
  183. return nil, nil
  184. }
  185. objects := scope.Objects
  186. structs := make(map[string]*Struct)
  187. serviceList := make([]*RpcService, 0)
  188. for name, obj := range objects {
  189. decl := obj.Decl
  190. if decl == nil {
  191. continue
  192. }
  193. typeSpec, ok := decl.(*ast.TypeSpec)
  194. if !ok {
  195. continue
  196. }
  197. tp := typeSpec.Type
  198. switch v := tp.(type) {
  199. case *ast.StructType:
  200. st, err := a.parseObject(name, v, sourcePackage)
  201. a.Must(err)
  202. structs[st.Name.Lower()] = st
  203. case *ast.InterfaceType:
  204. if !strings.HasSuffix(name, suffixServer) {
  205. continue
  206. }
  207. list := a.mustServerFunctions(v, sourcePackage)
  208. serviceList = append(serviceList, &RpcService{
  209. Name: stringx.From(strings.TrimSuffix(name, suffixServer)),
  210. Funcs: list,
  211. })
  212. }
  213. }
  214. targetStruct := make(map[string]*Struct)
  215. for st := range a.filterStruct {
  216. lower := strings.ToLower(st)
  217. targetStruct[lower] = structs[lower]
  218. }
  219. return targetStruct, serviceList
  220. }
  221. func (a *astParser) mustServerFunctions(v *ast.InterfaceType, sourcePackage string) []*Func {
  222. funcs := make([]*Func, 0)
  223. methodObject := v.Methods
  224. if methodObject == nil {
  225. return nil
  226. }
  227. for _, method := range methodObject.List {
  228. var item Func
  229. name := a.mustGetIndentName(method.Names[0])
  230. doc := a.parseCommentOrDoc(method.Doc)
  231. item.Name = stringx.From(name)
  232. item.Document = doc
  233. types := method.Type
  234. if types == nil {
  235. funcs = append(funcs, &item)
  236. continue
  237. }
  238. v, ok := types.(*ast.FuncType)
  239. if !ok {
  240. continue
  241. }
  242. params := v.Params
  243. if params != nil {
  244. inList, err := a.parseFields(params.List, true, sourcePackage)
  245. a.Must(err)
  246. for _, data := range inList {
  247. if data.Type.Package == referenceContext {
  248. continue
  249. }
  250. item.ParameterIn = data.Type
  251. break
  252. }
  253. }
  254. results := v.Results
  255. if results != nil {
  256. outList, err := a.parseFields(results.List, true, sourcePackage)
  257. a.Must(err)
  258. for _, data := range outList {
  259. if data.Type.Package == referenceContext {
  260. continue
  261. }
  262. item.ParameterOut = data.Type
  263. break
  264. }
  265. }
  266. funcs = append(funcs, &item)
  267. }
  268. return funcs
  269. }
  270. func (a *astParser) getFieldType(v string, sourcePackage string) Type {
  271. var pkg, name, expression, starExpression, invokeTypeExpression string
  272. if strings.Contains(v, ".") {
  273. starExpression = v
  274. if strings.Contains(v, "*") {
  275. leftIndex := strings.Index(v, "*")
  276. rightIndex := strings.Index(v, ".")
  277. if leftIndex >= 0 {
  278. invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
  279. } else {
  280. invokeTypeExpression = v[rightIndex+1:]
  281. }
  282. } else {
  283. if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
  284. leftIndex := strings.Index(v, "]")
  285. rightIndex := strings.Index(v, ".")
  286. invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
  287. } else {
  288. rightIndex := strings.Index(v, ".")
  289. invokeTypeExpression = v[rightIndex+1:]
  290. }
  291. }
  292. } else {
  293. expression = strings.TrimPrefix(v, flagStar)
  294. switch v {
  295. case "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64",
  296. "bool", "string", "bytes":
  297. invokeTypeExpression = v
  298. break
  299. default:
  300. name = expression
  301. invokeTypeExpression = v
  302. if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
  303. starExpression = strings.ReplaceAll(v, flagStar, flagStar+sourcePackage+".")
  304. } else {
  305. starExpression = fmt.Sprintf("*%v.%v", sourcePackage, name)
  306. invokeTypeExpression = v
  307. }
  308. }
  309. }
  310. expression = strings.TrimPrefix(starExpression, flagStar)
  311. index := strings.LastIndex(expression, flagDot)
  312. if index > 0 {
  313. pkg = expression[0:index]
  314. name = expression[index+1:]
  315. } else {
  316. pkg = sourcePackage
  317. }
  318. return Type{
  319. Expression: expression,
  320. StarExpression: starExpression,
  321. InvokeTypeExpression: invokeTypeExpression,
  322. Package: pkg,
  323. Name: name,
  324. }
  325. }
  326. func (a *astParser) parseObject(structName string, tp *ast.StructType, sourcePackage string) (*Struct, error) {
  327. if data, ok := objectM[structName]; ok {
  328. return data, nil
  329. }
  330. var st Struct
  331. st.Name = stringx.From(structName)
  332. if tp == nil {
  333. return &st, nil
  334. }
  335. fields := tp.Fields
  336. if fields == nil {
  337. objectM[structName] = &st
  338. return &st, nil
  339. }
  340. fieldList := fields.List
  341. members, err := a.parseFields(fieldList, false, sourcePackage)
  342. if err != nil {
  343. return nil, err
  344. }
  345. for _, m := range members {
  346. var field Field
  347. field.Name = m.Name
  348. field.Type = m.Type
  349. field.JsonTag = m.JsonTag
  350. field.Document = m.Document
  351. field.Comment = m.Comment
  352. st.Field = append(st.Field, &field)
  353. }
  354. objectM[structName] = &st
  355. return &st, nil
  356. }
  357. func (a *astParser) parseFields(fields []*ast.Field, onlyType bool, sourcePackage string) ([]*Field, error) {
  358. ret := make([]*Field, 0)
  359. for _, field := range fields {
  360. var item Field
  361. tag := a.parseTag(field.Tag)
  362. if tag == "" && !onlyType {
  363. continue
  364. }
  365. if tag == ignoreJsonTagExpression {
  366. continue
  367. }
  368. item.JsonTag = tag
  369. name := a.parseName(field.Names)
  370. if strings.HasPrefix(name, unknownPrefix) {
  371. continue
  372. }
  373. item.Name = stringx.From(name)
  374. typeName, err := a.parseType(field.Type)
  375. if err != nil {
  376. return nil, err
  377. }
  378. item.Type = a.getFieldType(typeName, sourcePackage)
  379. if onlyType {
  380. ret = append(ret, &item)
  381. continue
  382. }
  383. docs := a.parseCommentOrDoc(field.Doc)
  384. comments := a.parseCommentOrDoc(field.Comment)
  385. item.Document = docs
  386. item.Comment = comments
  387. isInline := name == ""
  388. if isInline {
  389. return nil, a.wrapError(field.Pos(), "unexpected inline type:%s", name)
  390. }
  391. ret = append(ret, &item)
  392. }
  393. return ret, nil
  394. }
  395. func (a *astParser) parseTag(basicLit *ast.BasicLit) string {
  396. if basicLit == nil {
  397. return ""
  398. }
  399. value := basicLit.Value
  400. splits := strings.Split(value, " ")
  401. if len(splits) == 1 {
  402. return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[0], "`", ""))
  403. } else {
  404. return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[1], "`", ""))
  405. }
  406. }
  407. // returns
  408. // resp1:type's string expression,like int、string、[]int64、map[string]User、*User
  409. // resp2:error
  410. func (a *astParser) parseType(expr ast.Expr) (string, error) {
  411. if expr == nil {
  412. return "", errorParseError
  413. }
  414. switch v := expr.(type) {
  415. case *ast.StarExpr:
  416. stringExpr, err := a.parseType(v.X)
  417. if err != nil {
  418. return "", err
  419. }
  420. e := fmt.Sprintf("*%s", stringExpr)
  421. return e, nil
  422. case *ast.Ident:
  423. return a.mustGetIndentName(v), nil
  424. case *ast.MapType:
  425. keyStringExpr, err := a.parseType(v.Key)
  426. if err != nil {
  427. return "", err
  428. }
  429. valueStringExpr, err := a.parseType(v.Value)
  430. if err != nil {
  431. return "", err
  432. }
  433. e := fmt.Sprintf("map[%s]%s", keyStringExpr, valueStringExpr)
  434. return e, nil
  435. case *ast.ArrayType:
  436. stringExpr, err := a.parseType(v.Elt)
  437. if err != nil {
  438. return "", err
  439. }
  440. e := fmt.Sprintf("[]%s", stringExpr)
  441. return e, nil
  442. case *ast.InterfaceType:
  443. return "interface{}", nil
  444. case *ast.SelectorExpr:
  445. join := make([]string, 0)
  446. xIdent, ok := v.X.(*ast.Ident)
  447. xIndentName := a.mustGetIndentName(xIdent)
  448. if ok {
  449. join = append(join, xIndentName)
  450. }
  451. sel := v.Sel
  452. join = append(join, a.mustGetIndentName(sel))
  453. return strings.Join(join, "."), nil
  454. case *ast.ChanType:
  455. return "", a.wrapError(v.Pos(), "unexpected type 'chan'")
  456. case *ast.FuncType:
  457. return "", a.wrapError(v.Pos(), "unexpected type 'func'")
  458. case *ast.StructType:
  459. return "", a.wrapError(v.Pos(), "unexpected inline struct type")
  460. default:
  461. return "", a.wrapError(v.Pos(), "unexpected type '%v'", v)
  462. }
  463. }
  464. func (a *astParser) parseName(names []*ast.Ident) string {
  465. if len(names) == 0 {
  466. return ""
  467. }
  468. name := names[0]
  469. return a.mustGetIndentName(name)
  470. }
  471. func (a *astParser) parseCommentOrDoc(cg *ast.CommentGroup) []string {
  472. if cg == nil {
  473. return nil
  474. }
  475. comments := make([]string, 0)
  476. for _, comment := range cg.List {
  477. if comment == nil {
  478. continue
  479. }
  480. text := strings.TrimSpace(comment.Text)
  481. if text == "" {
  482. continue
  483. }
  484. comments = append(comments, text)
  485. }
  486. return comments
  487. }
  488. func (a *astParser) mustGetIndentName(ident *ast.Ident) string {
  489. if ident == nil {
  490. return ""
  491. }
  492. return ident.Name
  493. }
  494. func (a *astParser) wrapError(pos token.Pos, format string, arg ...interface{}) error {
  495. file := a.fileSet.Position(pos)
  496. return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...))
  497. }
  498. func (f *Func) GetDoc() string {
  499. return strings.Join(f.Document, util.NL)
  500. }
  501. func (f *Func) HaveDoc() bool {
  502. return len(f.Document) > 0
  503. }
  504. func (a *PbAst) GenEnumCode() (string, error) {
  505. var element []string
  506. for _, item := range a.Enum {
  507. code, err := item.GenEnumCode()
  508. if err != nil {
  509. return "", err
  510. }
  511. element = append(element, code)
  512. }
  513. return strings.Join(element, util.NL), nil
  514. }
  515. func (a *PbAst) GenTypesCode() (string, error) {
  516. types := make([]string, 0)
  517. sts := make([]*Struct, 0)
  518. for _, item := range a.Structure {
  519. sts = append(sts, item)
  520. }
  521. sort.Slice(sts, func(i, j int) bool {
  522. return sts[i].Name.Source() < sts[j].Name.Source()
  523. })
  524. for _, s := range sts {
  525. structCode, err := s.genCode(false)
  526. if err != nil {
  527. return "", err
  528. }
  529. if structCode == "" {
  530. continue
  531. }
  532. types = append(types, structCode)
  533. }
  534. types = append(types, a.genAnyCode())
  535. for _, item := range a.Enum {
  536. typeCode, err := item.GenEnumTypeCode()
  537. if err != nil {
  538. return "", err
  539. }
  540. types = append(types, typeCode)
  541. }
  542. buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
  543. "types": strings.Join(types, util.NL+util.NL),
  544. })
  545. if err != nil {
  546. return "", err
  547. }
  548. return buffer.String(), nil
  549. }
  550. func (a *PbAst) genAnyCode() string {
  551. if !a.ContainsAny {
  552. return ""
  553. }
  554. return anyTypeTemplate
  555. }
  556. func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
  557. fields := make([]string, 0)
  558. for _, f := range s.Field {
  559. var comment, doc string
  560. if len(f.Comment) > 0 {
  561. comment = f.Comment[0]
  562. }
  563. doc = strings.Join(f.Document, util.NL)
  564. buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
  565. "name": f.Name.Title(),
  566. "type": f.Type.InvokeTypeExpression,
  567. "tag": f.JsonTag,
  568. "hasDoc": len(f.Document) > 0,
  569. "doc": doc,
  570. "hasComment": len(f.Comment) > 0,
  571. "comment": comment,
  572. })
  573. if err != nil {
  574. return "", err
  575. }
  576. fields = append(fields, buffer.String())
  577. }
  578. buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
  579. "type": containsTypeStatement,
  580. "name": s.Name.Title(),
  581. "fields": strings.Join(fields, util.NL),
  582. })
  583. if err != nil {
  584. return "", err
  585. }
  586. return buffer.String(), nil
  587. }