diff --git a/configwatcher_test.go b/configwatcher_test.go index fecfbcb..a79a8fb 100644 --- a/configwatcher_test.go +++ b/configwatcher_test.go @@ -4,12 +4,13 @@ import ( "os" "path/filepath" "testing" + "time" ) func TestConfigWatcher_New_Good(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "config.json") - if err := os.WriteFile(path, []byte(`{"mode":"nicehash"}`), 0o644); err != nil { + if err := os.WriteFile(path, []byte(`{"mode":"nicehash","workers":"false","bind":[{"host":"127.0.0.1","port":3333}],"pools":[{"url":"pool.example:3333","enabled":true}]}`), 0o644); err != nil { t.Fatalf("write config file: %v", err) } @@ -21,3 +22,42 @@ func TestConfigWatcher_New_Good(t *testing.T) { t.Fatal("expected last modification time to be initialised from the file") } } + +func TestConfigWatcher_Start_Good(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.json") + initial := []byte(`{"mode":"nicehash","workers":"false","bind":[{"host":"127.0.0.1","port":3333}],"pools":[{"url":"pool.example:3333","enabled":true}]}`) + if err := os.WriteFile(path, initial, 0o644); err != nil { + t.Fatalf("write initial config file: %v", err) + } + + updates := make(chan *Config, 1) + watcher := NewConfigWatcher(path, func(cfg *Config) { + select { + case updates <- cfg: + default: + } + }) + if watcher == nil { + t.Fatal("expected watcher") + } + watcher.Start() + defer watcher.Stop() + + updated := []byte(`{"mode":"simple","workers":"user","bind":[{"host":"127.0.0.1","port":3333}],"pools":[{"url":"pool.example:3333","enabled":true}]}`) + if err := os.WriteFile(path, updated, 0o644); err != nil { + t.Fatalf("write updated config file: %v", err) + } + + select { + case cfg := <-updates: + if cfg == nil { + t.Fatal("expected config update") + } + if got := cfg.Mode; got != "simple" { + t.Fatalf("expected updated mode, got %q", got) + } + case <-time.After(5 * time.Second): + t.Fatal("expected watcher to reload updated config") + } +} diff --git a/core_impl.go b/core_impl.go index f6ce517..bfe04ef 100644 --- a/core_impl.go +++ b/core_impl.go @@ -11,10 +11,13 @@ import ( "math" "net" "os" + "path/filepath" "strconv" "strings" "sync" "time" + + "github.com/fsnotify/fsnotify" ) // Result is the success/error carrier used by constructors and loaders. @@ -375,23 +378,65 @@ func (w *ConfigWatcher) Start() { if w == nil || w.path == "" || w.onChange == nil { return } + w.mu.Lock() + if w.watcher != nil { + w.mu.Unlock() + return + } + fsWatcher, err := fsnotify.NewWatcher() + if err != nil { + w.mu.Unlock() + return + } + w.watcher = fsWatcher + w.mu.Unlock() + + watchPath := filepath.Clean(w.path) + watchDir := filepath.Dir(watchPath) + if watchDir == "" { + watchDir = "." + } + if err := fsWatcher.Add(watchDir); err != nil { + _ = fsWatcher.Close() + w.mu.Lock() + if w.watcher == fsWatcher { + w.watcher = nil + } + w.mu.Unlock() + return + } + go func() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() + defer func() { + _ = fsWatcher.Close() + w.mu.Lock() + if w.watcher == fsWatcher { + w.watcher = nil + } + w.mu.Unlock() + }() for { select { - case <-ticker.C: - info, err := os.Stat(w.path) - if err != nil { + case event, ok := <-fsWatcher.Events: + if !ok { + return + } + if filepath.Clean(event.Name) != watchPath { continue } - mod := info.ModTime() - if mod.After(w.lastMod) { - w.lastMod = mod - config, result := LoadConfig(w.path) - if result.OK && config != nil { - w.onChange(config) + if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename|fsnotify.Remove|fsnotify.Chmod) == 0 { + continue + } + config, result := LoadConfig(w.path) + if result.OK && config != nil { + if info, err := os.Stat(w.path); err == nil { + w.lastMod = info.ModTime() } + w.onChange(config) + } + case _, ok := <-fsWatcher.Errors: + if !ok { + return } case <-w.done: return @@ -405,6 +450,12 @@ func (w *ConfigWatcher) Stop() { if w == nil { return } + w.mu.Lock() + if w.watcher != nil { + _ = w.watcher.Close() + w.watcher = nil + } + w.mu.Unlock() select { case <-w.done: default: diff --git a/go.mod b/go.mod index b954a44..f14ce7b 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,7 @@ module dappco.re/go/proxy go 1.26.0 + +require github.com/fsnotify/fsnotify v1.7.0 + +require golang.org/x/sys v0.4.0 // indirect diff --git a/go.sum b/go.sum index e69de29..ccd7ce9 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/proxy.go b/proxy.go index aab5894..16733d4 100644 --- a/proxy.go +++ b/proxy.go @@ -8,6 +8,7 @@ package proxy import ( + "github.com/fsnotify/fsnotify" "net/http" "sync" "sync/atomic" @@ -120,6 +121,8 @@ type ConfigWatcher struct { onChange func(*Config) lastMod time.Time done chan struct{} + mu sync.Mutex + watcher *fsnotify.Watcher } // RateLimiter throttles new connections per source IP.