fix: add Commit() for safe config writes, iterator-based All()

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-03-09 08:29:58 +00:00
parent c81bf25431
commit ddf301fc24
4 changed files with 95 additions and 46 deletions

View file

@ -12,6 +12,8 @@ package config
import (
"fmt"
"iter"
"maps"
"os"
"path/filepath"
"strings"
@ -28,7 +30,8 @@ import (
// It uses viper as the underlying configuration engine.
type Config struct {
mu sync.RWMutex
v *viper.Viper
v *viper.Viper // Full configuration (file + env + defaults)
f *viper.Viper // File-backed configuration only (for persistence)
medium coreio.Medium
path string
}
@ -63,6 +66,7 @@ func WithEnvPrefix(prefix string) Option {
func New(opts ...Option) (*Config, error) {
c := &Config{
v: viper.New(),
f: viper.New(),
}
// Configure viper defaults
@ -105,20 +109,27 @@ func (c *Config) LoadFile(m coreio.Medium, path string) error {
content, err := m.Read(path)
if err != nil {
return coreerr.E("config.LoadFile", "failed to read config file: "+path, err)
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err)
}
ext := filepath.Ext(path)
configType := "yaml"
if ext == "" && filepath.Base(path) == ".env" {
c.v.SetConfigType("env")
configType = "env"
} else if ext != "" {
c.v.SetConfigType(strings.TrimPrefix(ext, "."))
} else {
c.v.SetConfigType("yaml")
configType = strings.TrimPrefix(ext, ".")
}
// Load into file-backed viper
c.f.SetConfigType(configType)
if err := c.f.MergeConfig(strings.NewReader(content)); err != nil {
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (f): %s", path), err)
}
// Load into full viper
c.v.SetConfigType(configType)
if err := c.v.MergeConfig(strings.NewReader(content)); err != nil {
return coreerr.E("config.LoadFile", "failed to parse config file: "+path, err)
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (v): %s", path), err)
}
return nil
@ -132,37 +143,51 @@ func (c *Config) Get(key string, out any) error {
defer c.mu.RUnlock()
if key == "" {
return c.v.Unmarshal(out)
if err := c.v.Unmarshal(out); err != nil {
return coreerr.E("config.Get", "failed to unmarshal full config", err)
}
return nil
}
if !c.v.IsSet(key) {
return coreerr.E("config.Get", fmt.Sprintf("key not found: %s", key), nil)
}
return c.v.UnmarshalKey(key, out)
if err := c.v.UnmarshalKey(key, out); err != nil {
return coreerr.E("config.Get", fmt.Sprintf("failed to unmarshal key: %s", key), err)
}
return nil
}
// Set stores a configuration value by dot-notation key and persists to disk.
// Set stores a configuration value in memory.
// Call Commit() to persist changes to disk.
func (c *Config) Set(key string, v any) error {
c.mu.Lock()
defer c.mu.Unlock()
c.f.Set(key, v)
c.v.Set(key, v)
// Persist to disk
if err := Save(c.medium, c.path, c.v.AllSettings()); err != nil {
return coreerr.E("config.Set", "failed to save config", err)
}
return nil
}
// All returns a deep copy of all configuration values.
func (c *Config) All() map[string]any {
// Commit persists any changes made via Set() to the configuration file on disk.
// This will only save the configuration that was loaded from the file or explicitly Set(),
// preventing environment variable leakage.
func (c *Config) Commit() error {
c.mu.Lock()
defer c.mu.Unlock()
if err := Save(c.medium, c.path, c.f.AllSettings()); err != nil {
return coreerr.E("config.Commit", "failed to save config", err)
}
return nil
}
// All returns an iterator over all configuration values (including environment variables).
func (c *Config) All() iter.Seq2[string, any] {
c.mu.RLock()
defer c.mu.RUnlock()
return c.v.AllSettings()
return maps.All(c.v.AllSettings())
}
// Path returns the path to the configuration file.

View file

@ -1,6 +1,7 @@
package config
import (
"maps"
"os"
"testing"
@ -44,6 +45,9 @@ func TestConfig_Set_Good(t *testing.T) {
err = cfg.Set("dev.editor", "vim")
assert.NoError(t, err)
err = cfg.Commit()
assert.NoError(t, err)
// Verify the value was saved to the medium
content, readErr := m.Read("/tmp/test/config.yaml")
assert.NoError(t, readErr)
@ -80,7 +84,7 @@ func TestConfig_All_Good(t *testing.T) {
_ = cfg.Set("key1", "val1")
_ = cfg.Set("key2", "val2")
all := cfg.All()
all := maps.Collect(cfg.All())
assert.Equal(t, "val1", all["key1"])
assert.Equal(t, "val2", all["key2"])
}

62
env.go
View file

@ -1,40 +1,52 @@
package config
import (
"iter"
"os"
"strings"
)
// LoadEnv parses environment variables with the given prefix and returns
// them as a flat map with dot-notation keys.
// Env returns an iterator over environment variables with the given prefix,
// providing them as dot-notation keys and values.
//
// For example, with prefix "CORE_CONFIG_":
//
// CORE_CONFIG_FOO_BAR=baz -> {"foo.bar": "baz"}
// CORE_CONFIG_EDITOR=vim -> {"editor": "vim"}
// CORE_CONFIG_FOO_BAR=baz -> yields ("foo.bar", "baz")
func Env(prefix string) iter.Seq2[string, any] {
return func(yield func(string, any) bool) {
for _, env := range os.Environ() {
if !strings.HasPrefix(env, prefix) {
continue
}
parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 {
continue
}
name := parts[0]
value := parts[1]
// Strip prefix and convert to dot notation
key := strings.TrimPrefix(name, prefix)
key = strings.ToLower(key)
key = strings.ReplaceAll(key, "_", ".")
if !yield(key, value) {
return
}
}
}
}
// LoadEnv parses environment variables with the given prefix and returns
// them as a flat map with dot-notation keys.
//
// Deprecated: Use Env for iterative access or collect into a map manually.
func LoadEnv(prefix string) map[string]any {
result := make(map[string]any)
for _, env := range os.Environ() {
if !strings.HasPrefix(env, prefix) {
continue
}
parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 {
continue
}
name := parts[0]
value := parts[1]
// Strip prefix and convert to dot notation
key := strings.TrimPrefix(name, prefix)
key = strings.ToLower(key)
key = strings.ReplaceAll(key, "_", ".")
result[key] = value
for k, v := range Env(prefix) {
result[k] = v
}
return result
}

View file

@ -68,6 +68,14 @@ func (s *Service) Set(key string, v any) error {
return s.config.Set(key, v)
}
// Commit persists any configuration changes to disk.
func (s *Service) Commit() error {
if s.config == nil {
return coreerr.E("config.Service.Commit", "config not loaded", nil)
}
return s.config.Commit()
}
// LoadFile merges a configuration file into the central configuration.
func (s *Service) LoadFile(m io.Medium, path string) error {
if s.config == nil {