From ddf301fc24112f586cdd97bce30ba88823e1dc17 Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 9 Mar 2026 08:29:58 +0000 Subject: [PATCH] fix: add Commit() for safe config writes, iterator-based All() Co-Authored-By: Claude Opus 4.6 --- config.go | 65 ++++++++++++++++++++++++++++++++++---------------- config_test.go | 6 ++++- env.go | 62 ++++++++++++++++++++++++++++------------------- service.go | 8 +++++++ 4 files changed, 95 insertions(+), 46 deletions(-) diff --git a/config.go b/config.go index 5d13349..ef367e6 100644 --- a/config.go +++ b/config.go @@ -12,6 +12,8 @@ package config import ( "fmt" + "iter" + "maps" "os" "path/filepath" "strings" @@ -28,7 +30,8 @@ import ( // It uses viper as the underlying configuration engine. type Config struct { mu sync.RWMutex - v *viper.Viper + v *viper.Viper // Full configuration (file + env + defaults) + f *viper.Viper // File-backed configuration only (for persistence) medium coreio.Medium path string } @@ -63,6 +66,7 @@ func WithEnvPrefix(prefix string) Option { func New(opts ...Option) (*Config, error) { c := &Config{ v: viper.New(), + f: viper.New(), } // Configure viper defaults @@ -105,20 +109,27 @@ func (c *Config) LoadFile(m coreio.Medium, path string) error { content, err := m.Read(path) if err != nil { - return coreerr.E("config.LoadFile", "failed to read config file: "+path, err) + return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err) } ext := filepath.Ext(path) + configType := "yaml" if ext == "" && filepath.Base(path) == ".env" { - c.v.SetConfigType("env") + configType = "env" } else if ext != "" { - c.v.SetConfigType(strings.TrimPrefix(ext, ".")) - } else { - c.v.SetConfigType("yaml") + configType = strings.TrimPrefix(ext, ".") } + // Load into file-backed viper + c.f.SetConfigType(configType) + if err := c.f.MergeConfig(strings.NewReader(content)); err != nil { + return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (f): %s", path), err) + } + + // Load into full viper + c.v.SetConfigType(configType) if err := c.v.MergeConfig(strings.NewReader(content)); err != nil { - return coreerr.E("config.LoadFile", "failed to parse config file: "+path, err) + return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (v): %s", path), err) } return nil @@ -132,37 +143,51 @@ func (c *Config) Get(key string, out any) error { defer c.mu.RUnlock() if key == "" { - return c.v.Unmarshal(out) + if err := c.v.Unmarshal(out); err != nil { + return coreerr.E("config.Get", "failed to unmarshal full config", err) + } + return nil } if !c.v.IsSet(key) { return coreerr.E("config.Get", fmt.Sprintf("key not found: %s", key), nil) } - return c.v.UnmarshalKey(key, out) + if err := c.v.UnmarshalKey(key, out); err != nil { + return coreerr.E("config.Get", fmt.Sprintf("failed to unmarshal key: %s", key), err) + } + return nil } -// Set stores a configuration value by dot-notation key and persists to disk. +// Set stores a configuration value in memory. +// Call Commit() to persist changes to disk. func (c *Config) Set(key string, v any) error { c.mu.Lock() defer c.mu.Unlock() + c.f.Set(key, v) c.v.Set(key, v) - - // Persist to disk - if err := Save(c.medium, c.path, c.v.AllSettings()); err != nil { - return coreerr.E("config.Set", "failed to save config", err) - } - return nil } -// All returns a deep copy of all configuration values. -func (c *Config) All() map[string]any { +// Commit persists any changes made via Set() to the configuration file on disk. +// This will only save the configuration that was loaded from the file or explicitly Set(), +// preventing environment variable leakage. +func (c *Config) Commit() error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := Save(c.medium, c.path, c.f.AllSettings()); err != nil { + return coreerr.E("config.Commit", "failed to save config", err) + } + return nil +} + +// All returns an iterator over all configuration values (including environment variables). +func (c *Config) All() iter.Seq2[string, any] { c.mu.RLock() defer c.mu.RUnlock() - - return c.v.AllSettings() + return maps.All(c.v.AllSettings()) } // Path returns the path to the configuration file. diff --git a/config_test.go b/config_test.go index f899b72..6e75cef 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package config import ( + "maps" "os" "testing" @@ -44,6 +45,9 @@ func TestConfig_Set_Good(t *testing.T) { err = cfg.Set("dev.editor", "vim") assert.NoError(t, err) + err = cfg.Commit() + assert.NoError(t, err) + // Verify the value was saved to the medium content, readErr := m.Read("/tmp/test/config.yaml") assert.NoError(t, readErr) @@ -80,7 +84,7 @@ func TestConfig_All_Good(t *testing.T) { _ = cfg.Set("key1", "val1") _ = cfg.Set("key2", "val2") - all := cfg.All() + all := maps.Collect(cfg.All()) assert.Equal(t, "val1", all["key1"]) assert.Equal(t, "val2", all["key2"]) } diff --git a/env.go b/env.go index 711e3ec..64c0372 100644 --- a/env.go +++ b/env.go @@ -1,40 +1,52 @@ package config import ( + "iter" "os" "strings" ) -// LoadEnv parses environment variables with the given prefix and returns -// them as a flat map with dot-notation keys. +// Env returns an iterator over environment variables with the given prefix, +// providing them as dot-notation keys and values. // // For example, with prefix "CORE_CONFIG_": // -// CORE_CONFIG_FOO_BAR=baz -> {"foo.bar": "baz"} -// CORE_CONFIG_EDITOR=vim -> {"editor": "vim"} +// CORE_CONFIG_FOO_BAR=baz -> yields ("foo.bar", "baz") +func Env(prefix string) iter.Seq2[string, any] { + return func(yield func(string, any) bool) { + for _, env := range os.Environ() { + if !strings.HasPrefix(env, prefix) { + continue + } + + parts := strings.SplitN(env, "=", 2) + if len(parts) != 2 { + continue + } + + name := parts[0] + value := parts[1] + + // Strip prefix and convert to dot notation + key := strings.TrimPrefix(name, prefix) + key = strings.ToLower(key) + key = strings.ReplaceAll(key, "_", ".") + + if !yield(key, value) { + return + } + } + } +} + +// LoadEnv parses environment variables with the given prefix and returns +// them as a flat map with dot-notation keys. +// +// Deprecated: Use Env for iterative access or collect into a map manually. func LoadEnv(prefix string) map[string]any { result := make(map[string]any) - - for _, env := range os.Environ() { - if !strings.HasPrefix(env, prefix) { - continue - } - - parts := strings.SplitN(env, "=", 2) - if len(parts) != 2 { - continue - } - - name := parts[0] - value := parts[1] - - // Strip prefix and convert to dot notation - key := strings.TrimPrefix(name, prefix) - key = strings.ToLower(key) - key = strings.ReplaceAll(key, "_", ".") - - result[key] = value + for k, v := range Env(prefix) { + result[k] = v } - return result } diff --git a/service.go b/service.go index 5996789..8c7acf8 100644 --- a/service.go +++ b/service.go @@ -68,6 +68,14 @@ func (s *Service) Set(key string, v any) error { return s.config.Set(key, v) } +// Commit persists any configuration changes to disk. +func (s *Service) Commit() error { + if s.config == nil { + return coreerr.E("config.Service.Commit", "config not loaded", nil) + } + return s.config.Commit() +} + // LoadFile merges a configuration file into the central configuration. func (s *Service) LoadFile(m io.Medium, path string) error { if s.config == nil {