fix(config): validate file types before read

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-03-31 18:18:01 +00:00
parent d6f7c05838
commit 25559c4913
2 changed files with 44 additions and 28 deletions

View file

@ -30,8 +30,8 @@ import (
// It uses viper as the underlying configuration engine.
type Config struct {
mu sync.RWMutex
v *viper.Viper // Full configuration (file + env + defaults)
f *viper.Viper // File-backed configuration only (for persistence)
full *viper.Viper // Full configuration (file + env + defaults)
file *viper.Viper // File-backed configuration only (for persistence)
medium coreio.Medium
path string
}
@ -56,7 +56,7 @@ func WithPath(path string) Option {
// WithEnvPrefix sets the prefix for environment variables.
func WithEnvPrefix(prefix string) Option {
return func(c *Config) {
c.v.SetEnvPrefix(prefix)
c.full.SetEnvPrefix(prefix)
}
}
@ -65,13 +65,13 @@ func WithEnvPrefix(prefix string) Option {
// If no path is provided, it defaults to ~/.core/config.yaml.
func New(opts ...Option) (*Config, error) {
c := &Config{
v: viper.New(),
f: viper.New(),
full: viper.New(),
file: viper.New(),
}
// Configure viper defaults
c.v.SetEnvPrefix("CORE_CONFIG")
c.v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
c.full.SetEnvPrefix("CORE_CONFIG")
c.full.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
for _, opt := range opts {
opt(c)
@ -89,7 +89,7 @@ func New(opts ...Option) (*Config, error) {
c.path = filepath.Join(home, ".core", "config.yaml")
}
c.v.AutomaticEnv()
c.full.AutomaticEnv()
// Load existing config file if it exists
if c.medium.Exists(c.path) {
@ -127,26 +127,31 @@ func (c *Config) LoadFile(m coreio.Medium, path string) error {
c.mu.Lock()
defer c.mu.Unlock()
content, err := m.Read(path)
if err != nil {
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err)
}
configType, err := configTypeForPath(path)
if err != nil {
return coreerr.E("config.LoadFile", "failed to determine config file type: "+path, err)
}
// Load into file-backed viper
c.f.SetConfigType(configType)
if err := c.f.MergeConfig(strings.NewReader(content)); err != nil {
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (f): %s", path), err)
content, err := m.Read(path)
if err != nil {
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err)
}
// Load into full viper
c.v.SetConfigType(configType)
if err := c.v.MergeConfig(strings.NewReader(content)); err != nil {
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (v): %s", path), err)
parsed := viper.New()
parsed.SetConfigType(configType)
if err := parsed.MergeConfig(strings.NewReader(content)); err != nil {
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file: %s", path), err)
}
settings := parsed.AllSettings()
// Keep the persisted and runtime views aligned with the same parsed data.
if err := c.file.MergeConfigMap(settings); err != nil {
return coreerr.E("config.LoadFile", "failed to merge config into file settings", err)
}
if err := c.full.MergeConfigMap(settings); err != nil {
return coreerr.E("config.LoadFile", "failed to merge config into full settings", err)
}
return nil
@ -160,17 +165,17 @@ func (c *Config) Get(key string, out any) error {
defer c.mu.RUnlock()
if key == "" {
if err := c.v.Unmarshal(out); err != nil {
if err := c.full.Unmarshal(out); err != nil {
return coreerr.E("config.Get", "failed to unmarshal full config", err)
}
return nil
}
if !c.v.IsSet(key) {
if !c.full.IsSet(key) {
return coreerr.E("config.Get", fmt.Sprintf("key not found: %s", key), nil)
}
if err := c.v.UnmarshalKey(key, out); err != nil {
if err := c.full.UnmarshalKey(key, out); err != nil {
return coreerr.E("config.Get", fmt.Sprintf("failed to unmarshal key: %s", key), err)
}
return nil
@ -182,8 +187,8 @@ func (c *Config) Set(key string, v any) error {
c.mu.Lock()
defer c.mu.Unlock()
c.f.Set(key, v)
c.v.Set(key, v)
c.file.Set(key, v)
c.full.Set(key, v)
return nil
}
@ -194,7 +199,7 @@ func (c *Config) Commit() error {
c.mu.Lock()
defer c.mu.Unlock()
if err := Save(c.medium, c.path, c.f.AllSettings()); err != nil {
if err := Save(c.medium, c.path, c.file.AllSettings()); err != nil {
return coreerr.E("config.Commit", "failed to save config", err)
}
return nil
@ -206,7 +211,7 @@ func (c *Config) All() iter.Seq2[string, any] {
c.mu.RLock()
defer c.mu.RUnlock()
settings := c.v.AllSettings()
settings := c.full.AllSettings()
keys := make([]string, 0, len(settings))
for key := range settings {
keys = append(keys, key)

View file

@ -286,6 +286,17 @@ func TestConfig_LoadFile_Unsupported_Bad(t *testing.T) {
assert.Contains(t, err.Error(), "unsupported config file type")
}
func TestConfig_LoadFile_Unsupported_NoRead_Bad(t *testing.T) {
m := coreio.NewMockMedium()
cfg, err := New(WithMedium(m), WithPath("/tmp/test/config.txt"))
assert.NoError(t, err)
err = cfg.LoadFile(m, "/tmp/test/config.txt")
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported config file type")
}
func TestSave_Good(t *testing.T) {
m := coreio.NewMockMedium()