|
|
@@ -20,11 +20,14 @@ import (
|
|
|
"io/ioutil"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
+ "sort"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
|
descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
|
|
|
pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin"
|
|
|
+ "golang.org/x/tools/go/ast/astutil"
|
|
|
)
|
|
|
|
|
|
// Run executes a function as a protoc plugin.
|
|
|
@@ -168,7 +171,7 @@ func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse {
|
|
|
}
|
|
|
}
|
|
|
resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
|
|
|
- Name: proto.String(gf.path),
|
|
|
+ Name: proto.String(gf.filename),
|
|
|
Content: proto.String(string(content)),
|
|
|
})
|
|
|
}
|
|
|
@@ -185,16 +188,17 @@ func (gen *Plugin) FileByName(name string) (f *File, ok bool) {
|
|
|
type File struct {
|
|
|
Desc *descpb.FileDescriptorProto // TODO: protoreflect.FileDescriptor
|
|
|
|
|
|
- Messages []*Message // top-level message declartions
|
|
|
- Generate bool // true if we should generate code for this file
|
|
|
+ GoImportPath GoImportPath // import path of this file's Go package
|
|
|
+ Messages []*Message // top-level message declarations
|
|
|
+ Generate bool // true if we should generate code for this file
|
|
|
}
|
|
|
|
|
|
func newFile(gen *Plugin, p *descpb.FileDescriptorProto) *File {
|
|
|
f := &File{
|
|
|
Desc: p,
|
|
|
}
|
|
|
- for _, d := range p.MessageType {
|
|
|
- f.Messages = append(f.Messages, newMessage(gen, nil, d))
|
|
|
+ for i, mdesc := range p.MessageType {
|
|
|
+ f.Messages = append(f.Messages, newMessage(gen, f, nil, mdesc, i))
|
|
|
}
|
|
|
return f
|
|
|
}
|
|
|
@@ -207,30 +211,40 @@ type Message struct {
|
|
|
Messages []*Message // nested message declarations
|
|
|
}
|
|
|
|
|
|
-func newMessage(gen *Plugin, parent *Message, p *descpb.DescriptorProto) *Message {
|
|
|
+func newMessage(gen *Plugin, f *File, parent *Message, p *descpb.DescriptorProto, index int) *Message {
|
|
|
m := &Message{
|
|
|
- Desc: p,
|
|
|
- GoIdent: camelCase(p.GetName()),
|
|
|
+ Desc: p,
|
|
|
+ GoIdent: GoIdent{
|
|
|
+ GoName: camelCase(p.GetName()),
|
|
|
+ GoImportPath: f.GoImportPath,
|
|
|
+ },
|
|
|
}
|
|
|
if parent != nil {
|
|
|
- m.GoIdent = parent.GoIdent + "_" + m.GoIdent
|
|
|
+ m.GoIdent.GoName = parent.GoIdent.GoName + "_" + m.GoIdent.GoName
|
|
|
}
|
|
|
- for _, nested := range p.GetNestedType() {
|
|
|
- m.Messages = append(m.Messages, newMessage(gen, m, nested))
|
|
|
+ for i, nested := range p.GetNestedType() {
|
|
|
+ m.Messages = append(m.Messages, newMessage(gen, f, m, nested, i))
|
|
|
}
|
|
|
return m
|
|
|
}
|
|
|
|
|
|
// A GeneratedFile is a generated file.
|
|
|
type GeneratedFile struct {
|
|
|
- path string
|
|
|
- buf bytes.Buffer
|
|
|
+ filename string
|
|
|
+ goImportPath GoImportPath
|
|
|
+ buf bytes.Buffer
|
|
|
+ packageNames map[GoImportPath]GoPackageName
|
|
|
+ usedPackageNames map[GoPackageName]bool
|
|
|
}
|
|
|
|
|
|
-// NewGeneratedFile creates a new generated file with the given path.
|
|
|
-func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile {
|
|
|
+// NewGeneratedFile creates a new generated file with the given filename
|
|
|
+// and import path.
|
|
|
+func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
|
|
|
g := &GeneratedFile{
|
|
|
- path: path,
|
|
|
+ filename: filename,
|
|
|
+ goImportPath: goImportPath,
|
|
|
+ packageNames: make(map[GoImportPath]GoPackageName),
|
|
|
+ usedPackageNames: make(map[GoPackageName]bool),
|
|
|
}
|
|
|
gen.genFiles = append(gen.genFiles, g)
|
|
|
return g
|
|
|
@@ -243,11 +257,33 @@ func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile {
|
|
|
// TODO: .meta file annotations.
|
|
|
func (g *GeneratedFile) P(v ...interface{}) {
|
|
|
for _, x := range v {
|
|
|
- fmt.Fprint(&g.buf, x)
|
|
|
+ switch x := x.(type) {
|
|
|
+ case GoIdent:
|
|
|
+ if x.GoImportPath != g.goImportPath {
|
|
|
+ fmt.Fprint(&g.buf, g.goPackageName(x.GoImportPath))
|
|
|
+ fmt.Fprint(&g.buf, ".")
|
|
|
+ }
|
|
|
+ fmt.Fprint(&g.buf, x.GoName)
|
|
|
+ default:
|
|
|
+ fmt.Fprint(&g.buf, x)
|
|
|
+ }
|
|
|
}
|
|
|
fmt.Fprintln(&g.buf)
|
|
|
}
|
|
|
|
|
|
+func (g *GeneratedFile) goPackageName(importPath GoImportPath) GoPackageName {
|
|
|
+ if name, ok := g.packageNames[importPath]; ok {
|
|
|
+ return name
|
|
|
+ }
|
|
|
+ name := cleanPackageName(baseName(string(importPath)))
|
|
|
+ for i, orig := 1, name; g.usedPackageNames[name]; i++ {
|
|
|
+ name = orig + GoPackageName(strconv.Itoa(i))
|
|
|
+ }
|
|
|
+ g.packageNames[importPath] = name
|
|
|
+ g.usedPackageNames[name] = true
|
|
|
+ return name
|
|
|
+}
|
|
|
+
|
|
|
// Write implements io.Writer.
|
|
|
func (g *GeneratedFile) Write(p []byte) (n int, err error) {
|
|
|
return g.buf.Write(p)
|
|
|
@@ -255,7 +291,7 @@ func (g *GeneratedFile) Write(p []byte) (n int, err error) {
|
|
|
|
|
|
// Content returns the contents of the generated file.
|
|
|
func (g *GeneratedFile) Content() ([]byte, error) {
|
|
|
- if !strings.HasSuffix(g.path, ".go") {
|
|
|
+ if !strings.HasSuffix(g.filename, ".go") {
|
|
|
return g.buf.Bytes(), nil
|
|
|
}
|
|
|
|
|
|
@@ -272,13 +308,24 @@ func (g *GeneratedFile) Content() ([]byte, error) {
|
|
|
for line := 1; s.Scan(); line++ {
|
|
|
fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
|
|
|
}
|
|
|
- return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.path, err, src.String())
|
|
|
+ return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add imports.
|
|
|
+ var importPaths []string
|
|
|
+ for importPath := range g.packageNames {
|
|
|
+ importPaths = append(importPaths, string(importPath))
|
|
|
+ }
|
|
|
+ sort.Strings(importPaths)
|
|
|
+ for _, importPath := range importPaths {
|
|
|
+ astutil.AddNamedImport(fset, ast, string(g.packageNames[GoImportPath(importPath)]), importPath)
|
|
|
}
|
|
|
+
|
|
|
var out bytes.Buffer
|
|
|
if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, ast); err != nil {
|
|
|
- return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.path, err)
|
|
|
+ return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
|
|
|
}
|
|
|
- // TODO: Patch annotation locations.
|
|
|
+ // TODO: Annotations.
|
|
|
return out.Bytes(), nil
|
|
|
|
|
|
}
|