zsh_completions.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package cobra
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "os"
  7. "strings"
  8. )
  9. // GenZshCompletionFile generates zsh completion file.
  10. func (c *Command) GenZshCompletionFile(filename string) error {
  11. outFile, err := os.Create(filename)
  12. if err != nil {
  13. return err
  14. }
  15. defer outFile.Close()
  16. return c.GenZshCompletion(outFile)
  17. }
  18. // GenZshCompletion generates a zsh completion file and writes to the passed writer.
  19. func (c *Command) GenZshCompletion(w io.Writer) error {
  20. buf := new(bytes.Buffer)
  21. writeHeader(buf, c)
  22. maxDepth := maxDepth(c)
  23. writeLevelMapping(buf, maxDepth)
  24. writeLevelCases(buf, maxDepth, c)
  25. _, err := buf.WriteTo(w)
  26. return err
  27. }
  28. func writeHeader(w io.Writer, cmd *Command) {
  29. fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
  30. }
  31. func maxDepth(c *Command) int {
  32. if len(c.Commands()) == 0 {
  33. return 0
  34. }
  35. maxDepthSub := 0
  36. for _, s := range c.Commands() {
  37. subDepth := maxDepth(s)
  38. if subDepth > maxDepthSub {
  39. maxDepthSub = subDepth
  40. }
  41. }
  42. return 1 + maxDepthSub
  43. }
  44. func writeLevelMapping(w io.Writer, numLevels int) {
  45. fmt.Fprintln(w, `_arguments \`)
  46. for i := 1; i <= numLevels; i++ {
  47. fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
  48. fmt.Fprintln(w)
  49. }
  50. fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
  51. fmt.Fprintln(w)
  52. }
  53. func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
  54. fmt.Fprintln(w, "case $state in")
  55. defer fmt.Fprintln(w, "esac")
  56. for i := 1; i <= maxDepth; i++ {
  57. fmt.Fprintf(w, " level%d)\n", i)
  58. writeLevel(w, root, i)
  59. fmt.Fprintln(w, " ;;")
  60. }
  61. fmt.Fprintln(w, " *)")
  62. fmt.Fprintln(w, " _arguments '*: :_files'")
  63. fmt.Fprintln(w, " ;;")
  64. }
  65. func writeLevel(w io.Writer, root *Command, i int) {
  66. fmt.Fprintf(w, " case $words[%d] in\n", i)
  67. defer fmt.Fprintln(w, " esac")
  68. commands := filterByLevel(root, i)
  69. byParent := groupByParent(commands)
  70. for p, c := range byParent {
  71. names := names(c)
  72. fmt.Fprintf(w, " %s)\n", p)
  73. fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
  74. fmt.Fprintln(w, " ;;")
  75. }
  76. fmt.Fprintln(w, " *)")
  77. fmt.Fprintln(w, " _arguments '*: :_files'")
  78. fmt.Fprintln(w, " ;;")
  79. }
  80. func filterByLevel(c *Command, l int) []*Command {
  81. cs := make([]*Command, 0)
  82. if l == 0 {
  83. cs = append(cs, c)
  84. return cs
  85. }
  86. for _, s := range c.Commands() {
  87. cs = append(cs, filterByLevel(s, l-1)...)
  88. }
  89. return cs
  90. }
  91. func groupByParent(commands []*Command) map[string][]*Command {
  92. m := make(map[string][]*Command)
  93. for _, c := range commands {
  94. parent := c.Parent()
  95. if parent == nil {
  96. continue
  97. }
  98. m[parent.Name()] = append(m[parent.Name()], c)
  99. }
  100. return m
  101. }
  102. func names(commands []*Command) []string {
  103. ns := make([]string, len(commands))
  104. for i, c := range commands {
  105. ns[i] = c.Name()
  106. }
  107. return ns
  108. }