fix: add Commit() for safe config writes, iterator-based All()
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c81bf25431
commit
ddf301fc24
4 changed files with 95 additions and 46 deletions
65
config.go
65
config.go
|
|
@ -12,6 +12,8 @@ package config
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
|
@ -28,7 +30,8 @@ import (
|
|||
// It uses viper as the underlying configuration engine.
|
||||
type Config struct {
|
||||
mu sync.RWMutex
|
||||
v *viper.Viper
|
||||
v *viper.Viper // Full configuration (file + env + defaults)
|
||||
f *viper.Viper // File-backed configuration only (for persistence)
|
||||
medium coreio.Medium
|
||||
path string
|
||||
}
|
||||
|
|
@ -63,6 +66,7 @@ func WithEnvPrefix(prefix string) Option {
|
|||
func New(opts ...Option) (*Config, error) {
|
||||
c := &Config{
|
||||
v: viper.New(),
|
||||
f: viper.New(),
|
||||
}
|
||||
|
||||
// Configure viper defaults
|
||||
|
|
@ -105,20 +109,27 @@ func (c *Config) LoadFile(m coreio.Medium, path string) error {
|
|||
|
||||
content, err := m.Read(path)
|
||||
if err != nil {
|
||||
return coreerr.E("config.LoadFile", "failed to read config file: "+path, err)
|
||||
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to read config file: %s", path), err)
|
||||
}
|
||||
|
||||
ext := filepath.Ext(path)
|
||||
configType := "yaml"
|
||||
if ext == "" && filepath.Base(path) == ".env" {
|
||||
c.v.SetConfigType("env")
|
||||
configType = "env"
|
||||
} else if ext != "" {
|
||||
c.v.SetConfigType(strings.TrimPrefix(ext, "."))
|
||||
} else {
|
||||
c.v.SetConfigType("yaml")
|
||||
configType = strings.TrimPrefix(ext, ".")
|
||||
}
|
||||
|
||||
// Load into file-backed viper
|
||||
c.f.SetConfigType(configType)
|
||||
if err := c.f.MergeConfig(strings.NewReader(content)); err != nil {
|
||||
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (f): %s", path), err)
|
||||
}
|
||||
|
||||
// Load into full viper
|
||||
c.v.SetConfigType(configType)
|
||||
if err := c.v.MergeConfig(strings.NewReader(content)); err != nil {
|
||||
return coreerr.E("config.LoadFile", "failed to parse config file: "+path, err)
|
||||
return coreerr.E("config.LoadFile", fmt.Sprintf("failed to parse config file (v): %s", path), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -132,37 +143,51 @@ func (c *Config) Get(key string, out any) error {
|
|||
defer c.mu.RUnlock()
|
||||
|
||||
if key == "" {
|
||||
return c.v.Unmarshal(out)
|
||||
if err := c.v.Unmarshal(out); err != nil {
|
||||
return coreerr.E("config.Get", "failed to unmarshal full config", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.v.IsSet(key) {
|
||||
return coreerr.E("config.Get", fmt.Sprintf("key not found: %s", key), nil)
|
||||
}
|
||||
|
||||
return c.v.UnmarshalKey(key, out)
|
||||
if err := c.v.UnmarshalKey(key, out); err != nil {
|
||||
return coreerr.E("config.Get", fmt.Sprintf("failed to unmarshal key: %s", key), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set stores a configuration value by dot-notation key and persists to disk.
|
||||
// Set stores a configuration value in memory.
|
||||
// Call Commit() to persist changes to disk.
|
||||
func (c *Config) Set(key string, v any) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.f.Set(key, v)
|
||||
c.v.Set(key, v)
|
||||
|
||||
// Persist to disk
|
||||
if err := Save(c.medium, c.path, c.v.AllSettings()); err != nil {
|
||||
return coreerr.E("config.Set", "failed to save config", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// All returns a deep copy of all configuration values.
|
||||
func (c *Config) All() map[string]any {
|
||||
// Commit persists any changes made via Set() to the configuration file on disk.
|
||||
// This will only save the configuration that was loaded from the file or explicitly Set(),
|
||||
// preventing environment variable leakage.
|
||||
func (c *Config) Commit() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if err := Save(c.medium, c.path, c.f.AllSettings()); err != nil {
|
||||
return coreerr.E("config.Commit", "failed to save config", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// All returns an iterator over all configuration values (including environment variables).
|
||||
func (c *Config) All() iter.Seq2[string, any] {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return c.v.AllSettings()
|
||||
return maps.All(c.v.AllSettings())
|
||||
}
|
||||
|
||||
// Path returns the path to the configuration file.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
|
|
@ -44,6 +45,9 @@ func TestConfig_Set_Good(t *testing.T) {
|
|||
err = cfg.Set("dev.editor", "vim")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = cfg.Commit()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify the value was saved to the medium
|
||||
content, readErr := m.Read("/tmp/test/config.yaml")
|
||||
assert.NoError(t, readErr)
|
||||
|
|
@ -80,7 +84,7 @@ func TestConfig_All_Good(t *testing.T) {
|
|||
_ = cfg.Set("key1", "val1")
|
||||
_ = cfg.Set("key2", "val2")
|
||||
|
||||
all := cfg.All()
|
||||
all := maps.Collect(cfg.All())
|
||||
assert.Equal(t, "val1", all["key1"])
|
||||
assert.Equal(t, "val2", all["key2"])
|
||||
}
|
||||
|
|
|
|||
62
env.go
62
env.go
|
|
@ -1,40 +1,52 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// LoadEnv parses environment variables with the given prefix and returns
|
||||
// them as a flat map with dot-notation keys.
|
||||
// Env returns an iterator over environment variables with the given prefix,
|
||||
// providing them as dot-notation keys and values.
|
||||
//
|
||||
// For example, with prefix "CORE_CONFIG_":
|
||||
//
|
||||
// CORE_CONFIG_FOO_BAR=baz -> {"foo.bar": "baz"}
|
||||
// CORE_CONFIG_EDITOR=vim -> {"editor": "vim"}
|
||||
// CORE_CONFIG_FOO_BAR=baz -> yields ("foo.bar", "baz")
|
||||
func Env(prefix string) iter.Seq2[string, any] {
|
||||
return func(yield func(string, any) bool) {
|
||||
for _, env := range os.Environ() {
|
||||
if !strings.HasPrefix(env, prefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
name := parts[0]
|
||||
value := parts[1]
|
||||
|
||||
// Strip prefix and convert to dot notation
|
||||
key := strings.TrimPrefix(name, prefix)
|
||||
key = strings.ToLower(key)
|
||||
key = strings.ReplaceAll(key, "_", ".")
|
||||
|
||||
if !yield(key, value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadEnv parses environment variables with the given prefix and returns
|
||||
// them as a flat map with dot-notation keys.
|
||||
//
|
||||
// Deprecated: Use Env for iterative access or collect into a map manually.
|
||||
func LoadEnv(prefix string) map[string]any {
|
||||
result := make(map[string]any)
|
||||
|
||||
for _, env := range os.Environ() {
|
||||
if !strings.HasPrefix(env, prefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
name := parts[0]
|
||||
value := parts[1]
|
||||
|
||||
// Strip prefix and convert to dot notation
|
||||
key := strings.TrimPrefix(name, prefix)
|
||||
key = strings.ToLower(key)
|
||||
key = strings.ReplaceAll(key, "_", ".")
|
||||
|
||||
result[key] = value
|
||||
for k, v := range Env(prefix) {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,6 +68,14 @@ func (s *Service) Set(key string, v any) error {
|
|||
return s.config.Set(key, v)
|
||||
}
|
||||
|
||||
// Commit persists any configuration changes to disk.
|
||||
func (s *Service) Commit() error {
|
||||
if s.config == nil {
|
||||
return coreerr.E("config.Service.Commit", "config not loaded", nil)
|
||||
}
|
||||
return s.config.Commit()
|
||||
}
|
||||
|
||||
// LoadFile merges a configuration file into the central configuration.
|
||||
func (s *Service) LoadFile(m io.Medium, path string) error {
|
||||
if s.config == nil {
|
||||
|
|
|
|||
Reference in a new issue