fix(config): validate file types before read
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
d6f7c05838
commit
25559c4913
2 changed files with 44 additions and 28 deletions
61
config.go
61
config.go
|
|
@ -30,8 +30,8 @@ import (
|
|||
// It uses viper as the underlying configuration engine.
|
||||
type Config struct {
|
||||
mu sync.RWMutex
|
||||
v *viper.Viper // Full configuration (file + env + defaults)
|
||||
f *viper.Viper // File-backed configuration only (for persistence)
|
||||
full *viper.Viper // Full configuration (file + env + defaults)
|
||||
file *viper.Viper // File-backed configuration only (for persistence)
|
||||
medium coreio.Medium
|
||||
path string
|
||||
}
|
||||
|
|
@ -56,7 +56,7 @@ func WithPath(path string) Option {
|
|||
// WithEnvPrefix sets the prefix for environment variables.
|
||||
func WithEnvPrefix(prefix string) Option {
|
||||
return func(c *Config) {
|
||||
c.v.SetEnvPrefix(prefix)
|
||||
c.full.SetEnvPrefix(prefix)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -65,13 +65,13 @@ func WithEnvPrefix(prefix string) Option {
|
|||
// If no path is provided, it defaults to ~/.core/config.yaml.
|
||||
func New(opts ...Option) (*Config, error) {
|
||||
c := &Config{
|
||||
v: viper.New(),
|
||||
f: viper.New(),
|
||||
full: viper.New(),
|
||||
file: viper.New(),
|
||||
}
|
||||
|
||||
// Configure viper defaults
|
||||
c.v.SetEnvPrefix("CORE_CONFIG")
|
||||
c.v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
c.full.SetEnvPrefix("CORE_CONFIG")
|
||||
c.full.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
|
|
@ -89,7 +89,7 @@ func New(opts ...Option) (*Config, error) {
|
|||
c.path = filepath.Join(home, ".core", "config.yaml")
|
||||
}
|
||||
|
||||
c.v.AutomaticEnv()
|
||||
c.full.AutomaticEnv()
|
||||
|
||||
// Load existing config file if it exists
|
||||
if c.medium.Exists(c.path) {
|
||||
|
|
@ -127,26 +127,31 @@ func (c *Config) LoadFile(m coreio.Medium, path string) error {
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
content, err := m.Read(path)
|
||||
if err != nil {
|
||||
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err)
|
||||
}
|
||||
|
||||
configType, err := configTypeForPath(path)
|
||||
if err != nil {
|
||||
return coreerr.E("config.LoadFile", "failed to determine config file type: "+path, err)
|
||||
}
|
||||
|
||||
// Load into file-backed viper
|
||||
c.f.SetConfigType(configType)
|
||||
if err := c.f.MergeConfig(strings.NewReader(content)); err != nil {
|
||||
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (f): %s", path), err)
|
||||
content, err := m.Read(path)
|
||||
if err != nil {
|
||||
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err)
|
||||
}
|
||||
|
||||
// Load into full viper
|
||||
c.v.SetConfigType(configType)
|
||||
if err := c.v.MergeConfig(strings.NewReader(content)); err != nil {
|
||||
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (v): %s", path), err)
|
||||
parsed := viper.New()
|
||||
parsed.SetConfigType(configType)
|
||||
if err := parsed.MergeConfig(strings.NewReader(content)); err != nil {
|
||||
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file: %s", path), err)
|
||||
}
|
||||
|
||||
settings := parsed.AllSettings()
|
||||
|
||||
// Keep the persisted and runtime views aligned with the same parsed data.
|
||||
if err := c.file.MergeConfigMap(settings); err != nil {
|
||||
return coreerr.E("config.LoadFile", "failed to merge config into file settings", err)
|
||||
}
|
||||
|
||||
if err := c.full.MergeConfigMap(settings); err != nil {
|
||||
return coreerr.E("config.LoadFile", "failed to merge config into full settings", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -160,17 +165,17 @@ func (c *Config) Get(key string, out any) error {
|
|||
defer c.mu.RUnlock()
|
||||
|
||||
if key == "" {
|
||||
if err := c.v.Unmarshal(out); err != nil {
|
||||
if err := c.full.Unmarshal(out); err != nil {
|
||||
return coreerr.E("config.Get", "failed to unmarshal full config", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.v.IsSet(key) {
|
||||
if !c.full.IsSet(key) {
|
||||
return coreerr.E("config.Get", fmt.Sprintf("key not found: %s", key), nil)
|
||||
}
|
||||
|
||||
if err := c.v.UnmarshalKey(key, out); err != nil {
|
||||
if err := c.full.UnmarshalKey(key, out); err != nil {
|
||||
return coreerr.E("config.Get", fmt.Sprintf("failed to unmarshal key: %s", key), err)
|
||||
}
|
||||
return nil
|
||||
|
|
@ -182,8 +187,8 @@ func (c *Config) Set(key string, v any) error {
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.f.Set(key, v)
|
||||
c.v.Set(key, v)
|
||||
c.file.Set(key, v)
|
||||
c.full.Set(key, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -194,7 +199,7 @@ func (c *Config) Commit() error {
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if err := Save(c.medium, c.path, c.f.AllSettings()); err != nil {
|
||||
if err := Save(c.medium, c.path, c.file.AllSettings()); err != nil {
|
||||
return coreerr.E("config.Commit", "failed to save config", err)
|
||||
}
|
||||
return nil
|
||||
|
|
@ -206,7 +211,7 @@ func (c *Config) All() iter.Seq2[string, any] {
|
|||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
settings := c.v.AllSettings()
|
||||
settings := c.full.AllSettings()
|
||||
keys := make([]string, 0, len(settings))
|
||||
for key := range settings {
|
||||
keys = append(keys, key)
|
||||
|
|
|
|||
|
|
@ -286,6 +286,17 @@ func TestConfig_LoadFile_Unsupported_Bad(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "unsupported config file type")
|
||||
}
|
||||
|
||||
func TestConfig_LoadFile_Unsupported_NoRead_Bad(t *testing.T) {
|
||||
m := coreio.NewMockMedium()
|
||||
|
||||
cfg, err := New(WithMedium(m), WithPath("/tmp/test/config.txt"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = cfg.LoadFile(m, "/tmp/test/config.txt")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported config file type")
|
||||
}
|
||||
|
||||
func TestSave_Good(t *testing.T) {
|
||||
m := coreio.NewMockMedium()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue