diff --git a/pool/client.go b/pool/client.go index eaf0f15..1e51853 100644 --- a/pool/client.go +++ b/pool/client.go @@ -284,7 +284,5 @@ func (c *StratumClient) handleMessage(response jsonRPCResponse) { if response.Error != nil { errorMessage = response.Error.Message } - if !accepted || len(response.Result) > 0 { - c.listener.OnResultAccepted(response.ID, accepted, errorMessage) - } + c.listener.OnResultAccepted(response.ID, accepted, errorMessage) } diff --git a/pool/client_test.go b/pool/client_test.go index db64219..e5e0750 100644 --- a/pool/client_test.go +++ b/pool/client_test.go @@ -145,6 +145,26 @@ func TestStratumClient_HandleMessage_Bad(t *testing.T) { } } +func TestStratumClient_HandleMessage_Good(t *testing.T) { + listener := &resultCapturingListener{} + client := &StratumClient{ + listener: listener, + sessionID: "session-1", + } + + client.handleMessage(jsonRPCResponse{ + ID: 7, + }) + + sequence, accepted, errorMessage, results, disconnects := listener.Snapshot() + if sequence != 7 || !accepted || errorMessage != "" || results != 1 { + t.Fatalf("expected accepted submit callback, got sequence=%d accepted=%v error=%q results=%d", sequence, accepted, errorMessage, results) + } + if disconnects != 0 { + t.Fatalf("expected no disconnect on submit accept, got %d", disconnects) + } +} + func TestStratumClient_HandleMessage_Ugly(t *testing.T) { serverConn, clientConn := net.Pipe() defer clientConn.Close() diff --git a/pool/strategy.go b/pool/strategy.go index 5d881ef..3dcf770 100644 --- a/pool/strategy.go +++ b/pool/strategy.go @@ -179,6 +179,10 @@ func (s *FailoverStrategy) OnDisconnect() { client := s.client s.client = nil closed := s.closed + nextStart := 0 + if len(s.pools) > 0 { + nextStart = (s.current + 1) % len(s.pools) + } s.mu.Unlock() if s.listener != nil { @@ -192,6 +196,6 @@ func (s *FailoverStrategy) OnDisconnect() { } go func() { time.Sleep(10 * time.Millisecond) - s.connectFrom(0) + s.connectFrom(nextStart) }() } diff --git a/pool/strategy_test.go b/pool/strategy_test.go index 0076c5c..e65c003 100644 --- a/pool/strategy_test.go +++ b/pool/strategy_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net" "sync" + "sync/atomic" "testing" "time" @@ -108,3 +109,94 @@ func TestFailoverStrategy_Connect_Ugly(t *testing.T) { t.Fatal("expected disconnect callback before failover reconnect") } } + +func TestFailoverStrategy_OnDisconnect_Good(t *testing.T) { + primaryListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer primaryListener.Close() + + backupListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer backupListener.Close() + + var primaryConnections atomic.Int32 + + go func() { + for { + conn, acceptErr := primaryListener.Accept() + if acceptErr != nil { + return + } + primaryConnections.Add(1) + go func(connection net.Conn) { + defer connection.Close() + reader := bufio.NewReader(connection) + if _, readErr := reader.ReadBytes('\n'); readErr != nil { + return + } + _ = json.NewEncoder(connection).Encode(map[string]interface{}{ + "id": 1, + "jsonrpc": "2.0", + "error": map[string]interface{}{ + "code": -1, + "message": "Unauthenticated", + }, + }) + }(conn) + } + }() + + go func() { + conn, acceptErr := backupListener.Accept() + if acceptErr != nil { + return + } + defer conn.Close() + + reader := bufio.NewReader(conn) + if _, readErr := reader.ReadBytes('\n'); readErr != nil { + return + } + _ = json.NewEncoder(conn).Encode(map[string]interface{}{ + "id": 1, + "jsonrpc": "2.0", + "error": nil, + "result": map[string]interface{}{ + "id": "session-1", + "job": map[string]interface{}{ + "blob": "abcd", + "job_id": "job-1", + "target": "b88d0600", + }, + }, + }) + }() + + listener := &strategyTestListener{ + jobCh: make(chan proxy.Job, 1), + } + strategy := NewFailoverStrategy([]proxy.PoolConfig{ + {URL: primaryListener.Addr().String(), Enabled: true}, + {URL: backupListener.Addr().String(), Enabled: true}, + }, listener, &proxy.Config{Retries: 1}) + + strategy.Connect() + defer strategy.Disconnect() + + select { + case job := <-listener.jobCh: + if job.JobID != "job-1" { + t.Fatalf("expected backup job, got %+v", job) + } + case <-time.After(3 * time.Second): + t.Fatalf("expected backup job after primary disconnect, primary connections=%d", primaryConnections.Load()) + } + + if listener.Disconnects() == 0 { + t.Fatal("expected disconnect callback before failover reconnect") + } +}