123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- // Copyright 2013-2023 The Cobra Authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package cobra
- import (
- "fmt"
- "sort"
- "strings"
- flag "github.com/spf13/pflag"
- )
- const (
- requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
- oneRequiredAnnotation = "cobra_annotation_one_required"
- mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
- )
- // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
- // if the command is invoked with a subset (but not all) of the given flags.
- func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
- c.mergePersistentFlags()
- for _, v := range flagNames {
- f := c.Flags().Lookup(v)
- if f == nil {
- panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
- }
- if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
- // Only errs if the flag isn't found.
- panic(err)
- }
- }
- }
- // MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
- // if the command is invoked without at least one flag from the given set of flags.
- func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
- c.mergePersistentFlags()
- for _, v := range flagNames {
- f := c.Flags().Lookup(v)
- if f == nil {
- panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
- }
- if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
- // Only errs if the flag isn't found.
- panic(err)
- }
- }
- }
- // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
- // if the command is invoked with more than one flag from the given set of flags.
- func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
- c.mergePersistentFlags()
- for _, v := range flagNames {
- f := c.Flags().Lookup(v)
- if f == nil {
- panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
- }
- // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
- if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
- panic(err)
- }
- }
- }
- // ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
- // first error encountered.
- func (c *Command) ValidateFlagGroups() error {
- if c.DisableFlagParsing {
- return nil
- }
- flags := c.Flags()
- // groupStatus format is the list of flags as a unique ID,
- // then a map of each flag name and whether it is set or not.
- groupStatus := map[string]map[string]bool{}
- oneRequiredGroupStatus := map[string]map[string]bool{}
- mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
- flags.VisitAll(func(pflag *flag.Flag) {
- processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
- processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
- processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
- })
- if err := validateRequiredFlagGroups(groupStatus); err != nil {
- return err
- }
- if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
- return err
- }
- if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
- return err
- }
- return nil
- }
- func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
- for _, fname := range flagnames {
- f := fs.Lookup(fname)
- if f == nil {
- return false
- }
- }
- return true
- }
- func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
- groupInfo, found := pflag.Annotations[annotation]
- if found {
- for _, group := range groupInfo {
- if groupStatus[group] == nil {
- flagnames := strings.Split(group, " ")
- // Only consider this flag group at all if all the flags are defined.
- if !hasAllFlags(flags, flagnames...) {
- continue
- }
- groupStatus[group] = make(map[string]bool, len(flagnames))
- for _, name := range flagnames {
- groupStatus[group][name] = false
- }
- }
- groupStatus[group][pflag.Name] = pflag.Changed
- }
- }
- }
- func validateRequiredFlagGroups(data map[string]map[string]bool) error {
- keys := sortedKeys(data)
- for _, flagList := range keys {
- flagnameAndStatus := data[flagList]
- unset := []string{}
- for flagname, isSet := range flagnameAndStatus {
- if !isSet {
- unset = append(unset, flagname)
- }
- }
- if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
- continue
- }
- // Sort values, so they can be tested/scripted against consistently.
- sort.Strings(unset)
- return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
- }
- return nil
- }
- func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
- keys := sortedKeys(data)
- for _, flagList := range keys {
- flagnameAndStatus := data[flagList]
- var set []string
- for flagname, isSet := range flagnameAndStatus {
- if isSet {
- set = append(set, flagname)
- }
- }
- if len(set) >= 1 {
- continue
- }
- // Sort values, so they can be tested/scripted against consistently.
- sort.Strings(set)
- return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
- }
- return nil
- }
- func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
- keys := sortedKeys(data)
- for _, flagList := range keys {
- flagnameAndStatus := data[flagList]
- var set []string
- for flagname, isSet := range flagnameAndStatus {
- if isSet {
- set = append(set, flagname)
- }
- }
- if len(set) == 0 || len(set) == 1 {
- continue
- }
- // Sort values, so they can be tested/scripted against consistently.
- sort.Strings(set)
- return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
- }
- return nil
- }
- func sortedKeys(m map[string]map[string]bool) []string {
- keys := make([]string, len(m))
- i := 0
- for k := range m {
- keys[i] = k
- i++
- }
- sort.Strings(keys)
- return keys
- }
- // enforceFlagGroupsForCompletion will do the following:
- // - when a flag in a group is present, other flags in the group will be marked required
- // - when none of the flags in a one-required group are present, all flags in the group will be marked required
- // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
- // This allows the standard completion logic to behave appropriately for flag groups
- func (c *Command) enforceFlagGroupsForCompletion() {
- if c.DisableFlagParsing {
- return
- }
- flags := c.Flags()
- groupStatus := map[string]map[string]bool{}
- oneRequiredGroupStatus := map[string]map[string]bool{}
- mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
- c.Flags().VisitAll(func(pflag *flag.Flag) {
- processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
- processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
- processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
- })
- // If a flag that is part of a group is present, we make all the other flags
- // of that group required so that the shell completion suggests them automatically
- for flagList, flagnameAndStatus := range groupStatus {
- for _, isSet := range flagnameAndStatus {
- if isSet {
- // One of the flags of the group is set, mark the other ones as required
- for _, fName := range strings.Split(flagList, " ") {
- _ = c.MarkFlagRequired(fName)
- }
- }
- }
- }
- // If none of the flags of a one-required group are present, we make all the flags
- // of that group required so that the shell completion suggests them automatically
- for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
- isSet := false
- for _, isSet = range flagnameAndStatus {
- if isSet {
- break
- }
- }
- // None of the flags of the group are set, mark all flags in the group
- // as required
- if !isSet {
- for _, fName := range strings.Split(flagList, " ") {
- _ = c.MarkFlagRequired(fName)
- }
- }
- }
- // If a flag that is mutually exclusive to others is present, we hide the other
- // flags of that group so the shell completion does not suggest them
- for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
- for flagName, isSet := range flagnameAndStatus {
- if isSet {
- // One of the flags of the mutually exclusive group is set, mark the other ones as hidden
- // Don't mark the flag that is already set as hidden because it may be an
- // array or slice flag and therefore must continue being suggested
- for _, fName := range strings.Split(flagList, " ") {
- if fName != flagName {
- flag := c.Flags().Lookup(fName)
- flag.Hidden = true
- }
- }
- }
- }
- }
- }
|