From c7d688ccfabbed35d29a1444dd972cb921794100 Mon Sep 17 00:00:00 2001 From: Virgil Date: Sat, 4 Apr 2026 20:51:14 +0000 Subject: [PATCH] fix(proxy): drain pending submits on stop Co-Authored-By: Virgil --- state_impl.go | 22 ++++++++++++++++++++-- state_submit_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 state_submit_test.go diff --git a/state_impl.go b/state_impl.go index 5977ba5..fc56bd8 100644 --- a/state_impl.go +++ b/state_impl.go @@ -75,6 +75,8 @@ func New(cfg *Config) (*Proxy, Result) { p.events.Subscribe(EventAccept, p.customDiffBuckets.OnAccept) p.events.Subscribe(EventReject, p.customDiffBuckets.OnReject) } + p.events.Subscribe(EventAccept, p.onShareSettled) + p.events.Subscribe(EventReject, p.onShareSettled) if cfg.Watch && cfg.sourcePath != "" { p.watcher = NewConfigWatcher(cfg.sourcePath, p.Reload) } @@ -339,6 +341,21 @@ func (p *Proxy) Reload(cfg *Config) { } } +func (p *Proxy) onShareSettled(Event) { + if p == nil { + return + } + for { + current := p.submitCount.Load() + if current == 0 { + return + } + if p.submitCount.CompareAndSwap(current, current-1) { + return + } + } +} + func (p *Proxy) acceptMiner(conn net.Conn, localPort uint16) { if p == nil { _ = conn.Close() @@ -361,9 +378,10 @@ func (p *Proxy) acceptMiner(conn net.Conn, localPort uint16) { } } miner.onSubmit = func(m *Miner, event *SubmitEvent) { - p.submitCount.Add(1) - defer p.submitCount.Add(-1) if p.splitter != nil { + if _, ok := p.splitter.(*noopSplitter); !ok { + p.submitCount.Add(1) + } p.splitter.OnSubmit(event) } } diff --git a/state_submit_test.go b/state_submit_test.go new file mode 100644 index 0000000..28175aa --- /dev/null +++ b/state_submit_test.go @@ -0,0 +1,33 @@ +package proxy + +import ( + "testing" + "time" +) + +func TestProxy_Stop_WaitsForSubmitDrain(t *testing.T) { + p := &Proxy{ + done: make(chan struct{}), + } + p.submitCount.Store(1) + + stopped := make(chan struct{}) + go func() { + p.Stop() + close(stopped) + }() + + select { + case <-stopped: + t.Fatalf("expected Stop to wait for pending submits") + case <-time.After(50 * time.Millisecond): + } + + p.submitCount.Store(0) + + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatalf("expected Stop to finish after pending submits drain") + } +}