diff options
Diffstat (limited to 'vendor/github.com/spf13/cobra/flag_groups.go')
-rw-r--r-- | vendor/github.com/spf13/cobra/flag_groups.go | 68 |
1 files changed, 67 insertions, 1 deletions
diff --git a/vendor/github.com/spf13/cobra/flag_groups.go b/vendor/github.com/spf13/cobra/flag_groups.go index b35fde155..0671ec5f2 100644 --- a/vendor/github.com/spf13/cobra/flag_groups.go +++ b/vendor/github.com/spf13/cobra/flag_groups.go @@ -24,6 +24,7 @@ import ( const ( requiredAsGroup = "cobra_annotation_required_if_others_set" + oneRequired = "cobra_annotation_one_required" mutuallyExclusive = "cobra_annotation_mutually_exclusive" ) @@ -43,6 +44,22 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { } } +// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors +// if the command is invoked without at least one flag from the given set of flags. +func (c *Command) MarkFlagsOneRequired(flagNames ...string) { + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v)) + } + if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil { + // Only errs if the flag isn't found. + panic(err) + } + } +} + // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors // if the command is invoked with more than one flag from the given set of flags. func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { @@ -59,7 +76,7 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { } } -// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the +// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the // first error encountered. func (c *Command) ValidateFlagGroups() error { if c.DisableFlagParsing { @@ -71,15 +88,20 @@ func (c *Command) ValidateFlagGroups() error { // groupStatus format is the list of flags as a unique ID, // then a map of each flag name and whether it is set or not. groupStatus := map[string]map[string]bool{} + oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) + processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { return err } + if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil { + return err + } if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { return err } @@ -142,6 +164,27 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error { return nil } +func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { + keys := sortedKeys(data) + for _, flagList := range keys { + flagnameAndStatus := data[flagList] + var set []string + for flagname, isSet := range flagnameAndStatus { + if isSet { + set = append(set, flagname) + } + } + if len(set) >= 1 { + continue + } + + // Sort values, so they can be tested/scripted against consistently. + sort.Strings(set) + return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList) + } + return nil +} + func validateExclusiveFlagGroups(data map[string]map[string]bool) error { keys := sortedKeys(data) for _, flagList := range keys { @@ -176,6 +219,7 @@ func sortedKeys(m map[string]map[string]bool) []string { // enforceFlagGroupsForCompletion will do the following: // - when a flag in a group is present, other flags in the group will be marked required +// - when none of the flags in a one-required group are present, all flags in the group will be marked required // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden // This allows the standard completion logic to behave appropriately for flag groups func (c *Command) enforceFlagGroupsForCompletion() { @@ -185,9 +229,11 @@ func (c *Command) enforceFlagGroupsForCompletion() { flags := c.Flags() groupStatus := map[string]map[string]bool{} + oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} c.Flags().VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) + processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) }) @@ -204,6 +250,26 @@ func (c *Command) enforceFlagGroupsForCompletion() { } } + // If none of the flags of a one-required group are present, we make all the flags + // of that group required so that the shell completion suggests them automatically + for flagList, flagnameAndStatus := range oneRequiredGroupStatus { + set := 0 + + for _, isSet := range flagnameAndStatus { + if isSet { + set++ + } + } + + // None of the flags of the group are set, mark all flags in the group + // as required + if set == 0 { + for _, fName := range strings.Split(flagList, " ") { + _ = c.MarkFlagRequired(fName) + } + } + } + // If a flag that is mutually exclusive to others is present, we hide the other // flags of that group so the shell completion does not suggest them for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus { |