bash_completions.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. package cobra
  2. import (
  3. "bytes"
  4. "fmt"
  5. "os"
  6. "sort"
  7. "strings"
  8. "github.com/spf13/pflag"
  9. )
  10. const (
  11. BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extentions"
  12. BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag"
  13. BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir"
  14. )
  15. func preamble(out *bytes.Buffer) {
  16. fmt.Fprintf(out, `#!/bin/bash
  17. __debug()
  18. {
  19. if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
  20. echo "$*" >> "${BASH_COMP_DEBUG_FILE}"
  21. fi
  22. }
  23. # Homebrew on Macs have version 1.3 of bash-completion which doesn't include
  24. # _init_completion. This is a very minimal version of that function.
  25. __my_init_completion()
  26. {
  27. COMPREPLY=()
  28. _get_comp_words_by_ref cur prev words cword
  29. }
  30. __index_of_word()
  31. {
  32. local w word=$1
  33. shift
  34. index=0
  35. for w in "$@"; do
  36. [[ $w = "$word" ]] && return
  37. index=$((index+1))
  38. done
  39. index=-1
  40. }
  41. __contains_word()
  42. {
  43. local w word=$1; shift
  44. for w in "$@"; do
  45. [[ $w = "$word" ]] && return
  46. done
  47. return 1
  48. }
  49. __handle_reply()
  50. {
  51. __debug "${FUNCNAME}"
  52. case $cur in
  53. -*)
  54. if [[ $(type -t compopt) = "builtin" ]]; then
  55. compopt -o nospace
  56. fi
  57. local allflags
  58. if [ ${#must_have_one_flag[@]} -ne 0 ]; then
  59. allflags=("${must_have_one_flag[@]}")
  60. else
  61. allflags=("${flags[*]} ${two_word_flags[*]}")
  62. fi
  63. COMPREPLY=( $(compgen -W "${allflags[*]}" -- "$cur") )
  64. if [[ $(type -t compopt) = "builtin" ]]; then
  65. [[ $COMPREPLY == *= ]] || compopt +o nospace
  66. fi
  67. return 0;
  68. ;;
  69. esac
  70. # check if we are handling a flag with special work handling
  71. local index
  72. __index_of_word "${prev}" "${flags_with_completion[@]}"
  73. if [[ ${index} -ge 0 ]]; then
  74. ${flags_completion[${index}]}
  75. return
  76. fi
  77. # we are parsing a flag and don't have a special handler, no completion
  78. if [[ ${cur} != "${words[cword]}" ]]; then
  79. return
  80. fi
  81. local completions
  82. if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
  83. completions=("${must_have_one_flag[@]}")
  84. elif [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
  85. completions=("${must_have_one_noun[@]}")
  86. else
  87. completions=("${commands[@]}")
  88. fi
  89. COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") )
  90. if [[ ${#COMPREPLY[@]} -eq 0 ]]; then
  91. declare -F __custom_func >/dev/null && __custom_func
  92. fi
  93. }
  94. # The arguments should be in the form "ext1|ext2|extn"
  95. __handle_filename_extension_flag()
  96. {
  97. local ext="$1"
  98. _filedir "@(${ext})"
  99. }
  100. __handle_subdirs_in_dir_flag()
  101. {
  102. local dir="$1"
  103. pushd "${dir}" >/dev/null 2>&1 && _filedir -d && popd >/dev/null 2>&1
  104. }
  105. __handle_flag()
  106. {
  107. __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
  108. # if a command required a flag, and we found it, unset must_have_one_flag()
  109. local flagname=${words[c]}
  110. # if the word contained an =
  111. if [[ ${words[c]} == *"="* ]]; then
  112. flagname=${flagname%%=*} # strip everything after the =
  113. flagname="${flagname}=" # but put the = back
  114. fi
  115. __debug "${FUNCNAME}: looking for ${flagname}"
  116. if __contains_word "${flagname}" "${must_have_one_flag[@]}"; then
  117. must_have_one_flag=()
  118. fi
  119. # skip the argument to a two word flag
  120. if __contains_word "${words[c]}" "${two_word_flags[@]}"; then
  121. c=$((c+1))
  122. # if we are looking for a flags value, don't show commands
  123. if [[ $c -eq $cword ]]; then
  124. commands=()
  125. fi
  126. fi
  127. # skip the flag itself
  128. c=$((c+1))
  129. }
  130. __handle_noun()
  131. {
  132. __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
  133. if __contains_word "${words[c]}" "${must_have_one_noun[@]}"; then
  134. must_have_one_noun=()
  135. fi
  136. nouns+=("${words[c]}")
  137. c=$((c+1))
  138. }
  139. __handle_command()
  140. {
  141. __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
  142. local next_command
  143. if [[ -n ${last_command} ]]; then
  144. next_command="_${last_command}_${words[c]}"
  145. else
  146. next_command="_${words[c]}"
  147. fi
  148. c=$((c+1))
  149. __debug "${FUNCNAME}: looking for ${next_command}"
  150. declare -F $next_command >/dev/null && $next_command
  151. }
  152. __handle_word()
  153. {
  154. if [[ $c -ge $cword ]]; then
  155. __handle_reply
  156. return
  157. fi
  158. __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
  159. if [[ "${words[c]}" == -* ]]; then
  160. __handle_flag
  161. elif __contains_word "${words[c]}" "${commands[@]}"; then
  162. __handle_command
  163. else
  164. __handle_noun
  165. fi
  166. __handle_word
  167. }
  168. `)
  169. }
  170. func postscript(out *bytes.Buffer, name string) {
  171. fmt.Fprintf(out, "__start_%s()\n", name)
  172. fmt.Fprintf(out, `{
  173. local cur prev words cword
  174. if declare -F _init_completion >/dev/null 2>&1; then
  175. _init_completion -s || return
  176. else
  177. __my_init_completion || return
  178. fi
  179. local c=0
  180. local flags=()
  181. local two_word_flags=()
  182. local flags_with_completion=()
  183. local flags_completion=()
  184. local commands=("%s")
  185. local must_have_one_flag=()
  186. local must_have_one_noun=()
  187. local last_command
  188. local nouns=()
  189. __handle_word
  190. }
  191. `, name)
  192. fmt.Fprintf(out, `if [[ $(type -t compopt) = "builtin" ]]; then
  193. complete -F __start_%s %s
  194. else
  195. complete -o nospace -F __start_%s %s
  196. fi
  197. `, name, name, name, name)
  198. fmt.Fprintf(out, "# ex: ts=4 sw=4 et filetype=sh\n")
  199. }
  200. func writeCommands(cmd *Command, out *bytes.Buffer) {
  201. fmt.Fprintf(out, " commands=()\n")
  202. for _, c := range cmd.Commands() {
  203. if !c.IsAvailableCommand() || c == cmd.helpCommand {
  204. continue
  205. }
  206. fmt.Fprintf(out, " commands+=(%q)\n", c.Name())
  207. }
  208. fmt.Fprintf(out, "\n")
  209. }
  210. func writeFlagHandler(name string, annotations map[string][]string, out *bytes.Buffer) {
  211. for key, value := range annotations {
  212. switch key {
  213. case BashCompFilenameExt:
  214. fmt.Fprintf(out, " flags_with_completion+=(%q)\n", name)
  215. if len(value) > 0 {
  216. ext := "__handle_filename_extension_flag " + strings.Join(value, "|")
  217. fmt.Fprintf(out, " flags_completion+=(%q)\n", ext)
  218. } else {
  219. ext := "_filedir"
  220. fmt.Fprintf(out, " flags_completion+=(%q)\n", ext)
  221. }
  222. case BashCompSubdirsInDir:
  223. fmt.Fprintf(out, " flags_with_completion+=(%q)\n", name)
  224. if len(value) == 1 {
  225. ext := "__handle_subdirs_in_dir_flag " + value[0]
  226. fmt.Fprintf(out, " flags_completion+=(%q)\n", ext)
  227. } else {
  228. ext := "_filedir -d"
  229. fmt.Fprintf(out, " flags_completion+=(%q)\n", ext)
  230. }
  231. }
  232. }
  233. }
  234. func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) {
  235. b := (flag.Value.Type() == "bool")
  236. name := flag.Shorthand
  237. format := " "
  238. if !b {
  239. format += "two_word_"
  240. }
  241. format += "flags+=(\"-%s\")\n"
  242. fmt.Fprintf(out, format, name)
  243. writeFlagHandler("-"+name, flag.Annotations, out)
  244. }
  245. func writeFlag(flag *pflag.Flag, out *bytes.Buffer) {
  246. b := (flag.Value.Type() == "bool")
  247. name := flag.Name
  248. format := " flags+=(\"--%s"
  249. if !b {
  250. format += "="
  251. }
  252. format += "\")\n"
  253. fmt.Fprintf(out, format, name)
  254. writeFlagHandler("--"+name, flag.Annotations, out)
  255. }
  256. func writeFlags(cmd *Command, out *bytes.Buffer) {
  257. fmt.Fprintf(out, ` flags=()
  258. two_word_flags=()
  259. flags_with_completion=()
  260. flags_completion=()
  261. `)
  262. cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
  263. writeFlag(flag, out)
  264. if len(flag.Shorthand) > 0 {
  265. writeShortFlag(flag, out)
  266. }
  267. })
  268. cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
  269. writeFlag(flag, out)
  270. if len(flag.Shorthand) > 0 {
  271. writeShortFlag(flag, out)
  272. }
  273. })
  274. fmt.Fprintf(out, "\n")
  275. }
  276. func writeRequiredFlag(cmd *Command, out *bytes.Buffer) {
  277. fmt.Fprintf(out, " must_have_one_flag=()\n")
  278. flags := cmd.NonInheritedFlags()
  279. flags.VisitAll(func(flag *pflag.Flag) {
  280. for key := range flag.Annotations {
  281. switch key {
  282. case BashCompOneRequiredFlag:
  283. format := " must_have_one_flag+=(\"--%s"
  284. b := (flag.Value.Type() == "bool")
  285. if !b {
  286. format += "="
  287. }
  288. format += "\")\n"
  289. fmt.Fprintf(out, format, flag.Name)
  290. if len(flag.Shorthand) > 0 {
  291. fmt.Fprintf(out, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand)
  292. }
  293. }
  294. }
  295. })
  296. }
  297. func writeRequiredNoun(cmd *Command, out *bytes.Buffer) {
  298. fmt.Fprintf(out, " must_have_one_noun=()\n")
  299. sort.Sort(sort.StringSlice(cmd.ValidArgs))
  300. for _, value := range cmd.ValidArgs {
  301. fmt.Fprintf(out, " must_have_one_noun+=(%q)\n", value)
  302. }
  303. }
  304. func gen(cmd *Command, out *bytes.Buffer) {
  305. for _, c := range cmd.Commands() {
  306. if !c.IsAvailableCommand() || c == cmd.helpCommand {
  307. continue
  308. }
  309. gen(c, out)
  310. }
  311. commandName := cmd.CommandPath()
  312. commandName = strings.Replace(commandName, " ", "_", -1)
  313. fmt.Fprintf(out, "_%s()\n{\n", commandName)
  314. fmt.Fprintf(out, " last_command=%q\n", commandName)
  315. writeCommands(cmd, out)
  316. writeFlags(cmd, out)
  317. writeRequiredFlag(cmd, out)
  318. writeRequiredNoun(cmd, out)
  319. fmt.Fprintf(out, "}\n\n")
  320. }
  321. func (cmd *Command) GenBashCompletion(out *bytes.Buffer) {
  322. preamble(out)
  323. if len(cmd.BashCompletionFunction) > 0 {
  324. fmt.Fprintf(out, "%s\n", cmd.BashCompletionFunction)
  325. }
  326. gen(cmd, out)
  327. postscript(out, cmd.Name())
  328. }
  329. func (cmd *Command) GenBashCompletionFile(filename string) error {
  330. out := new(bytes.Buffer)
  331. cmd.GenBashCompletion(out)
  332. outFile, err := os.Create(filename)
  333. if err != nil {
  334. return err
  335. }
  336. defer outFile.Close()
  337. _, err = outFile.Write(out.Bytes())
  338. if err != nil {
  339. return err
  340. }
  341. return nil
  342. }
  343. // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists.
  344. func (cmd *Command) MarkFlagRequired(name string) error {
  345. return MarkFlagRequired(cmd.Flags(), name)
  346. }
  347. // MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag, if it exists.
  348. func (cmd *Command) MarkPersistentFlagRequired(name string) error {
  349. return MarkFlagRequired(cmd.PersistentFlags(), name)
  350. }
  351. // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag in the flag set, if it exists.
  352. func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
  353. return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
  354. }
  355. // MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
  356. // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
  357. func (cmd *Command) MarkFlagFilename(name string, extensions ...string) error {
  358. return MarkFlagFilename(cmd.Flags(), name, extensions...)
  359. }
  360. // MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists.
  361. // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
  362. func (cmd *Command) MarkPersistentFlagFilename(name string, extensions ...string) error {
  363. return MarkFlagFilename(cmd.PersistentFlags(), name, extensions...)
  364. }
  365. // MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists.
  366. // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
  367. func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error {
  368. return flags.SetAnnotation(name, BashCompFilenameExt, extensions)
  369. }