From 25559c49137fcbbcafb1881cdd8e2cb8ebae693b Mon Sep 17 00:00:00 2001 From: Virgil Date: Tue, 31 Mar 2026 18:18:01 +0000 Subject: [PATCH] fix(config): validate file types before read Co-Authored-By: Virgil --- config.go | 61 +++++++++++++++++++++++++++----------------------- config_test.go | 11 +++++++++ 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/config.go b/config.go index 9952b40..35055dc 100644 --- a/config.go +++ b/config.go @@ -30,8 +30,8 @@ import ( // It uses viper as the underlying configuration engine. type Config struct { mu sync.RWMutex - v *viper.Viper // Full configuration (file + env + defaults) - f *viper.Viper // File-backed configuration only (for persistence) + full *viper.Viper // Full configuration (file + env + defaults) + file *viper.Viper // File-backed configuration only (for persistence) medium coreio.Medium path string } @@ -56,7 +56,7 @@ func WithPath(path string) Option { // WithEnvPrefix sets the prefix for environment variables. func WithEnvPrefix(prefix string) Option { return func(c *Config) { - c.v.SetEnvPrefix(prefix) + c.full.SetEnvPrefix(prefix) } } @@ -65,13 +65,13 @@ func WithEnvPrefix(prefix string) Option { // If no path is provided, it defaults to ~/.core/config.yaml. func New(opts ...Option) (*Config, error) { c := &Config{ - v: viper.New(), - f: viper.New(), + full: viper.New(), + file: viper.New(), } // Configure viper defaults - c.v.SetEnvPrefix("CORE_CONFIG") - c.v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + c.full.SetEnvPrefix("CORE_CONFIG") + c.full.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) for _, opt := range opts { opt(c) @@ -89,7 +89,7 @@ func New(opts ...Option) (*Config, error) { c.path = filepath.Join(home, ".core", "config.yaml") } - c.v.AutomaticEnv() + c.full.AutomaticEnv() // Load existing config file if it exists if c.medium.Exists(c.path) { @@ -127,26 +127,31 @@ func (c *Config) LoadFile(m coreio.Medium, path string) error { c.mu.Lock() defer c.mu.Unlock() - content, err := m.Read(path) - if err != nil { - return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err) - } - configType, err := configTypeForPath(path) if err != nil { return coreerr.E("config.LoadFile", "failed to determine config file type: "+path, err) } - // Load into file-backed viper - c.f.SetConfigType(configType) - if err := c.f.MergeConfig(strings.NewReader(content)); err != nil { - return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (f): %s", path), err) + content, err := m.Read(path) + if err != nil { + return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err) } - // Load into full viper - c.v.SetConfigType(configType) - if err := c.v.MergeConfig(strings.NewReader(content)); err != nil { - return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (v): %s", path), err) + parsed := viper.New() + parsed.SetConfigType(configType) + if err := parsed.MergeConfig(strings.NewReader(content)); err != nil { + return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file: %s", path), err) + } + + settings := parsed.AllSettings() + + // Keep the persisted and runtime views aligned with the same parsed data. + if err := c.file.MergeConfigMap(settings); err != nil { + return coreerr.E("config.LoadFile", "failed to merge config into file settings", err) + } + + if err := c.full.MergeConfigMap(settings); err != nil { + return coreerr.E("config.LoadFile", "failed to merge config into full settings", err) } return nil @@ -160,17 +165,17 @@ func (c *Config) Get(key string, out any) error { defer c.mu.RUnlock() if key == "" { - if err := c.v.Unmarshal(out); err != nil { + if err := c.full.Unmarshal(out); err != nil { return coreerr.E("config.Get", "failed to unmarshal full config", err) } return nil } - if !c.v.IsSet(key) { + if !c.full.IsSet(key) { return coreerr.E("config.Get", fmt.Sprintf("key not found: %s", key), nil) } - if err := c.v.UnmarshalKey(key, out); err != nil { + if err := c.full.UnmarshalKey(key, out); err != nil { return coreerr.E("config.Get", fmt.Sprintf("failed to unmarshal key: %s", key), err) } return nil @@ -182,8 +187,8 @@ func (c *Config) Set(key string, v any) error { c.mu.Lock() defer c.mu.Unlock() - c.f.Set(key, v) - c.v.Set(key, v) + c.file.Set(key, v) + c.full.Set(key, v) return nil } @@ -194,7 +199,7 @@ func (c *Config) Commit() error { c.mu.Lock() defer c.mu.Unlock() - if err := Save(c.medium, c.path, c.f.AllSettings()); err != nil { + if err := Save(c.medium, c.path, c.file.AllSettings()); err != nil { return coreerr.E("config.Commit", "failed to save config", err) } return nil @@ -206,7 +211,7 @@ func (c *Config) All() iter.Seq2[string, any] { c.mu.RLock() defer c.mu.RUnlock() - settings := c.v.AllSettings() + settings := c.full.AllSettings() keys := make([]string, 0, len(settings)) for key := range settings { keys = append(keys, key) diff --git a/config_test.go b/config_test.go index f53ff3e..317c80b 100644 --- a/config_test.go +++ b/config_test.go @@ -286,6 +286,17 @@ func TestConfig_LoadFile_Unsupported_Bad(t *testing.T) { assert.Contains(t, err.Error(), "unsupported config file type") } +func TestConfig_LoadFile_Unsupported_NoRead_Bad(t *testing.T) { + m := coreio.NewMockMedium() + + cfg, err := New(WithMedium(m), WithPath("/tmp/test/config.txt")) + assert.NoError(t, err) + + err = cfg.LoadFile(m, "/tmp/test/config.txt") + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported config file type") +} + func TestSave_Good(t *testing.T) { m := coreio.NewMockMedium()