diff options
Diffstat (limited to 'internal/config/gen/gen.go')
| -rw-r--r-- | internal/config/gen/gen.go | 485 |
1 files changed, 440 insertions, 45 deletions
diff --git a/internal/config/gen/gen.go b/internal/config/gen/gen.go index dda321e7c..b3532caf8 100644 --- a/internal/config/gen/gen.go +++ b/internal/config/gen/gen.go @@ -25,6 +25,7 @@ import ( "os/exec" "reflect" "strings" + "time" "code.superseriousbusiness.org/gotosocial/internal/config" ) @@ -48,6 +49,11 @@ const license = `// GoToSocial ` +var durationType = reflect.TypeOf(time.Duration(0)) +var stringerType = reflect.TypeOf((*interface{ String() string })(nil)).Elem() +var stringersType = reflect.TypeOf((*interface{ Strings() []string })(nil)).Elem() +var flagSetType = reflect.TypeOf((*interface{ Set(string) error })(nil)).Elem() + func main() { var out string @@ -61,41 +67,392 @@ func main() { panic(err) } - fmt.Fprint(output, "// THIS IS A GENERATED FILE, DO NOT EDIT BY HAND\n") - fmt.Fprint(output, license) - fmt.Fprint(output, "package config\n\n") - fmt.Fprint(output, "import (\n") - fmt.Fprint(output, "\t\"time\"\n\n") - fmt.Fprint(output, "\t\"codeberg.org/gruf/go-bytesize\"\n") - fmt.Fprint(output, "\t\"code.superseriousbusiness.org/gotosocial/internal/language\"\n") - fmt.Fprint(output, ")\n\n") - generateFields(output, nil, reflect.TypeOf(config.Configuration{})) - _ = output.Close() - _ = exec.Command("gofumpt", "-w", out).Run() - - // The plan here is that eventually we might be able - // to generate an example configuration from struct tags + configType := reflect.TypeOf(config.Configuration{}) + + // Parse our config type for usable fields. + fields := loadConfigFields(nil, nil, configType) + + fprintf(output, "// THIS IS A GENERATED FILE, DO NOT EDIT BY HAND\n") + fprintf(output, license) + fprintf(output, "package config\n\n") + fprintf(output, "import (\n") + fprintf(output, "\t\"fmt\"\n") + fprintf(output, "\t\"time\"\n\n") + fprintf(output, "\t\"codeberg.org/gruf/go-bytesize\"\n") + fprintf(output, "\t\"code.superseriousbusiness.org/gotosocial/internal/language\"\n") + fprintf(output, "\t\"github.com/spf13/pflag\"\n") + fprintf(output, "\t\"github.com/spf13/cast\"\n") + fprintf(output, ")\n") + fprintf(output, "\n") + generateFlagRegistering(output, fields) + generateMapMarshaler(output, fields) + generateMapUnmarshaler(output, fields) + generateGetSetters(output, fields) + generateMapFlattener(output, fields) + must(output.Close()) + must(exec.Command("gofumpt", "-w", out).Run()) +} + +type ConfigField struct { + // Any CLI flag prefixes, + // i.e. with nested fields. + Prefixes []string + + // The base CLI flag + // name of the field. + Name string + + // Path to struct field + // in dot-separated form. + Path string + + // Usage string. + Usage string + + // The underlying Go type + // of the config field. + Type reflect.Type + + // i.e. is this found in the configuration file? + // or just used in specific CLI commands? in the + // future we'll remove these from config struct. + Ephemeral bool +} + +// Flag returns the combined "prefixes-name" CLI flag for config field. +func (f ConfigField) Flag() string { + flag := strings.Join(append(f.Prefixes, f.Name), "-") + flag = strings.ToLower(flag) + return flag +} + +// PossibleKeys returns a list of possible map key combinations +// that this config field may be found under. The combined "prefixes-name" +// will always be in the list, but also separates them out to account for +// possible nesting. This allows us to support both nested and un-nested +// configuration files, always prioritizing "prefixes-name" as its the CLI flag. +func (f ConfigField) PossibleKeys() [][]string { + if len(f.Prefixes) == 0 { + return [][]string{{f.Name}} + } + + var keys [][]string + + combined := f.Flag() + keys = append(keys, []string{combined}) + + basePrefix := strings.TrimSuffix(combined, "-"+f.Name) + keys = append(keys, []string{basePrefix, f.Name}) + + for i := len(f.Prefixes) - 1; i >= 0; i-- { + prefix := f.Prefixes[i] + + basePrefix = strings.TrimSuffix(basePrefix, prefix) + basePrefix = strings.TrimSuffix(basePrefix, "-") + if len(basePrefix) == 0 { + break + } + + var key []string + key = append(key, basePrefix) + key = append(key, f.Prefixes[i:]...) + key = append(key, f.Name) + keys = append(keys, key) + } + + return keys } -func generateFields(output io.Writer, prefixes []string, t reflect.Type) { +func loadConfigFields(pathPrefixes, flagPrefixes []string, t reflect.Type) []ConfigField { + var out []ConfigField for i := 0; i < t.NumField(); i++ { + // Struct field at index. field := t.Field(i) + // Get field's tagged name. + name := field.Tag.Get("name") + if name == "" || name == "-" { + continue + } + if ft := field.Type; ft.Kind() == reflect.Struct { - // This is a struct field containing further nested config vars. - generateFields(output, append(prefixes, field.Name), ft) + // This is a nested struct, load nested fields. + pathPrefixes := append(pathPrefixes, field.Name) + flagPrefixes := append(flagPrefixes, name) + out = append(out, loadConfigFields(pathPrefixes, flagPrefixes, ft)...) + continue + } + + // Get prefixed, period-separated, config variable struct "path". + fieldPath := strings.Join(append(pathPrefixes, field.Name), ".") + + // Append prepared ConfigField. + out = append(out, ConfigField{ + Prefixes: flagPrefixes, + Name: name, + Path: fieldPath, + Usage: field.Tag.Get("usage"), + Ephemeral: field.Tag.Get("ephemeral") == "yes", + Type: field.Type, + }) + } + return out +} + +// func generateFlagConsts(out io.Writer, fields []ConfigField) { +// fprintf(out, "const (\n") +// for _, field := range fields { +// name := strings.ReplaceAll(field.Path, ".", "") +// fprintf(out, "\t%sFlag = \"%s\"\n", name, field.Flag()) +// } +// fprintf(out, ")\n\n") +// } + +func generateFlagRegistering(out io.Writer, fields []ConfigField) { + fprintf(out, "func (cfg *Configuration) RegisterFlags(flags *pflag.FlagSet) {\n") + for _, field := range fields { + if field.Ephemeral { + // Skip registering + // ephemeral flags. + continue + } + + // Check for easy cases of just regular primitive types. + if field.Type.Kind().String() == field.Type.String() { + typeName := field.Type.String() + typeName = strings.ToUpper(typeName[:1]) + typeName[1:] + fprintf(out, "\tflags.%s(\"%s\", cfg.%s, \"%s\")\n", typeName, field.Flag(), field.Path, field.Usage) + continue + } + + // Check for easy cases of just + // regular primitive slice types. + if field.Type.Kind() == reflect.Slice { + elem := field.Type.Elem() + if elem.Kind().String() == elem.String() { + typeName := elem.String() + typeName = strings.ToUpper(typeName[:1]) + typeName[1:] + fprintf(out, "\tflags.%sSlice(\"%s\", cfg.%s, \"%s\")\n", typeName, field.Flag(), field.Path, field.Usage) + continue + } + } + + // Durations should get set directly + // as their types as viper knows how + // to deal with this type directly. + if field.Type == durationType { + fprintf(out, "\tflags.Duration(\"%s\", cfg.%s, \"%s\")\n", field.Flag(), field.Path, field.Usage) + continue + } + + if field.Type.Kind() == reflect.Slice { + // Check if the field supports Stringers{}. + if field.Type.Implements(stringersType) { + fprintf(out, "\tflags.StringSlice(\"%s\", cfg.%s.Strings(), \"%s\")\n", field.Flag(), field.Path, field.Usage) + continue + } + + // Or the pointer type of the field value supports Stringers{}. + if ptr := reflect.PointerTo(field.Type); ptr.Implements(stringersType) { + fprintf(out, "\tflags.StringSlice(\"%s\", cfg.%s.Strings(), \"%s\")\n", field.Flag(), field.Path, field.Usage) + continue + } + + fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, stringersType) + } else { + // Check if the field supports Stringer{}. + if field.Type.Implements(stringerType) { + fprintf(out, "\tflags.String(\"%s\", cfg.%s.String(), \"%s\")\n", field.Flag(), field.Path, field.Usage) + continue + } + + // Or the pointer type of the field value supports Stringer{}. + if ptr := reflect.PointerTo(field.Type); ptr.Implements(stringerType) { + fprintf(out, "\tflags.String(\"%s\", cfg.%s.String(), \"%s\")\n", field.Flag(), field.Path, field.Usage) + continue + } + + fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, stringerType) + } + } + fprintf(out, "}\n\n") +} + +func generateMapMarshaler(out io.Writer, fields []ConfigField) { + fprintf(out, "func (cfg *Configuration) MarshalMap() map[string]any {\n") + fprintf(out, "\tcfgmap := make(map[string]any, %d)\n", len(fields)) + for _, field := range fields { + // Check for easy cases of just regular primitive types. + if field.Type.Kind().String() == field.Type.String() { + fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s\n", field.Flag(), field.Path) + continue + } + + // Check for easy cases of just + // regular primitive slice types. + if field.Type.Kind() == reflect.Slice { + elem := field.Type.Elem() + if elem.Kind().String() == elem.String() { + fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s\n", field.Flag(), field.Path) + continue + } + } + + // Durations should get set directly + // as their types as viper knows how + // to deal with this type directly. + if field.Type == durationType { + fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s\n", field.Flag(), field.Path) + continue + } + + if field.Type.Kind() == reflect.Slice { + // Either the field must support Stringers{}. + if field.Type.Implements(stringersType) { + fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s.Strings()\n", field.Flag(), field.Path) + continue + } + + // Or the pointer type of the field value must support Stringers{}. + if ptr := reflect.PointerTo(field.Type); ptr.Implements(stringersType) { + fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s.Strings()\n", field.Flag(), field.Path) + continue + } + + fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, stringersType) + } else { + // Either the field must support Stringer{}. + if field.Type.Implements(stringerType) { + fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s.String()\n", field.Flag(), field.Path) + continue + } + + // Or the pointer type of the field value must support Stringer{}. + if ptr := reflect.PointerTo(field.Type); ptr.Implements(stringerType) { + fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s.String()\n", field.Flag(), field.Path) + continue + } + + fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, stringerType) + } + } + fprintf(out, "\treturn cfgmap") + fprintf(out, "}\n\n") +} + +func generateMapUnmarshaler(out io.Writer, fields []ConfigField) { + fprintf(out, "func (cfg *Configuration) UnmarshalMap(cfgmap map[string]any) error {\n") + fprintf(out, "// VERY IMPORTANT FIRST STEP!\n") + fprintf(out, "// flatten to normalize map to\n") + fprintf(out, "// entirely un-nested key values\n") + fprintf(out, "flattenConfigMap(cfgmap)\n") + fprintf(out, "\n") + for _, field := range fields { + // Check for easy cases of just regular primitive types. + if field.Type.Kind().String() == field.Type.String() { + generateUnmarshalerPrimitive(out, field) + continue + } + + // Check for easy cases of just + // regular primitive slice types. + if field.Type.Kind() == reflect.Slice { + elem := field.Type.Elem() + if elem.Kind().String() == elem.String() { + generateUnmarshalerPrimitive(out, field) + continue + } + } + + // Durations should get set directly + // as their types as viper knows how + // to deal with this type directly. + if field.Type == durationType { + generateUnmarshalerPrimitive(out, field) + continue + } + + // Either the field must support flag.Value{}. + if field.Type.Implements(flagSetType) { + generateUnmarshalerFlagType(out, field) continue } - // Get prefixed config variable name - name := strings.Join(prefixes, "") + field.Name + // Or the pointer type of the field value must support flag.Value{}. + if ptr := reflect.PointerTo(field.Type); ptr.Implements(flagSetType) { + generateUnmarshalerFlagType(out, field) + continue + } - // Get period-separated (if nested) config variable "path" - fieldPath := strings.Join(append(prefixes, field.Name), ".") + fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, flagSetType) + } + fprintf(out, "\treturn nil\n") + fprintf(out, "}\n\n") +} - // Get dash-separated config variable CLI flag "path" - flagPath := strings.Join(append(prefixes, field.Tag.Get("name")), "-") - flagPath = strings.ToLower(flagPath) +func generateUnmarshalerPrimitive(out io.Writer, field ConfigField) { + fprintf(out, "\t\tif ival, ok := cfgmap[\"%s\"]; ok {\n", field.Flag()) + if field.Type.Kind() == reflect.Slice { + elem := field.Type.Elem() + typeName := elem.String() + if i := strings.IndexRune(typeName, '.'); i >= 0 { + typeName = typeName[i+1:] + } + typeName = strings.ToUpper(typeName[:1]) + typeName[1:] + fprintf(out, "\t\t\tvar err error\n") + // note we specifically handle slice types ourselves to split by comma + fprintf(out, "\t\t\tcfg.%s, err = to%sSlice(ival)\n", field.Path, typeName) + fprintf(out, "\t\t\tif err != nil {\n") + fprintf(out, "\t\t\t\treturn fmt.Errorf(\"error casting %%#v -> []%s for '%s': %%w\", ival, err)\n", elem.String(), field.Flag()) + fprintf(out, "\t\t\t}\n") + } else { + typeName := field.Type.String() + if i := strings.IndexRune(typeName, '.'); i >= 0 { + typeName = typeName[i+1:] + } + typeName = strings.ToUpper(typeName[:1]) + typeName[1:] + fprintf(out, "\t\t\tvar err error\n") + fprintf(out, "\t\t\tcfg.%s, err = cast.To%sE(ival)\n", field.Path, typeName) + fprintf(out, "\t\t\tif err != nil {\n") + fprintf(out, "\t\t\t\treturn fmt.Errorf(\"error casting %%#v -> %s for '%s': %%w\", ival, err)\n", field.Type.String(), field.Flag()) + fprintf(out, "\t\t\t}\n") + } + fprintf(out, "\t}\n") + fprintf(out, "\n") +} + +func generateUnmarshalerFlagType(out io.Writer, field ConfigField) { + fprintf(out, "\t\tif ival, ok := cfgmap[\"%s\"]; ok {\n", field.Flag()) + if field.Type.Kind() == reflect.Slice { + // same as above re: slice types and splitting on comma + fprintf(out, "\t\tt, err := toStringSlice(ival)\n") + fprintf(out, "\t\tif err != nil {\n") + fprintf(out, "\t\t\treturn fmt.Errorf(\"error casting %%#v -> []string for '%s': %%w\", ival, err)\n", field.Flag()) + fprintf(out, "\t\t}\n") + fprintf(out, "\t\tcfg.%s = %s{}\n", field.Path, strings.TrimPrefix(field.Type.String(), "config.")) + fprintf(out, "\t\tfor _, in := range t {\n") + fprintf(out, "\t\t\tif err := cfg.%s.Set(in); err != nil {\n", field.Path) + fprintf(out, "\t\t\t\treturn fmt.Errorf(\"error parsing %%#v for '%s': %%w\", ival, err)\n", field.Flag()) + fprintf(out, "\t\t\t}\n") + fprintf(out, "\t\t}\n") + } else { + fprintf(out, "\t\tt, err := cast.ToStringE(ival)\n") + fprintf(out, "\t\tif err != nil {\n") + fprintf(out, "\t\t\treturn fmt.Errorf(\"error casting %%#v -> string for '%s': %%w\", ival, err)\n", field.Flag()) + fprintf(out, "\t\t}\n") + fprintf(out, "\t\tcfg.%s = %#v\n", field.Path, reflect.New(field.Type).Elem().Interface()) + fprintf(out, "\t\tif err := cfg.%s.Set(t); err != nil {\n", field.Path) + fprintf(out, "\t\t\treturn fmt.Errorf(\"error parsing %%#v for '%s': %%w\", ival, err)\n", field.Flag()) + fprintf(out, "\t\t}\n") + } + fprintf(out, "\t}\n") + fprintf(out, "\n") +} + +func generateGetSetters(out io.Writer, fields []ConfigField) { + for _, field := range fields { + // Get name from struct path, without periods. + name := strings.ReplaceAll(field.Path, ".", "") // Get type without "config." prefix. fieldType := strings.ReplaceAll( @@ -103,29 +460,67 @@ func generateFields(output io.Writer, prefixes []string, t reflect.Type) { "config.", "", ) + fprintf(out, "// %sFlag returns the flag name for the '%s' field\n", name, field.Path) + fprintf(out, "func %sFlag() string { return \"%s\" }\n\n", name, field.Flag()) + // ConfigState structure helper methods - fmt.Fprintf(output, "// Get%s safely fetches the Configuration value for state's '%s' field\n", name, fieldPath) - fmt.Fprintf(output, "func (st *ConfigState) Get%s() (v %s) {\n", name, fieldType) - fmt.Fprintf(output, "\tst.mutex.RLock()\n") - fmt.Fprintf(output, "\tv = st.config.%s\n", fieldPath) - fmt.Fprintf(output, "\tst.mutex.RUnlock()\n") - fmt.Fprintf(output, "\treturn\n") - fmt.Fprintf(output, "}\n\n") - fmt.Fprintf(output, "// Set%s safely sets the Configuration value for state's '%s' field\n", name, fieldPath) - fmt.Fprintf(output, "func (st *ConfigState) Set%s(v %s) {\n", name, fieldType) - fmt.Fprintf(output, "\tst.mutex.Lock()\n") - fmt.Fprintf(output, "\tdefer st.mutex.Unlock()\n") - fmt.Fprintf(output, "\tst.config.%s = v\n", fieldPath) - fmt.Fprintf(output, "\tst.reloadToViper()\n") - fmt.Fprintf(output, "}\n\n") + fprintf(out, "// Get%s safely fetches the Configuration value for state's '%s' field\n", name, field.Path) + fprintf(out, "func (st *ConfigState) Get%s() (v %s) {\n", name, fieldType) + fprintf(out, "\tst.mutex.RLock()\n") + fprintf(out, "\tv = st.config.%s\n", field.Path) + fprintf(out, "\tst.mutex.RUnlock()\n") + fprintf(out, "\treturn\n") + fprintf(out, "}\n\n") + fprintf(out, "// Set%s safely sets the Configuration value for state's '%s' field\n", name, field.Path) + fprintf(out, "func (st *ConfigState) Set%s(v %s) {\n", name, fieldType) + fprintf(out, "\tst.mutex.Lock()\n") + fprintf(out, "\tdefer st.mutex.Unlock()\n") + fprintf(out, "\tst.config.%s = v\n", field.Path) + fprintf(out, "\tst.reloadToViper()\n") + fprintf(out, "}\n\n") // Global ConfigState helper methods - // TODO: remove when we pass around a ConfigState{} - fmt.Fprintf(output, "// %sFlag returns the flag name for the '%s' field\n", name, fieldPath) - fmt.Fprintf(output, "func %sFlag() string { return \"%s\" }\n\n", name, flagPath) - fmt.Fprintf(output, "// Get%s safely fetches the value for global configuration '%s' field\n", name, fieldPath) - fmt.Fprintf(output, "func Get%[1]s() %[2]s { return global.Get%[1]s() }\n\n", name, fieldType) - fmt.Fprintf(output, "// Set%s safely sets the value for global configuration '%s' field\n", name, fieldPath) - fmt.Fprintf(output, "func Set%[1]s(v %[2]s) { global.Set%[1]s(v) }\n\n", name, fieldType) + fprintf(out, "// Get%s safely fetches the value for global configuration '%s' field\n", name, field.Path) + fprintf(out, "func Get%[1]s() %[2]s { return global.Get%[1]s() }\n\n", name, fieldType) + fprintf(out, "// Set%s safely sets the value for global configuration '%s' field\n", name, field.Path) + fprintf(out, "func Set%[1]s(v %[2]s) { global.Set%[1]s(v) }\n\n", name, fieldType) + } +} + +func generateMapFlattener(out io.Writer, fields []ConfigField) { + fprintf(out, "func flattenConfigMap(cfgmap map[string]any) {\n") + fprintf(out, "\tnestedKeys := make(map[string]struct{})\n") + for _, field := range fields { + keys := field.PossibleKeys() + if len(keys) <= 1 { + continue + } + fprintf(out, "\tfor _, key := range [][]string{\n") + for _, key := range keys[1:] { + fprintf(out, "\t\t{\"%s\"},\n", strings.Join(key, "\", \"")) + } + fprintf(out, "\t} {\n") + fprintf(out, "\t\tival, ok := mapGet(cfgmap, key...)\n") + fprintf(out, "\t\tif ok {\n") + fprintf(out, "\t\t\tcfgmap[\"%s\"] = ival\n", field.Flag()) + fprintf(out, "\t\t\tnestedKeys[key[0]] = struct{}{}\n") + fprintf(out, "\t\t\tbreak\n") + fprintf(out, "\t\t}\n") + fprintf(out, "\t}\n\n") + } + fprintf(out, "\tfor key := range nestedKeys {\n") + fprintf(out, "\t\tdelete(cfgmap, key)\n") + fprintf(out, "\t}\n") + fprintf(out, "}\n\n") +} + +func fprintf(out io.Writer, format string, args ...any) { + _, err := fmt.Fprintf(out, format, args...) + must(err) +} + +func must(err error) { + if err != nil { + panic(err) } } |
