From d43c8ee4c1e1e05664208def94311ab91a198150 Mon Sep 17 00:00:00 2001 From: Virgil Date: Sun, 5 Apr 2026 01:19:50 +0000 Subject: [PATCH] fix: clear stale upstream state on disconnect --- pool/impl.go | 22 +++++++++++-- pool/strategy_disconnect_test.go | 56 ++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/pool/impl.go b/pool/impl.go index 9dd20de..ee475d4 100644 --- a/pool/impl.go +++ b/pool/impl.go @@ -164,8 +164,9 @@ func (c *StratumClient) Disconnect() { return } c.closedOnce.Do(func() { - if c.conn != nil { - _ = c.conn.Close() + conn := c.resetConnectionState() + if conn != nil { + _ = conn.Close() } if c.listener != nil { c.listener.OnDisconnect() @@ -175,12 +176,28 @@ func (c *StratumClient) Disconnect() { func (c *StratumClient) notifyDisconnect() { c.closedOnce.Do(func() { + c.resetConnectionState() if c.listener != nil { c.listener.OnDisconnect() } }) } +func (c *StratumClient) resetConnectionState() net.Conn { + if c == nil { + return nil + } + c.mu.Lock() + defer c.mu.Unlock() + conn := c.conn + c.conn = nil + c.tlsConn = nil + c.sessionID = "" + c.active = false + c.pending = make(map[int64]struct{}) + return conn +} + func (c *StratumClient) writeJSON(payload any) error { c.sendMu.Lock() defer c.sendMu.Unlock() @@ -472,6 +489,7 @@ func (s *FailoverStrategy) OnDisconnect() { return } s.mu.Lock() + s.client = nil closing := s.closing if closing { s.closing = false diff --git a/pool/strategy_disconnect_test.go b/pool/strategy_disconnect_test.go index de9a173..d3247dd 100644 --- a/pool/strategy_disconnect_test.go +++ b/pool/strategy_disconnect_test.go @@ -1,6 +1,7 @@ package pool import ( + "net" "sync/atomic" "testing" "time" @@ -63,3 +64,58 @@ func TestFailoverStrategy_Disconnect_Ugly(t *testing.T) { t.Fatalf("expected repeated intentional disconnects to remain silent, got %d listener calls", got) } } + +func TestStratumClient_NotifyDisconnect_ClearsState_Good(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + + spy := &disconnectSpy{} + client := &StratumClient{ + conn: clientConn, + listener: spy, + sessionID: "session-1", + active: true, + pending: map[int64]struct{}{ + 7: {}, + }, + } + + client.notifyDisconnect() + + if got := spy.disconnects.Load(); got != 1 { + t.Fatalf("expected one disconnect notification, got %d", got) + } + if client.conn != nil { + t.Fatalf("expected pooled connection to be cleared") + } + if client.sessionID != "" { + t.Fatalf("expected session id to be cleared, got %q", client.sessionID) + } + if client.IsActive() { + t.Fatalf("expected client to stop reporting active after disconnect") + } + if len(client.pending) != 0 { + t.Fatalf("expected pending submit state to be cleared, got %d entries", len(client.pending)) + } +} + +func TestFailoverStrategy_OnDisconnect_ClearsClient_Bad(t *testing.T) { + spy := &disconnectSpy{} + strategy := &FailoverStrategy{ + listener: spy, + client: &StratumClient{active: true, pending: make(map[int64]struct{})}, + } + + strategy.OnDisconnect() + time.Sleep(10 * time.Millisecond) + + if strategy.client != nil { + t.Fatalf("expected strategy to drop the stale client before reconnect") + } + if strategy.IsActive() { + t.Fatalf("expected strategy to report inactive while reconnect is pending") + } + if got := spy.disconnects.Load(); got != 1 { + t.Fatalf("expected one disconnect notification, got %d", got) + } +}