flag_groups.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. // Copyright 2013-2023 The Cobra Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package cobra
  15. import (
  16. "fmt"
  17. "sort"
  18. "strings"
  19. flag "github.com/spf13/pflag"
  20. )
  21. const (
  22. requiredAsGroup = "cobra_annotation_required_if_others_set"
  23. oneRequired = "cobra_annotation_one_required"
  24. mutuallyExclusive = "cobra_annotation_mutually_exclusive"
  25. )
  26. // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
  27. // if the command is invoked with a subset (but not all) of the given flags.
  28. func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
  29. c.mergePersistentFlags()
  30. for _, v := range flagNames {
  31. f := c.Flags().Lookup(v)
  32. if f == nil {
  33. panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
  34. }
  35. if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
  36. // Only errs if the flag isn't found.
  37. panic(err)
  38. }
  39. }
  40. }
  41. // MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
  42. // if the command is invoked without at least one flag from the given set of flags.
  43. func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
  44. c.mergePersistentFlags()
  45. for _, v := range flagNames {
  46. f := c.Flags().Lookup(v)
  47. if f == nil {
  48. panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
  49. }
  50. if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil {
  51. // Only errs if the flag isn't found.
  52. panic(err)
  53. }
  54. }
  55. }
  56. // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
  57. // if the command is invoked with more than one flag from the given set of flags.
  58. func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
  59. c.mergePersistentFlags()
  60. for _, v := range flagNames {
  61. f := c.Flags().Lookup(v)
  62. if f == nil {
  63. panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
  64. }
  65. // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
  66. if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
  67. panic(err)
  68. }
  69. }
  70. }
  71. // ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
  72. // first error encountered.
  73. func (c *Command) ValidateFlagGroups() error {
  74. if c.DisableFlagParsing {
  75. return nil
  76. }
  77. flags := c.Flags()
  78. // groupStatus format is the list of flags as a unique ID,
  79. // then a map of each flag name and whether it is set or not.
  80. groupStatus := map[string]map[string]bool{}
  81. oneRequiredGroupStatus := map[string]map[string]bool{}
  82. mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
  83. flags.VisitAll(func(pflag *flag.Flag) {
  84. processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
  85. processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
  86. processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
  87. })
  88. if err := validateRequiredFlagGroups(groupStatus); err != nil {
  89. return err
  90. }
  91. if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
  92. return err
  93. }
  94. if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
  95. return err
  96. }
  97. return nil
  98. }
  99. func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
  100. for _, fname := range flagnames {
  101. f := fs.Lookup(fname)
  102. if f == nil {
  103. return false
  104. }
  105. }
  106. return true
  107. }
  108. func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
  109. groupInfo, found := pflag.Annotations[annotation]
  110. if found {
  111. for _, group := range groupInfo {
  112. if groupStatus[group] == nil {
  113. flagnames := strings.Split(group, " ")
  114. // Only consider this flag group at all if all the flags are defined.
  115. if !hasAllFlags(flags, flagnames...) {
  116. continue
  117. }
  118. groupStatus[group] = map[string]bool{}
  119. for _, name := range flagnames {
  120. groupStatus[group][name] = false
  121. }
  122. }
  123. groupStatus[group][pflag.Name] = pflag.Changed
  124. }
  125. }
  126. }
  127. func validateRequiredFlagGroups(data map[string]map[string]bool) error {
  128. keys := sortedKeys(data)
  129. for _, flagList := range keys {
  130. flagnameAndStatus := data[flagList]
  131. unset := []string{}
  132. for flagname, isSet := range flagnameAndStatus {
  133. if !isSet {
  134. unset = append(unset, flagname)
  135. }
  136. }
  137. if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
  138. continue
  139. }
  140. // Sort values, so they can be tested/scripted against consistently.
  141. sort.Strings(unset)
  142. return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
  143. }
  144. return nil
  145. }
  146. func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
  147. keys := sortedKeys(data)
  148. for _, flagList := range keys {
  149. flagnameAndStatus := data[flagList]
  150. var set []string
  151. for flagname, isSet := range flagnameAndStatus {
  152. if isSet {
  153. set = append(set, flagname)
  154. }
  155. }
  156. if len(set) >= 1 {
  157. continue
  158. }
  159. // Sort values, so they can be tested/scripted against consistently.
  160. sort.Strings(set)
  161. return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
  162. }
  163. return nil
  164. }
  165. func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
  166. keys := sortedKeys(data)
  167. for _, flagList := range keys {
  168. flagnameAndStatus := data[flagList]
  169. var set []string
  170. for flagname, isSet := range flagnameAndStatus {
  171. if isSet {
  172. set = append(set, flagname)
  173. }
  174. }
  175. if len(set) == 0 || len(set) == 1 {
  176. continue
  177. }
  178. // Sort values, so they can be tested/scripted against consistently.
  179. sort.Strings(set)
  180. return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
  181. }
  182. return nil
  183. }
  184. func sortedKeys(m map[string]map[string]bool) []string {
  185. keys := make([]string, len(m))
  186. i := 0
  187. for k := range m {
  188. keys[i] = k
  189. i++
  190. }
  191. sort.Strings(keys)
  192. return keys
  193. }
  194. // enforceFlagGroupsForCompletion will do the following:
  195. // - when a flag in a group is present, other flags in the group will be marked required
  196. // - when none of the flags in a one-required group are present, all flags in the group will be marked required
  197. // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
  198. // This allows the standard completion logic to behave appropriately for flag groups
  199. func (c *Command) enforceFlagGroupsForCompletion() {
  200. if c.DisableFlagParsing {
  201. return
  202. }
  203. flags := c.Flags()
  204. groupStatus := map[string]map[string]bool{}
  205. oneRequiredGroupStatus := map[string]map[string]bool{}
  206. mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
  207. c.Flags().VisitAll(func(pflag *flag.Flag) {
  208. processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
  209. processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
  210. processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
  211. })
  212. // If a flag that is part of a group is present, we make all the other flags
  213. // of that group required so that the shell completion suggests them automatically
  214. for flagList, flagnameAndStatus := range groupStatus {
  215. for _, isSet := range flagnameAndStatus {
  216. if isSet {
  217. // One of the flags of the group is set, mark the other ones as required
  218. for _, fName := range strings.Split(flagList, " ") {
  219. _ = c.MarkFlagRequired(fName)
  220. }
  221. }
  222. }
  223. }
  224. // If none of the flags of a one-required group are present, we make all the flags
  225. // of that group required so that the shell completion suggests them automatically
  226. for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
  227. set := 0
  228. for _, isSet := range flagnameAndStatus {
  229. if isSet {
  230. set++
  231. }
  232. }
  233. // None of the flags of the group are set, mark all flags in the group
  234. // as required
  235. if set == 0 {
  236. for _, fName := range strings.Split(flagList, " ") {
  237. _ = c.MarkFlagRequired(fName)
  238. }
  239. }
  240. }
  241. // If a flag that is mutually exclusive to others is present, we hide the other
  242. // flags of that group so the shell completion does not suggest them
  243. for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
  244. for flagName, isSet := range flagnameAndStatus {
  245. if isSet {
  246. // One of the flags of the mutually exclusive group is set, mark the other ones as hidden
  247. // Don't mark the flag that is already set as hidden because it may be an
  248. // array or slice flag and therefore must continue being suggested
  249. for _, fName := range strings.Split(flagList, " ") {
  250. if fName != flagName {
  251. flag := c.Flags().Lookup(fName)
  252. flag.Hidden = true
  253. }
  254. }
  255. }
  256. }
  257. }
  258. }