diff options
Diffstat (limited to 'internal/config/state.go')
| -rw-r--r-- | internal/config/state.go | 109 |
1 files changed, 60 insertions, 49 deletions
diff --git a/internal/config/state.go b/internal/config/state.go index 90e8a98f2..eeff866b3 100644 --- a/internal/config/state.go +++ b/internal/config/state.go @@ -18,10 +18,11 @@ package config import ( + "os" + "path" "strings" "sync" - "github.com/go-viper/mapstructure/v2" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -46,34 +47,25 @@ func NewState() *ConfigState { // and will reload the current Configuration back into viper settings. func (st *ConfigState) Config(fn func(*Configuration)) { st.mutex.Lock() - defer func() { - st.reloadToViper() - st.mutex.Unlock() - }() + defer st.mutex.Unlock() fn(&st.config) + st.reloadToViper() } // Viper provides safe access to the ConfigState's contained viper instance, // and will reload the current viper setting state back into Configuration. func (st *ConfigState) Viper(fn func(*viper.Viper)) { st.mutex.Lock() - defer func() { - st.reloadFromViper() - st.mutex.Unlock() - }() + defer st.mutex.Unlock() fn(st.viper) + st.reloadFromViper() } -// LoadEarlyFlags will bind specific flags from given Cobra command to ConfigState's viper -// instance, and load the current configuration values. This is useful for flags like -// .ConfigPath which have to parsed first in order to perform early configuration load. -func (st *ConfigState) LoadEarlyFlags(cmd *cobra.Command) (err error) { - name := ConfigPathFlag() - flag := cmd.Flags().Lookup(name) - st.Viper(func(v *viper.Viper) { - err = v.BindPFlag(name, flag) - }) - return +// RegisterGlobalFlags ... +func (st *ConfigState) RegisterGlobalFlags(root *cobra.Command) { + st.mutex.RLock() + st.config.RegisterFlags(root.PersistentFlags()) + st.mutex.RUnlock() } // BindFlags will bind given Cobra command's pflags to this ConfigState's viper instance. @@ -84,15 +76,21 @@ func (st *ConfigState) BindFlags(cmd *cobra.Command) (err error) { return } -// Reload will reload the Configuration values from ConfigState's viper instance, and from file if set. -func (st *ConfigState) Reload() (err error) { +// LoadConfigFile loads the currently set configuration file into this ConfigState's viper instance. +func (st *ConfigState) LoadConfigFile() (err error) { st.Viper(func(v *viper.Viper) { - if st.config.ConfigPath != "" { - // Ensure configuration path is set - v.SetConfigFile(st.config.ConfigPath) + if path := st.config.ConfigPath; path != "" { + var cfgmap map[string]any - // Read in configuration from file - if err = v.ReadInConfig(); err != nil { + // Read config map into memory. + cfgmap, err := readConfigMap(path) + if err != nil { + return + } + + // Merge the parsed config into viper. + err = st.viper.MergeConfigMap(cfgmap) + if err != nil { return } } @@ -108,18 +106,17 @@ func (st *ConfigState) Reset() { defer st.mutex.Unlock() // Create new viper. - viper := viper.New() + st.viper = viper.New() // Flag 'some-flag-name' becomes env var 'GTS_SOME_FLAG_NAME' - viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) - viper.SetEnvPrefix("gts") + st.viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + st.viper.SetEnvPrefix("gts") // Load appropriate // named vals from env. - viper.AutomaticEnv() + st.viper.AutomaticEnv() - // Reset variables. - st.viper = viper + // Set default config. st.config = Defaults // Load into viper. @@ -128,31 +125,45 @@ func (st *ConfigState) Reset() { // reloadToViper will reload Configuration{} values into viper. func (st *ConfigState) reloadToViper() { - raw, err := st.config.MarshalMap() - if err != nil { - panic(err) - } - if err := st.viper.MergeConfigMap(raw); err != nil { + if err := st.viper.MergeConfigMap(st.config.MarshalMap()); err != nil { panic(err) } } // reloadFromViper will reload Configuration{} values from viper. func (st *ConfigState) reloadFromViper() { - if err := st.viper.Unmarshal(&st.config, func(c *mapstructure.DecoderConfig) { - c.TagName = "name" + if err := st.config.UnmarshalMap(st.viper.AllSettings()); err != nil { + panic(err) + } +} - // empty config before marshaling - c.ZeroFields = true +// readConfigMap reads given configuration file into memory, +// using viper's codec registry to handle decoding into a map, +// flattening the result for standardization, returning this. +// this ensures the stored config map in viper always has the +// same level of nesting, given we support varying levels. +func readConfigMap(file string) (map[string]any, error) { + ext := path.Ext(file) + ext = strings.TrimPrefix(ext, ".") + + registry := viper.NewCodecRegistry() + dec, err := registry.Decoder(ext) + if err != nil { + return nil, err + } - oldhook := c.DecodeHook + data, err := os.ReadFile(file) + if err != nil { + return nil, err + } - // Use the TextUnmarshaler interface when decoding. - c.DecodeHook = mapstructure.ComposeDecodeHookFunc( - mapstructure.TextUnmarshallerHookFunc(), - oldhook, - ) - }); err != nil { - panic(err) + cfgmap := make(map[string]any) + + if err := dec.Decode(data, cfgmap); err != nil { + return nil, err } + + flattenConfigMap(cfgmap) + + return cfgmap, nil } |
