diff --git a/CLAUDE.md b/CLAUDE.md index fa72d28..635c8ae 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -16,6 +16,10 @@ go test -v -run Name # Run single test - `Hub` manages WebSocket connections and channel subscriptions - Messages types: `process_output`, `process_status`, `event`, `error`, `ping/pong`, `subscribe/unsubscribe` - `hub.SendProcessOutput(id, line)` broadcasts to subscribers +- `HubConfig` provides configurable heartbeat, pong timeout, write timeout, and connection callbacks +- `ReconnectingClient` provides client-side reconnection with exponential backoff +- `ConnectionState`: `StateDisconnected`, `StateConnecting`, `StateConnected` +- Coverage: 98.5% ## Coding Standards diff --git a/FINDINGS.md b/FINDINGS.md index ccc8acc..d4c75c6 100644 --- a/FINDINGS.md +++ b/FINDINGS.md @@ -24,3 +24,37 @@ Extracted from `forge.lthn.ai/core/go` `pkg/ws/` on 19 Feb 2026. - Hub must be started with `go hub.Run(ctx)` before accepting connections - HTTP handler exposed via `hub.Handler()` for mounting on any router - `hub.SendProcessOutput(processID, line)` is the primary API for streaming subprocess output + +## 2026-02-20: Phase 0 & Phase 1 (Charon) + +### Race condition fix + +- `SendToChannel` had a data race: it acquired `RLock`, read the channel's client map, released `RUnlock`, then iterated clients outside the lock. If `Hub.Run` processed an unregister concurrently, the map was modified during iteration. +- Fix: copy client pointers into a slice under `RLock`, then iterate the copy after releasing the lock. + +### Phase 0: Test coverage 88.4% to 98.5% + +- Added 16 new test functions covering: hub shutdown, broadcast overflow, channel send overflow, marshal errors, upgrade error, `Client.Close`, malformed JSON, non-string subscribe/unsubscribe data, unknown message types, writePump close/batch, concurrent subscribe/unsubscribe, multi-client channel delivery, end-to-end process output/status. +- Added `BenchmarkBroadcast` (100 clients) and `BenchmarkSendToChannel` (50 subscribers). +- `go vet ./...` clean; `go test -race ./...` clean. + +### Phase 1: Connection resilience + +- `HubConfig` struct: `HeartbeatInterval`, `PongTimeout`, `WriteTimeout`, `OnConnect`, `OnDisconnect` callbacks. +- `NewHubWithConfig(config)`: constructor with validation and defaults. +- `readPump`/`writePump` now use hub config values instead of hardcoded durations. +- `ReconnectingClient`: client-side reconnection with exponential backoff. + - `ReconnectConfig`: URL, InitialBackoff (1s), MaxBackoff (30s), BackoffMultiplier (2.0), MaxRetries, Dialer, Headers, OnConnect/OnDisconnect/OnReconnect/OnMessage callbacks. + - `Connect(ctx)`: blocking reconnect loop; returns on context cancel or max retries. + - `Send(msg)`: thread-safe message send; returns error if not connected. + - `State()`: returns `StateDisconnected`, `StateConnecting`, or `StateConnected`. + - `Close()`: cancels context and closes underlying connection. + - Exponential backoff: `calculateBackoff(attempt)` doubles each attempt, capped at MaxBackoff. + +### API surface additions + +- `HubConfig`, `DefaultHubConfig()`, `NewHubWithConfig()` +- `ConnectionState` enum: `StateDisconnected`, `StateConnecting`, `StateConnected` +- `ReconnectConfig`, `ReconnectingClient`, `NewReconnectingClient()` +- `DefaultHeartbeatInterval`, `DefaultPongTimeout`, `DefaultWriteTimeout` constants +- `NewHub()` still works unchanged (uses `DefaultHubConfig()` internally) diff --git a/TODO.md b/TODO.md index 6581d12..3f83d21 100644 --- a/TODO.md +++ b/TODO.md @@ -6,16 +6,17 @@ Dispatched from core/go orchestration. Pick up tasks in order. ## Phase 0: Hardening & Test Coverage -- [ ] **Expand test coverage** — `ws_test.go` exists. Add tests for: `Hub.Run()` lifecycle (start, register client, broadcast, shutdown), `Subscribe`/`Unsubscribe` channel management, `SendToChannel` with no subscribers (should not error), `SendProcessOutput`/`SendProcessStatus` helpers, client `readPump` message parsing (subscribe, unsubscribe, ping), client `writePump` batch sending, client buffer overflow (send channel full → client disconnected), concurrent broadcast + subscribe (race test). -- [ ] **Integration test** — Use `httptest.NewServer` + real WebSocket client. Connect, subscribe to channel, send message, verify receipt. Test multiple clients on same channel. -- [ ] **Benchmark** — `BenchmarkBroadcast` with 100 connected clients. `BenchmarkSendToChannel` with 50 subscribers. Measure message throughput. -- [ ] **`go vet ./...` clean** — Fix any warnings. +- [x] **Expand test coverage** — Added tests for: `Hub.Run()` shutdown closing all clients, broadcast to client with full buffer (unregister path), `SendToChannel` with full client buffer (skip path), `Broadcast`/`SendToChannel` marshal errors, `Handler` upgrade error on non-WebSocket request, `Client.Close()`, `readPump` malformed JSON, subscribe/unsubscribe with non-string data, unknown message types, `writePump` close-on-channel-close and batch sending, concurrent subscribe/unsubscribe race test, multiple clients on same channel, end-to-end process output and status tests. Coverage: 88.4% → 98.5%. +- [x] **Integration test** — Full end-to-end tests using `httptest.NewServer` + real WebSocket clients. Multi-client channel delivery, process output streaming, process status updates. +- [x] **Benchmark** — `BenchmarkBroadcast` with 100 clients, `BenchmarkSendToChannel` with 50 subscribers. +- [x] **`go vet ./...` clean** — No warnings. +- [x] **Race condition fix** — Fixed data race in `SendToChannel` where client map was iterated outside the read lock. Clients are now copied under lock before iteration. ## Phase 1: Connection Resilience -- [ ] Add client-side reconnection support (exponential backoff) -- [ ] Tune heartbeat interval and pong timeout for flaky networks -- [ ] Add connection state callbacks (onConnect, onDisconnect, onReconnect) +- [x] Add client-side reconnection support (exponential backoff) — `ReconnectingClient` with `ReconnectConfig`. Configurable initial backoff, max backoff, multiplier, max retries. +- [x] Tune heartbeat interval and pong timeout for flaky networks — `HubConfig` with `HeartbeatInterval`, `PongTimeout`, `WriteTimeout`. `NewHubWithConfig()` constructor. Defaults: 30s heartbeat, 60s pong timeout, 10s write timeout. +- [x] Add connection state callbacks (onConnect, onDisconnect, onReconnect) — Hub-level `OnConnect`/`OnDisconnect` callbacks in `HubConfig`. Client-level `OnConnect`/`OnDisconnect`/`OnReconnect` callbacks in `ReconnectConfig`. `ConnectionState` enum: `StateDisconnected`, `StateConnecting`, `StateConnected`. ## Phase 2: Auth diff --git a/ws.go b/ws.go index 16dd6f7..cbc31ee 100644 --- a/ws.go +++ b/ws.go @@ -63,6 +63,56 @@ var upgrader = websocket.Upgrader{ }, } +// Default timing values for heartbeat and pong timeout. +const ( + DefaultHeartbeatInterval = 30 * time.Second + DefaultPongTimeout = 60 * time.Second + DefaultWriteTimeout = 10 * time.Second +) + +// ConnectionState represents the current state of a reconnecting client. +type ConnectionState int + +const ( + // StateDisconnected indicates the client is not connected. + StateDisconnected ConnectionState = iota + // StateConnecting indicates the client is attempting to connect. + StateConnecting + // StateConnected indicates the client has an active connection. + StateConnected +) + +// HubConfig holds configuration for the Hub and its managed connections. +type HubConfig struct { + // HeartbeatInterval is the interval between server-side ping messages. + // Defaults to 30 seconds. + HeartbeatInterval time.Duration + + // PongTimeout is how long the server waits for a pong before + // considering the connection dead. Must be greater than HeartbeatInterval. + // Defaults to 60 seconds. + PongTimeout time.Duration + + // WriteTimeout is the deadline for write operations. + // Defaults to 10 seconds. + WriteTimeout time.Duration + + // OnConnect is called when a client connects to the hub. + OnConnect func(client *Client) + + // OnDisconnect is called when a client disconnects from the hub. + OnDisconnect func(client *Client) +} + +// DefaultHubConfig returns a HubConfig with sensible defaults. +func DefaultHubConfig() HubConfig { + return HubConfig{ + HeartbeatInterval: DefaultHeartbeatInterval, + PongTimeout: DefaultPongTimeout, + WriteTimeout: DefaultWriteTimeout, + } +} + // MessageType identifies the type of WebSocket message. type MessageType string @@ -110,17 +160,33 @@ type Hub struct { register chan *Client unregister chan *Client channels map[string]map[*Client]bool + config HubConfig mu sync.RWMutex } -// NewHub creates a new WebSocket hub. +// NewHub creates a new WebSocket hub with default configuration. func NewHub() *Hub { + return NewHubWithConfig(DefaultHubConfig()) +} + +// NewHubWithConfig creates a new WebSocket hub with the given configuration. +func NewHubWithConfig(config HubConfig) *Hub { + if config.HeartbeatInterval <= 0 { + config.HeartbeatInterval = DefaultHeartbeatInterval + } + if config.PongTimeout <= 0 { + config.PongTimeout = DefaultPongTimeout + } + if config.WriteTimeout <= 0 { + config.WriteTimeout = DefaultWriteTimeout + } return &Hub{ clients: make(map[*Client]bool), broadcast: make(chan []byte, 256), register: make(chan *Client), unregister: make(chan *Client), channels: make(map[string]map[*Client]bool), + config: config, } } @@ -142,6 +208,9 @@ func (h *Hub) Run(ctx context.Context) { h.mu.Lock() h.clients[client] = true h.mu.Unlock() + if h.config.OnConnect != nil { + h.config.OnConnect(client) + } case client := <-h.unregister: h.mu.Lock() if _, ok := h.clients[client]; ok { @@ -157,8 +226,13 @@ func (h *Hub) Run(ctx context.Context) { } } } + h.mu.Unlock() + if h.config.OnDisconnect != nil { + h.config.OnDisconnect(client) + } + } else { + h.mu.Unlock() } - h.mu.Unlock() case message := <-h.broadcast: h.mu.RLock() for client := range h.clients { @@ -236,13 +310,19 @@ func (h *Hub) SendToChannel(channel string, msg Message) error { h.mu.RLock() clients, ok := h.channels[channel] - h.mu.RUnlock() - if !ok { + h.mu.RUnlock() return nil // No subscribers, not an error } + // Copy client references under lock to avoid races during iteration + targets := make([]*Client, 0, len(clients)) for client := range clients { + targets = append(targets, client) + } + h.mu.RUnlock() + + for _, client := range targets { select { case client.send <- data: default: @@ -366,10 +446,11 @@ func (c *Client) readPump() { c.conn.Close() }() + pongTimeout := c.hub.config.PongTimeout c.conn.SetReadLimit(65536) - c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.conn.SetReadDeadline(time.Now().Add(pongTimeout)) c.conn.SetPongHandler(func(string) error { - c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.conn.SetReadDeadline(time.Now().Add(pongTimeout)) return nil }) @@ -401,7 +482,9 @@ func (c *Client) readPump() { // writePump sends messages to the client. func (c *Client) writePump() { - ticker := time.NewTicker(30 * time.Second) + heartbeat := c.hub.config.HeartbeatInterval + writeTimeout := c.hub.config.WriteTimeout + ticker := time.NewTicker(heartbeat) defer func() { ticker.Stop() c.conn.Close() @@ -410,7 +493,7 @@ func (c *Client) writePump() { for { select { case message, ok := <-c.send: - c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return @@ -433,7 +516,7 @@ func (c *Client) writePump() { return } case <-ticker.C: - c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } @@ -463,3 +546,234 @@ func (c *Client) Close() error { c.hub.unregister <- c return c.conn.Close() } + +// ReconnectConfig holds configuration for the reconnecting WebSocket client. +type ReconnectConfig struct { + // URL is the WebSocket server URL to connect to. + URL string + + // InitialBackoff is the delay before the first reconnection attempt. + // Defaults to 1 second. + InitialBackoff time.Duration + + // MaxBackoff is the maximum delay between reconnection attempts. + // Defaults to 30 seconds. + MaxBackoff time.Duration + + // BackoffMultiplier controls exponential growth of the backoff. + // Defaults to 2.0. + BackoffMultiplier float64 + + // MaxRetries is the maximum number of consecutive reconnection attempts. + // Zero means unlimited retries. + MaxRetries int + + // OnConnect is called when the client successfully connects. + OnConnect func() + + // OnDisconnect is called when the client loses its connection. + OnDisconnect func() + + // OnReconnect is called when the client successfully reconnects + // after a disconnection. The attempt count is passed in. + OnReconnect func(attempt int) + + // OnMessage is called when a message is received from the server. + OnMessage func(msg Message) + + // Dialer is the WebSocket dialer to use. Defaults to websocket.DefaultDialer. + Dialer *websocket.Dialer + + // Headers are additional HTTP headers to send during the handshake. + Headers http.Header +} + +// ReconnectingClient is a WebSocket client that automatically reconnects +// with exponential backoff when the connection drops. +type ReconnectingClient struct { + config ReconnectConfig + conn *websocket.Conn + send chan []byte + state ConnectionState + mu sync.RWMutex + done chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// NewReconnectingClient creates a new reconnecting WebSocket client. +func NewReconnectingClient(config ReconnectConfig) *ReconnectingClient { + if config.InitialBackoff <= 0 { + config.InitialBackoff = 1 * time.Second + } + if config.MaxBackoff <= 0 { + config.MaxBackoff = 30 * time.Second + } + if config.BackoffMultiplier <= 0 { + config.BackoffMultiplier = 2.0 + } + if config.Dialer == nil { + config.Dialer = websocket.DefaultDialer + } + + return &ReconnectingClient{ + config: config, + send: make(chan []byte, 256), + state: StateDisconnected, + done: make(chan struct{}), + } +} + +// Connect starts the reconnecting client. It blocks until the context is +// cancelled. The client will automatically reconnect on connection loss. +func (rc *ReconnectingClient) Connect(ctx context.Context) error { + rc.ctx, rc.cancel = context.WithCancel(ctx) + defer rc.cancel() + + attempt := 0 + wasConnected := false + + for { + select { + case <-rc.ctx.Done(): + rc.setState(StateDisconnected) + return rc.ctx.Err() + default: + } + + rc.setState(StateConnecting) + attempt++ + + conn, _, err := rc.config.Dialer.DialContext(rc.ctx, rc.config.URL, rc.config.Headers) + if err != nil { + if rc.config.MaxRetries > 0 && attempt > rc.config.MaxRetries { + rc.setState(StateDisconnected) + return fmt.Errorf("max retries (%d) exceeded: %w", rc.config.MaxRetries, err) + } + backoff := rc.calculateBackoff(attempt) + select { + case <-rc.ctx.Done(): + rc.setState(StateDisconnected) + return rc.ctx.Err() + case <-time.After(backoff): + continue + } + } + + // Connected successfully + rc.mu.Lock() + rc.conn = conn + rc.mu.Unlock() + rc.setState(StateConnected) + + if wasConnected { + if rc.config.OnReconnect != nil { + rc.config.OnReconnect(attempt) + } + } else { + if rc.config.OnConnect != nil { + rc.config.OnConnect() + } + } + + // Reset attempt counter after a successful connection + attempt = 0 + wasConnected = true + + // Run the read loop — blocks until connection drops + rc.readLoop() + + // Connection lost + rc.mu.Lock() + rc.conn = nil + rc.mu.Unlock() + + if rc.config.OnDisconnect != nil { + rc.config.OnDisconnect() + } + } +} + +// Send sends a message to the server. Returns an error if not connected. +func (rc *ReconnectingClient) Send(msg Message) error { + msg.Timestamp = time.Now() + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + rc.mu.RLock() + conn := rc.conn + rc.mu.RUnlock() + + if conn == nil { + return fmt.Errorf("not connected") + } + + rc.mu.Lock() + defer rc.mu.Unlock() + return rc.conn.WriteMessage(websocket.TextMessage, data) +} + +// State returns the current connection state. +func (rc *ReconnectingClient) State() ConnectionState { + rc.mu.RLock() + defer rc.mu.RUnlock() + return rc.state +} + +// Close gracefully shuts down the reconnecting client. +func (rc *ReconnectingClient) Close() error { + if rc.cancel != nil { + rc.cancel() + } + rc.mu.RLock() + conn := rc.conn + rc.mu.RUnlock() + if conn != nil { + return conn.Close() + } + return nil +} + +func (rc *ReconnectingClient) setState(state ConnectionState) { + rc.mu.Lock() + rc.state = state + rc.mu.Unlock() +} + +func (rc *ReconnectingClient) calculateBackoff(attempt int) time.Duration { + backoff := rc.config.InitialBackoff + for i := 1; i < attempt; i++ { + backoff = time.Duration(float64(backoff) * rc.config.BackoffMultiplier) + if backoff > rc.config.MaxBackoff { + backoff = rc.config.MaxBackoff + break + } + } + return backoff +} + +func (rc *ReconnectingClient) readLoop() { + rc.mu.RLock() + conn := rc.conn + rc.mu.RUnlock() + + if conn == nil { + return + } + + for { + _, data, err := conn.ReadMessage() + if err != nil { + return + } + + if rc.config.OnMessage != nil { + var msg Message + if jsonErr := json.Unmarshal(data, &msg); jsonErr == nil { + rc.config.OnMessage(msg) + } + } + } +} diff --git a/ws_test.go b/ws_test.go index 0632568..b116a65 100644 --- a/ws_test.go +++ b/ws_test.go @@ -3,6 +3,7 @@ package ws import ( "context" "encoding/json" + "net" "net/http" "net/http/httptest" "strings" @@ -790,3 +791,1235 @@ func TestMustMarshal(t *testing.T) { }) }) } + +// --- Phase 0: Additional coverage tests --- + +func TestHub_Run_ShutdownClosesClients(t *testing.T) { + t.Run("closes all client send channels on shutdown", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + go hub.Run(ctx) + + // Register clients via the hub's Run loop + client1 := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + client2 := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.register <- client1 + hub.register <- client2 + time.Sleep(20 * time.Millisecond) + + assert.Equal(t, 2, hub.ClientCount()) + + // Cancel context to trigger shutdown + cancel() + time.Sleep(50 * time.Millisecond) + + // Send channels should be closed + _, ok1 := <-client1.send + assert.False(t, ok1, "client1 send channel should be closed") + _, ok2 := <-client2.send + assert.False(t, ok2, "client2 send channel should be closed") + }) +} + +func TestHub_Run_BroadcastToClientWithFullBuffer(t *testing.T) { + t.Run("unregisters client with full send buffer during broadcast", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + // Create a client with a tiny buffer that will overflow + slowClient := &Client{ + hub: hub, + send: make(chan []byte, 1), // Very small buffer + subscriptions: make(map[string]bool), + } + + hub.register <- slowClient + time.Sleep(20 * time.Millisecond) + assert.Equal(t, 1, hub.ClientCount()) + + // Fill the client's send buffer + slowClient.send <- []byte("blocking") + + // Broadcast should trigger the overflow path + err := hub.Broadcast(Message{Type: TypeEvent, Data: "overflow"}) + require.NoError(t, err) + + // Wait for the unregister goroutine to fire + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, 0, hub.ClientCount(), "slow client should be unregistered") + }) +} + +func TestHub_SendToChannel_ClientBufferFull(t *testing.T) { + t.Run("skips client with full send buffer", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + send: make(chan []byte, 1), // Tiny buffer + subscriptions: make(map[string]bool), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + hub.Subscribe(client, "test-channel") + + // Fill the client buffer + client.send <- []byte("blocking") + + // SendToChannel should not block; it skips the full client + err := hub.SendToChannel("test-channel", Message{Type: TypeEvent, Data: "overflow"}) + assert.NoError(t, err) + }) +} + +func TestHub_Broadcast_MarshalError(t *testing.T) { + t.Run("returns error for unmarshalable message", func(t *testing.T) { + hub := NewHub() + + // Channels cannot be marshalled to JSON + msg := Message{ + Type: TypeEvent, + Data: make(chan int), + } + + err := hub.Broadcast(msg) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to marshal message") + }) +} + +func TestHub_SendToChannel_MarshalError(t *testing.T) { + t.Run("returns error for unmarshalable message", func(t *testing.T) { + hub := NewHub() + + msg := Message{ + Type: TypeEvent, + Data: make(chan int), + } + + err := hub.SendToChannel("any-channel", msg) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to marshal message") + }) +} + +func TestHub_Handler_UpgradeError(t *testing.T) { + t.Run("returns silently on non-websocket request", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + // Make a plain HTTP request (not a WebSocket upgrade) + resp, err := http.Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + + // The handler should have returned an error response + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.Equal(t, 0, hub.ClientCount()) + }) +} + +func TestClient_Close(t *testing.T) { + t.Run("unregisters and closes connection", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ClientCount()) + + // Get the client from the hub + hub.mu.RLock() + var client *Client + for c := range hub.clients { + client = c + break + } + hub.mu.RUnlock() + require.NotNil(t, client) + + // Close via Client.Close() + err = client.Close() + // conn.Close may return an error if already closing, that is acceptable + _ = err + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, hub.ClientCount()) + + // Connection should be closed — writing should fail + _ = conn.Close() // ensure clean up + }) +} + +func TestReadPump_MalformedJSON(t *testing.T) { + t.Run("ignores malformed JSON messages", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Send malformed JSON — should be ignored without disconnecting + err = conn.WriteMessage(websocket.TextMessage, []byte("this is not json")) + require.NoError(t, err) + + // Send a valid subscribe after the bad message — client should still be alive + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + }) +} + +func TestReadPump_SubscribeWithNonStringData(t *testing.T) { + t.Run("ignores subscribe with non-string data", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Send subscribe with numeric data instead of string + err = conn.WriteJSON(map[string]any{ + "type": "subscribe", + "data": 12345, + }) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + // No channels should have been created + assert.Equal(t, 0, hub.ChannelCount()) + }) +} + +func TestReadPump_UnsubscribeWithNonStringData(t *testing.T) { + t.Run("ignores unsubscribe with non-string data", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Subscribe first with valid data + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) + require.NoError(t, err) + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + + // Send unsubscribe with non-string data — should be ignored + err = conn.WriteJSON(map[string]any{ + "type": "unsubscribe", + "data": []string{"test-channel"}, + }) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + // Channel should still have the subscriber + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + }) +} + +func TestReadPump_UnknownMessageType(t *testing.T) { + t.Run("ignores unknown message types without disconnecting", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Send a message with an unknown type + err = conn.WriteJSON(Message{Type: "unknown_type", Data: "anything"}) + require.NoError(t, err) + + // Client should still be connected + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ClientCount()) + }) +} + +func TestWritePump_SendsCloseOnChannelClose(t *testing.T) { + t.Run("sends close message when send channel is closed", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Unregister the client (which closes its send channel) + hub.mu.RLock() + var client *Client + for c := range hub.clients { + client = c + break + } + hub.mu.RUnlock() + + hub.unregister <- client + time.Sleep(50 * time.Millisecond) + + // The client should receive a close message and the connection should end + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _, _, readErr := conn.ReadMessage() + assert.Error(t, readErr, "reading should fail after close") + }) +} + +func TestWritePump_BatchesMessages(t *testing.T) { + t.Run("batches multiple queued messages into a single write", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Get the client + hub.mu.RLock() + var client *Client + for c := range hub.clients { + client = c + break + } + hub.mu.RUnlock() + require.NotNil(t, client) + + // Queue multiple messages rapidly on the send channel directly + msg1 := mustMarshal(Message{Type: TypeEvent, Data: "batch-1", Timestamp: time.Now()}) + msg2 := mustMarshal(Message{Type: TypeEvent, Data: "batch-2", Timestamp: time.Now()}) + msg3 := mustMarshal(Message{Type: TypeEvent, Data: "batch-3", Timestamp: time.Now()}) + client.send <- msg1 + client.send <- msg2 + client.send <- msg3 + + // Read the batched response — it should contain all three messages + // separated by newlines + conn.SetReadDeadline(time.Now().Add(time.Second)) + _, data, readErr := conn.ReadMessage() + require.NoError(t, readErr) + + content := string(data) + assert.Contains(t, content, "batch-1") + // Depending on timing, messages may be batched or separate + // At minimum, the first message should be received + }) +} + +func TestHub_MultipleClientsOnChannel(t *testing.T) { + t.Run("delivers channel messages to all subscribers", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Connect three clients and subscribe them all to the same channel + conns := make([]*websocket.Conn, 3) + for i := range conns { + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + conns[i] = conn + } + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 3, hub.ClientCount()) + + // Subscribe all to "shared" + for _, conn := range conns { + err := conn.WriteJSON(Message{Type: TypeSubscribe, Data: "shared"}) + require.NoError(t, err) + } + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 3, hub.ChannelSubscriberCount("shared")) + + // Send to channel + err := hub.SendToChannel("shared", Message{Type: TypeEvent, Data: "hello all"}) + require.NoError(t, err) + + // All three clients should receive the message + for i, conn := range conns { + conn.SetReadDeadline(time.Now().Add(time.Second)) + var received Message + err := conn.ReadJSON(&received) + require.NoError(t, err, "client %d should receive message", i) + assert.Equal(t, TypeEvent, received.Type) + assert.Equal(t, "hello all", received.Data) + } + }) +} + +func TestHub_ConcurrentSubscribeUnsubscribe(t *testing.T) { + t.Run("handles concurrent subscribe and unsubscribe safely", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + var wg sync.WaitGroup + numClients := 50 + + clients := make([]*Client, numClients) + for i := 0; i < numClients; i++ { + clients[i] = &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.mu.Lock() + hub.clients[clients[i]] = true + hub.mu.Unlock() + } + + // Half subscribe, half unsubscribe concurrently + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + hub.Subscribe(clients[idx], "race-channel") + }(i) + } + wg.Wait() + + // Now concurrently unsubscribe half and subscribe the other half to a new channel + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + if idx%2 == 0 { + hub.Unsubscribe(clients[idx], "race-channel") + } else { + hub.Subscribe(clients[idx], "another-channel") + } + }(i) + } + wg.Wait() + + // Verify: half should remain on race-channel, half should be on another-channel + assert.Equal(t, numClients/2, hub.ChannelSubscriberCount("race-channel")) + assert.Equal(t, numClients/2, hub.ChannelSubscriberCount("another-channel")) + }) +} + +func TestHub_ProcessOutputEndToEnd(t *testing.T) { + t.Run("end-to-end process output via WebSocket", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Subscribe to process channel + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "process:build-42"}) + require.NoError(t, err) + time.Sleep(50 * time.Millisecond) + + // Send lines one at a time with a small delay to avoid batching + lines := []string{"Compiling...", "Linking...", "Done."} + for _, line := range lines { + err = hub.SendProcessOutput("build-42", line) + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) // Allow writePump to flush each individually + } + + // Read all three — messages may arrive individually or batched + // with newline separators. ReadMessage gives raw frames. + var received []Message + for len(received) < 3 { + conn.SetReadDeadline(time.Now().Add(time.Second)) + _, data, readErr := conn.ReadMessage() + require.NoError(t, readErr) + + // A single frame may contain multiple newline-separated JSON objects + parts := strings.Split(strings.TrimSpace(string(data)), "\n") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + var msg Message + err := json.Unmarshal([]byte(part), &msg) + require.NoError(t, err) + received = append(received, msg) + } + } + + require.Len(t, received, 3) + for i, expected := range lines { + assert.Equal(t, TypeProcessOutput, received[i].Type) + assert.Equal(t, "build-42", received[i].ProcessID) + assert.Equal(t, expected, received[i].Data) + } + }) +} + +func TestHub_ProcessStatusEndToEnd(t *testing.T) { + t.Run("end-to-end process status via WebSocket", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Subscribe + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "process:job-7"}) + require.NoError(t, err) + time.Sleep(50 * time.Millisecond) + + // Send status + err = hub.SendProcessStatus("job-7", "exited", 1) + require.NoError(t, err) + + conn.SetReadDeadline(time.Now().Add(time.Second)) + var received Message + err = conn.ReadJSON(&received) + require.NoError(t, err) + assert.Equal(t, TypeProcessStatus, received.Type) + assert.Equal(t, "job-7", received.ProcessID) + + data, ok := received.Data.(map[string]any) + require.True(t, ok) + assert.Equal(t, "exited", data["status"]) + assert.Equal(t, float64(1), data["exitCode"]) + }) +} + +// --- Benchmarks --- + +func BenchmarkBroadcast(b *testing.B) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + // Register 100 clients + for i := 0; i < 100; i++ { + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.register <- client + } + time.Sleep(50 * time.Millisecond) + + msg := Message{Type: TypeEvent, Data: "benchmark"} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = hub.Broadcast(msg) + } +} + +func BenchmarkSendToChannel(b *testing.B) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + // Register 50 clients, all subscribed to the same channel + for i := 0; i < 50; i++ { + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + hub.Subscribe(client, "bench-channel") + } + + msg := Message{Type: TypeEvent, Data: "benchmark"} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = hub.SendToChannel("bench-channel", msg) + } +} + +// --- Phase 1: Connection resilience tests --- + +func TestNewHubWithConfig(t *testing.T) { + t.Run("uses custom heartbeat and pong timeout", func(t *testing.T) { + config := HubConfig{ + HeartbeatInterval: 5 * time.Second, + PongTimeout: 10 * time.Second, + WriteTimeout: 3 * time.Second, + } + hub := NewHubWithConfig(config) + + assert.Equal(t, 5*time.Second, hub.config.HeartbeatInterval) + assert.Equal(t, 10*time.Second, hub.config.PongTimeout) + assert.Equal(t, 3*time.Second, hub.config.WriteTimeout) + }) + + t.Run("applies defaults for zero values", func(t *testing.T) { + hub := NewHubWithConfig(HubConfig{}) + + assert.Equal(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval) + assert.Equal(t, DefaultPongTimeout, hub.config.PongTimeout) + assert.Equal(t, DefaultWriteTimeout, hub.config.WriteTimeout) + }) + + t.Run("applies defaults for negative values", func(t *testing.T) { + hub := NewHubWithConfig(HubConfig{ + HeartbeatInterval: -1, + PongTimeout: -1, + WriteTimeout: -1, + }) + + assert.Equal(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval) + assert.Equal(t, DefaultPongTimeout, hub.config.PongTimeout) + assert.Equal(t, DefaultWriteTimeout, hub.config.WriteTimeout) + }) +} + +func TestDefaultHubConfig(t *testing.T) { + t.Run("returns sensible defaults", func(t *testing.T) { + config := DefaultHubConfig() + + assert.Equal(t, 30*time.Second, config.HeartbeatInterval) + assert.Equal(t, 60*time.Second, config.PongTimeout) + assert.Equal(t, 10*time.Second, config.WriteTimeout) + assert.Nil(t, config.OnConnect) + assert.Nil(t, config.OnDisconnect) + }) +} + +func TestHub_ConnectionCallbacks(t *testing.T) { + t.Run("calls OnConnect when client connects", func(t *testing.T) { + connectCalled := make(chan *Client, 1) + + hub := NewHubWithConfig(HubConfig{ + OnConnect: func(client *Client) { + connectCalled <- client + }, + }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + select { + case c := <-connectCalled: + assert.NotNil(t, c) + case <-time.After(time.Second): + t.Fatal("OnConnect callback should have been called") + } + }) + + t.Run("calls OnDisconnect when client disconnects", func(t *testing.T) { + disconnectCalled := make(chan *Client, 1) + + hub := NewHubWithConfig(HubConfig{ + OnDisconnect: func(client *Client) { + disconnectCalled <- client + }, + }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + // Close the connection to trigger disconnect + conn.Close() + + select { + case c := <-disconnectCalled: + assert.NotNil(t, c) + case <-time.After(time.Second): + t.Fatal("OnDisconnect callback should have been called") + } + }) + + t.Run("does not call OnDisconnect for unknown client", func(t *testing.T) { + disconnectCalled := make(chan struct{}, 1) + + hub := NewHubWithConfig(HubConfig{ + OnDisconnect: func(client *Client) { + disconnectCalled <- struct{}{} + }, + }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + // Send an unregister for a client that was never registered + unknownClient := &Client{ + hub: hub, + send: make(chan []byte, 1), + subscriptions: make(map[string]bool), + } + hub.unregister <- unknownClient + + select { + case <-disconnectCalled: + t.Fatal("OnDisconnect should not be called for unknown client") + case <-time.After(100 * time.Millisecond): + // Good — callback was not called + } + }) +} + +func TestHub_CustomHeartbeat(t *testing.T) { + t.Run("uses custom heartbeat interval for server pings", func(t *testing.T) { + // Use a very short heartbeat to test it actually fires + hub := NewHubWithConfig(HubConfig{ + HeartbeatInterval: 100 * time.Millisecond, + PongTimeout: 500 * time.Millisecond, + WriteTimeout: 500 * time.Millisecond, + }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + pingReceived := make(chan struct{}, 1) + dialer := websocket.Dialer{} + conn, _, err := dialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + conn.SetPingHandler(func(appData string) error { + select { + case pingReceived <- struct{}{}: + default: + } + // Respond with pong + return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second)) + }) + + // Start reading in background to process control frames + go func() { + for { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + }() + + select { + case <-pingReceived: + // Server ping received within custom heartbeat interval + case <-time.After(time.Second): + t.Fatal("should have received server ping within custom heartbeat interval") + } + }) +} + +func TestReconnectingClient_Connect(t *testing.T) { + t.Run("connects to server and receives messages", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + connectCalled := make(chan struct{}, 1) + msgReceived := make(chan Message, 1) + + rc := NewReconnectingClient(ReconnectConfig{ + URL: wsURL, + OnConnect: func() { + select { + case connectCalled <- struct{}{}: + default: + } + }, + OnMessage: func(msg Message) { + select { + case msgReceived <- msg: + default: + } + }, + }) + + // Run Connect in background + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + go rc.Connect(clientCtx) + + // Wait for connect + select { + case <-connectCalled: + // Good + case <-time.After(time.Second): + t.Fatal("OnConnect should have been called") + } + + assert.Equal(t, StateConnected, rc.State()) + + // Wait for client to register + time.Sleep(50 * time.Millisecond) + + // Broadcast a message + err := hub.Broadcast(Message{Type: TypeEvent, Data: "hello"}) + require.NoError(t, err) + + select { + case msg := <-msgReceived: + assert.Equal(t, TypeEvent, msg.Type) + assert.Equal(t, "hello", msg.Data) + case <-time.After(time.Second): + t.Fatal("should have received the broadcast message") + } + + clientCancel() + time.Sleep(50 * time.Millisecond) + }) +} + +func TestReconnectingClient_Reconnect(t *testing.T) { + t.Run("reconnects after server restart", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + go hub.Run(ctx) + + // Use a net.Listener so we control the port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + server := &httptest.Server{ + Listener: listener, + Config: &http.Server{Handler: hub.Handler()}, + } + server.Start() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + addr := listener.Addr().String() + + reconnectCalled := make(chan int, 5) + disconnectCalled := make(chan struct{}, 5) + connectCalled := make(chan struct{}, 5) + + rc := NewReconnectingClient(ReconnectConfig{ + URL: wsURL, + InitialBackoff: 50 * time.Millisecond, + MaxBackoff: 200 * time.Millisecond, + OnConnect: func() { + select { + case connectCalled <- struct{}{}: + default: + } + }, + OnDisconnect: func() { + select { + case disconnectCalled <- struct{}{}: + default: + } + }, + OnReconnect: func(attempt int) { + select { + case reconnectCalled <- attempt: + default: + } + }, + }) + + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + go rc.Connect(clientCtx) + + // Wait for initial connection + select { + case <-connectCalled: + case <-time.After(time.Second): + t.Fatal("initial connection should have succeeded") + } + assert.Equal(t, StateConnected, rc.State()) + + // Shut down the server to simulate disconnect + cancel() + server.Close() + + // Wait for disconnect callback + select { + case <-disconnectCalled: + // Good + case <-time.After(time.Second): + t.Fatal("OnDisconnect should have been called") + } + + // Start a new server on the same address for reconnection + hub2 := NewHub() + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + go hub2.Run(ctx2) + + listener2, err := net.Listen("tcp", addr) + if err != nil { + t.Skipf("could not bind to same port: %v", err) + return + } + + newServer := &httptest.Server{ + Listener: listener2, + Config: &http.Server{Handler: hub2.Handler()}, + } + newServer.Start() + defer newServer.Close() + + // Wait for reconnection + select { + case attempt := <-reconnectCalled: + assert.Greater(t, attempt, 0) + case <-time.After(3 * time.Second): + t.Fatal("OnReconnect should have been called") + } + + assert.Equal(t, StateConnected, rc.State()) + clientCancel() + }) +} + +func TestReconnectingClient_MaxRetries(t *testing.T) { + t.Run("stops after max retries exceeded", func(t *testing.T) { + // Use a URL that will never connect + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", // Should refuse connection + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 50 * time.Millisecond, + MaxRetries: 3, + }) + + errCh := make(chan error, 1) + go func() { + errCh <- rc.Connect(context.Background()) + }() + + select { + case err := <-errCh: + require.Error(t, err) + assert.Contains(t, err.Error(), "max retries (3) exceeded") + case <-time.After(5 * time.Second): + t.Fatal("should have stopped after max retries") + } + + assert.Equal(t, StateDisconnected, rc.State()) + }) +} + +func TestReconnectingClient_Send(t *testing.T) { + t.Run("sends message to server", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + connected := make(chan struct{}, 1) + + rc := NewReconnectingClient(ReconnectConfig{ + URL: wsURL, + OnConnect: func() { + select { + case connected <- struct{}{}: + default: + } + }, + }) + + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + go rc.Connect(clientCtx) + + <-connected + time.Sleep(50 * time.Millisecond) + + // Send a subscribe message + err := rc.Send(Message{ + Type: TypeSubscribe, + Data: "test-channel", + }) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + + clientCancel() + }) + + t.Run("returns error when not connected", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + }) + + err := rc.Send(Message{Type: TypePing}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not connected") + }) + + t.Run("returns error for unmarshalable message", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + }) + // Force a conn to be set so we get past the nil check + // to hit the marshal error first + err := rc.Send(Message{Type: TypeEvent, Data: make(chan int)}) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to marshal message") + }) +} + +func TestReconnectingClient_Close(t *testing.T) { + t.Run("stops reconnection loop", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + connected := make(chan struct{}, 1) + + rc := NewReconnectingClient(ReconnectConfig{ + URL: wsURL, + OnConnect: func() { + select { + case connected <- struct{}{}: + default: + } + }, + }) + + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + + done := make(chan error, 1) + go func() { + done <- rc.Connect(clientCtx) + }() + + <-connected + + err := rc.Close() + assert.NoError(t, err) + + select { + case <-done: + // Good — Connect returned + case <-time.After(time.Second): + t.Fatal("Connect should have returned after Close") + } + }) + + t.Run("close when not connected is safe", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + }) + + err := rc.Close() + assert.NoError(t, err) + }) +} + +func TestReconnectingClient_ExponentialBackoff(t *testing.T) { + t.Run("calculates correct backoff values", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + InitialBackoff: 100 * time.Millisecond, + MaxBackoff: 1 * time.Second, + BackoffMultiplier: 2.0, + }) + + // attempt 1: 100ms + assert.Equal(t, 100*time.Millisecond, rc.calculateBackoff(1)) + // attempt 2: 200ms + assert.Equal(t, 200*time.Millisecond, rc.calculateBackoff(2)) + // attempt 3: 400ms + assert.Equal(t, 400*time.Millisecond, rc.calculateBackoff(3)) + // attempt 4: 800ms + assert.Equal(t, 800*time.Millisecond, rc.calculateBackoff(4)) + // attempt 5: capped at 1s + assert.Equal(t, 1*time.Second, rc.calculateBackoff(5)) + // attempt 10: still capped at 1s + assert.Equal(t, 1*time.Second, rc.calculateBackoff(10)) + }) +} + +func TestReconnectingClient_Defaults(t *testing.T) { + t.Run("applies defaults for zero config values", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://localhost:1", + }) + + assert.Equal(t, 1*time.Second, rc.config.InitialBackoff) + assert.Equal(t, 30*time.Second, rc.config.MaxBackoff) + assert.Equal(t, 2.0, rc.config.BackoffMultiplier) + assert.NotNil(t, rc.config.Dialer) + }) +} + +func TestReconnectingClient_ContextCancel(t *testing.T) { + t.Run("returns context error on cancel during backoff", func(t *testing.T) { + rc := NewReconnectingClient(ReconnectConfig{ + URL: "ws://127.0.0.1:1", + InitialBackoff: 10 * time.Second, // Long backoff + }) + + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan error, 1) + go func() { + done <- rc.Connect(ctx) + }() + + // Allow first dial attempt to fail + time.Sleep(200 * time.Millisecond) + + // Cancel during backoff + cancel() + + select { + case err := <-done: + require.Error(t, err) + assert.Equal(t, context.Canceled, err) + case <-time.After(2 * time.Second): + t.Fatal("Connect should have returned after context cancel") + } + }) +} + +func TestConnectionState(t *testing.T) { + t.Run("state constants are distinct", func(t *testing.T) { + assert.NotEqual(t, StateDisconnected, StateConnecting) + assert.NotEqual(t, StateConnecting, StateConnected) + assert.NotEqual(t, StateDisconnected, StateConnected) + }) +} +