go-ws/ws_test.go
Snider 3d6f14cd91
All checks were successful
Security Scan / security (push) Successful in 9s
Test / test (push) Successful in 3m49s
feat: modernise to Go 1.26 iterators and stdlib helpers
Add AllClients, AllChannels, AllSubscriptions iter.Seq iterators.
Use slices.Collect(maps.Keys()) in SendToChannel/Subscriptions,
range-over-int in calculateBackoff and benchmarks.

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 05:32:08 +00:00

2512 lines
64 KiB
Go

// SPDX-Licence-Identifier: EUPL-1.2
package ws
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// wsURL converts an httptest server URL to a WebSocket URL.
func wsURL(server *httptest.Server) string {
return "ws" + strings.TrimPrefix(server.URL, "http")
}
func TestNewHub(t *testing.T) {
t.Run("creates hub with initialized maps", func(t *testing.T) {
hub := NewHub()
require.NotNil(t, hub)
assert.NotNil(t, hub.clients)
assert.NotNil(t, hub.broadcast)
assert.NotNil(t, hub.register)
assert.NotNil(t, hub.unregister)
assert.NotNil(t, hub.channels)
})
}
func TestHub_Run(t *testing.T) {
t.Run("stops on context cancel", func(t *testing.T) {
hub := NewHub()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
hub.Run(ctx)
close(done)
}()
cancel()
select {
case <-done:
// Good - hub stopped
case <-time.After(time.Second):
t.Fatal("hub should have stopped on context cancel")
}
})
}
func TestHub_Broadcast(t *testing.T) {
t.Run("marshals message with timestamp", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
msg := Message{
Type: TypeEvent,
Data: "test data",
}
err := hub.Broadcast(msg)
require.NoError(t, err)
})
t.Run("returns error when channel full", func(t *testing.T) {
hub := NewHub()
// Fill the broadcast channel
for range 256 {
hub.broadcast <- []byte("test")
}
err := hub.Broadcast(Message{Type: TypeEvent})
assert.Error(t, err)
assert.Contains(t, err.Error(), "broadcast channel full")
})
}
func TestHub_Stats(t *testing.T) {
t.Run("returns empty stats for new hub", func(t *testing.T) {
hub := NewHub()
stats := hub.Stats()
assert.Equal(t, 0, stats.Clients)
assert.Equal(t, 0, stats.Channels)
})
t.Run("tracks client and channel counts", func(t *testing.T) {
hub := NewHub()
// Manually add clients for testing
hub.mu.Lock()
client1 := &Client{subscriptions: make(map[string]bool)}
client2 := &Client{subscriptions: make(map[string]bool)}
hub.clients[client1] = true
hub.clients[client2] = true
hub.channels["test-channel"] = make(map[*Client]bool)
hub.mu.Unlock()
stats := hub.Stats()
assert.Equal(t, 2, stats.Clients)
assert.Equal(t, 1, stats.Channels)
})
}
func TestHub_ClientCount(t *testing.T) {
t.Run("returns zero for empty hub", func(t *testing.T) {
hub := NewHub()
assert.Equal(t, 0, hub.ClientCount())
})
t.Run("counts connected clients", func(t *testing.T) {
hub := NewHub()
hub.mu.Lock()
hub.clients[&Client{}] = true
hub.clients[&Client{}] = true
hub.mu.Unlock()
assert.Equal(t, 2, hub.ClientCount())
})
}
func TestHub_ChannelCount(t *testing.T) {
t.Run("returns zero for empty hub", func(t *testing.T) {
hub := NewHub()
assert.Equal(t, 0, hub.ChannelCount())
})
t.Run("counts active channels", func(t *testing.T) {
hub := NewHub()
hub.mu.Lock()
hub.channels["channel1"] = make(map[*Client]bool)
hub.channels["channel2"] = make(map[*Client]bool)
hub.mu.Unlock()
assert.Equal(t, 2, hub.ChannelCount())
})
}
func TestHub_ChannelSubscriberCount(t *testing.T) {
t.Run("returns zero for non-existent channel", func(t *testing.T) {
hub := NewHub()
assert.Equal(t, 0, hub.ChannelSubscriberCount("non-existent"))
})
t.Run("counts subscribers in channel", func(t *testing.T) {
hub := NewHub()
hub.mu.Lock()
hub.channels["test-channel"] = make(map[*Client]bool)
hub.channels["test-channel"][&Client{}] = true
hub.channels["test-channel"][&Client{}] = true
hub.mu.Unlock()
assert.Equal(t, 2, hub.ChannelSubscriberCount("test-channel"))
})
}
func TestHub_Subscribe(t *testing.T) {
t.Run("adds client to channel", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
subscriptions: make(map[string]bool),
}
hub.mu.Lock()
hub.clients[client] = true
hub.mu.Unlock()
hub.Subscribe(client, "test-channel")
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
assert.True(t, client.subscriptions["test-channel"])
})
t.Run("creates channel if not exists", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
subscriptions: make(map[string]bool),
}
hub.Subscribe(client, "new-channel")
hub.mu.RLock()
_, exists := hub.channels["new-channel"]
hub.mu.RUnlock()
assert.True(t, exists)
})
}
func TestHub_Unsubscribe(t *testing.T) {
t.Run("removes client from channel", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
subscriptions: make(map[string]bool),
}
hub.Subscribe(client, "test-channel")
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
hub.Unsubscribe(client, "test-channel")
assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel"))
assert.False(t, client.subscriptions["test-channel"])
})
t.Run("cleans up empty channels", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
subscriptions: make(map[string]bool),
}
hub.Subscribe(client, "temp-channel")
hub.Unsubscribe(client, "temp-channel")
hub.mu.RLock()
_, exists := hub.channels["temp-channel"]
hub.mu.RUnlock()
assert.False(t, exists, "empty channel should be removed")
})
t.Run("handles non-existent channel gracefully", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
subscriptions: make(map[string]bool),
}
// Should not panic
hub.Unsubscribe(client, "non-existent")
})
}
func TestHub_SendToChannel(t *testing.T) {
t.Run("sends to channel subscribers", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.mu.Lock()
hub.clients[client] = true
hub.mu.Unlock()
hub.Subscribe(client, "test-channel")
err := hub.SendToChannel("test-channel", Message{
Type: TypeEvent,
Data: "test",
})
require.NoError(t, err)
select {
case msg := <-client.send:
var received Message
err := json.Unmarshal(msg, &received)
require.NoError(t, err)
assert.Equal(t, TypeEvent, received.Type)
assert.Equal(t, "test-channel", received.Channel)
case <-time.After(time.Second):
t.Fatal("expected message on client send channel")
}
})
t.Run("returns nil for non-existent channel", func(t *testing.T) {
hub := NewHub()
err := hub.SendToChannel("non-existent", Message{Type: TypeEvent})
assert.NoError(t, err, "should not error for non-existent channel")
})
}
func TestHub_SendProcessOutput(t *testing.T) {
t.Run("sends output to process channel", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.mu.Lock()
hub.clients[client] = true
hub.mu.Unlock()
hub.Subscribe(client, "process:proc-1")
err := hub.SendProcessOutput("proc-1", "hello world")
require.NoError(t, err)
select {
case msg := <-client.send:
var received Message
err := json.Unmarshal(msg, &received)
require.NoError(t, err)
assert.Equal(t, TypeProcessOutput, received.Type)
assert.Equal(t, "proc-1", received.ProcessID)
assert.Equal(t, "hello world", received.Data)
case <-time.After(time.Second):
t.Fatal("expected message on client send channel")
}
})
}
func TestHub_SendProcessStatus(t *testing.T) {
t.Run("sends status to process channel", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.mu.Lock()
hub.clients[client] = true
hub.mu.Unlock()
hub.Subscribe(client, "process:proc-1")
err := hub.SendProcessStatus("proc-1", "exited", 0)
require.NoError(t, err)
select {
case msg := <-client.send:
var received Message
err := json.Unmarshal(msg, &received)
require.NoError(t, err)
assert.Equal(t, TypeProcessStatus, received.Type)
assert.Equal(t, "proc-1", received.ProcessID)
data, ok := received.Data.(map[string]any)
require.True(t, ok)
assert.Equal(t, "exited", data["status"])
assert.Equal(t, float64(0), data["exitCode"])
case <-time.After(time.Second):
t.Fatal("expected message on client send channel")
}
})
}
func TestHub_SendError(t *testing.T) {
t.Run("broadcasts error message", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client
// Give time for registration
time.Sleep(10 * time.Millisecond)
err := hub.SendError("something went wrong")
require.NoError(t, err)
select {
case msg := <-client.send:
var received Message
err := json.Unmarshal(msg, &received)
require.NoError(t, err)
assert.Equal(t, TypeError, received.Type)
assert.Equal(t, "something went wrong", received.Data)
case <-time.After(time.Second):
t.Fatal("expected error message on client send channel")
}
})
}
func TestHub_SendEvent(t *testing.T) {
t.Run("broadcasts event message", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client
time.Sleep(10 * time.Millisecond)
err := hub.SendEvent("user_joined", map[string]string{"user": "alice"})
require.NoError(t, err)
select {
case msg := <-client.send:
var received Message
err := json.Unmarshal(msg, &received)
require.NoError(t, err)
assert.Equal(t, TypeEvent, received.Type)
data, ok := received.Data.(map[string]any)
require.True(t, ok)
assert.Equal(t, "user_joined", data["event"])
case <-time.After(time.Second):
t.Fatal("expected event message on client send channel")
}
})
}
func TestClient_Subscriptions(t *testing.T) {
t.Run("returns copy of subscriptions", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
subscriptions: make(map[string]bool),
}
hub.Subscribe(client, "channel1")
hub.Subscribe(client, "channel2")
subs := client.Subscriptions()
assert.Len(t, subs, 2)
assert.Contains(t, subs, "channel1")
assert.Contains(t, subs, "channel2")
})
}
func TestClient_AllSubscriptions(t *testing.T) {
t.Run("returns iterator over subscriptions", func(t *testing.T) {
client := &Client{subscriptions: make(map[string]bool)}
client.subscriptions["sub1"] = true
client.subscriptions["sub2"] = true
subs := slices.Collect(client.AllSubscriptions())
assert.Len(t, subs, 2)
assert.Contains(t, subs, "sub1")
assert.Contains(t, subs, "sub2")
})
}
func TestHub_AllClients(t *testing.T) {
t.Run("returns iterator over all clients", func(t *testing.T) {
hub := NewHub()
client1 := &Client{subscriptions: make(map[string]bool)}
client2 := &Client{subscriptions: make(map[string]bool)}
hub.mu.Lock()
hub.clients[client1] = true
hub.clients[client2] = true
hub.mu.Unlock()
clients := slices.Collect(hub.AllClients())
assert.Len(t, clients, 2)
assert.Contains(t, clients, client1)
assert.Contains(t, clients, client2)
})
}
func TestHub_AllChannels(t *testing.T) {
t.Run("returns iterator over all active channels", func(t *testing.T) {
hub := NewHub()
hub.mu.Lock()
hub.channels["ch1"] = make(map[*Client]bool)
hub.channels["ch2"] = make(map[*Client]bool)
hub.mu.Unlock()
channels := slices.Collect(hub.AllChannels())
assert.Len(t, channels, 2)
assert.Contains(t, channels, "ch1")
assert.Contains(t, channels, "ch2")
})
}
func TestMessage_JSON(t *testing.T) {
t.Run("marshals correctly", func(t *testing.T) {
msg := Message{
Type: TypeProcessOutput,
Channel: "process:1",
ProcessID: "1",
Data: "output line",
Timestamp: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
data, err := json.Marshal(msg)
require.NoError(t, err)
assert.Contains(t, string(data), `"type":"process_output"`)
assert.Contains(t, string(data), `"channel":"process:1"`)
assert.Contains(t, string(data), `"processId":"1"`)
assert.Contains(t, string(data), `"data":"output line"`)
})
t.Run("unmarshals correctly", func(t *testing.T) {
jsonStr := `{"type":"subscribe","data":"channel:test"}`
var msg Message
err := json.Unmarshal([]byte(jsonStr), &msg)
require.NoError(t, err)
assert.Equal(t, TypeSubscribe, msg.Type)
assert.Equal(t, "channel:test", msg.Data)
})
}
func TestHub_WebSocketHandler(t *testing.T) {
t.Run("upgrades connection and registers client", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
// Give time for registration
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
})
t.Run("handles subscribe message", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
// Send subscribe message
subscribeMsg := Message{
Type: TypeSubscribe,
Data: "test-channel",
}
err = conn.WriteJSON(subscribeMsg)
require.NoError(t, err)
// Give time for subscription
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
})
t.Run("handles unsubscribe message", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
// Subscribe first
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
// Unsubscribe
err = conn.WriteJSON(Message{Type: TypeUnsubscribe, Data: "test-channel"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel"))
})
t.Run("responds to ping with pong", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
// Give time for registration
time.Sleep(50 * time.Millisecond)
// Send ping
err = conn.WriteJSON(Message{Type: TypePing})
require.NoError(t, err)
// Read pong response
var response Message
conn.SetReadDeadline(time.Now().Add(time.Second))
err = conn.ReadJSON(&response)
require.NoError(t, err)
assert.Equal(t, TypePong, response.Type)
})
t.Run("broadcasts messages to clients", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
// Give time for registration
time.Sleep(50 * time.Millisecond)
// Broadcast a message
err = hub.Broadcast(Message{
Type: TypeEvent,
Data: "broadcast test",
})
require.NoError(t, err)
// Read the broadcast
var response Message
conn.SetReadDeadline(time.Now().Add(time.Second))
err = conn.ReadJSON(&response)
require.NoError(t, err)
assert.Equal(t, TypeEvent, response.Type)
assert.Equal(t, "broadcast test", response.Data)
})
t.Run("unregisters client on connection close", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
// Wait for registration
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
// Close connection
conn.Close()
// Wait for unregistration
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, hub.ClientCount())
})
t.Run("removes client from channels on disconnect", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
// Subscribe to channel
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
// Close connection
conn.Close()
time.Sleep(50 * time.Millisecond)
// Channel should be cleaned up
assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel"))
})
}
func TestHub_Concurrency(t *testing.T) {
t.Run("handles concurrent subscriptions", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
var wg sync.WaitGroup
numClients := 100
for i := range numClients {
wg.Add(1)
go func(id int) {
defer wg.Done()
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.mu.Lock()
hub.clients[client] = true
hub.mu.Unlock()
hub.Subscribe(client, "shared-channel")
hub.Subscribe(client, "shared-channel") // Double subscribe should be safe
}(i)
}
wg.Wait()
assert.Equal(t, numClients, hub.ChannelSubscriberCount("shared-channel"))
})
t.Run("handles concurrent broadcasts", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
client := &Client{
hub: hub,
send: make(chan []byte, 1000),
subscriptions: make(map[string]bool),
}
hub.register <- client
time.Sleep(10 * time.Millisecond)
var wg sync.WaitGroup
numBroadcasts := 100
for i := range numBroadcasts {
wg.Add(1)
go func(id int) {
defer wg.Done()
_ = hub.Broadcast(Message{
Type: TypeEvent,
Data: id,
})
}(i)
}
wg.Wait()
// Give time for broadcasts to be delivered
time.Sleep(100 * time.Millisecond)
// Count received messages
received := 0
timeout := time.After(100 * time.Millisecond)
loop:
for {
select {
case <-client.send:
received++
case <-timeout:
break loop
}
}
// All or most broadcasts should be received
assert.GreaterOrEqual(t, received, numBroadcasts-10, "should receive most broadcasts")
})
}
func TestHub_HandleWebSocket(t *testing.T) {
t.Run("alias works same as Handler", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
// Test with HandleWebSocket directly
server := httptest.NewServer(http.HandlerFunc(hub.HandleWebSocket))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
})
}
func TestMustMarshal(t *testing.T) {
t.Run("marshals valid data", func(t *testing.T) {
data := mustMarshal(Message{Type: TypePong})
assert.Contains(t, string(data), "pong")
})
t.Run("handles unmarshalable data without panic", func(t *testing.T) {
// Create a channel which cannot be marshaled
// This should not panic, even if it returns nil
ch := make(chan int)
assert.NotPanics(t, func() {
_ = mustMarshal(ch)
})
})
}
// --- Phase 0: Additional coverage tests ---
func TestHub_Run_ShutdownClosesClients(t *testing.T) {
t.Run("closes all client send channels on shutdown", func(t *testing.T) {
hub := NewHub()
ctx, cancel := context.WithCancel(context.Background())
go hub.Run(ctx)
// Register clients via the hub's Run loop
client1 := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
client2 := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client1
hub.register <- client2
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 2, hub.ClientCount())
// Cancel context to trigger shutdown
cancel()
time.Sleep(50 * time.Millisecond)
// Send channels should be closed
_, ok1 := <-client1.send
assert.False(t, ok1, "client1 send channel should be closed")
_, ok2 := <-client2.send
assert.False(t, ok2, "client2 send channel should be closed")
})
}
func TestHub_Run_BroadcastToClientWithFullBuffer(t *testing.T) {
t.Run("unregisters client with full send buffer during broadcast", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
// Create a client with a tiny buffer that will overflow
slowClient := &Client{
hub: hub,
send: make(chan []byte, 1), // Very small buffer
subscriptions: make(map[string]bool),
}
hub.register <- slowClient
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
// Fill the client's send buffer
slowClient.send <- []byte("blocking")
// Broadcast should trigger the overflow path
err := hub.Broadcast(Message{Type: TypeEvent, Data: "overflow"})
require.NoError(t, err)
// Wait for the unregister goroutine to fire
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 0, hub.ClientCount(), "slow client should be unregistered")
})
}
func TestHub_SendToChannel_ClientBufferFull(t *testing.T) {
t.Run("skips client with full send buffer", func(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
send: make(chan []byte, 1), // Tiny buffer
subscriptions: make(map[string]bool),
}
hub.mu.Lock()
hub.clients[client] = true
hub.mu.Unlock()
hub.Subscribe(client, "test-channel")
// Fill the client buffer
client.send <- []byte("blocking")
// SendToChannel should not block; it skips the full client
err := hub.SendToChannel("test-channel", Message{Type: TypeEvent, Data: "overflow"})
assert.NoError(t, err)
})
}
func TestHub_Broadcast_MarshalError(t *testing.T) {
t.Run("returns error for unmarshalable message", func(t *testing.T) {
hub := NewHub()
// Channels cannot be marshalled to JSON
msg := Message{
Type: TypeEvent,
Data: make(chan int),
}
err := hub.Broadcast(msg)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to marshal message")
})
}
func TestHub_SendToChannel_MarshalError(t *testing.T) {
t.Run("returns error for unmarshalable message", func(t *testing.T) {
hub := NewHub()
msg := Message{
Type: TypeEvent,
Data: make(chan int),
}
err := hub.SendToChannel("any-channel", msg)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to marshal message")
})
}
func TestHub_Handler_UpgradeError(t *testing.T) {
t.Run("returns silently on non-websocket request", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
// Make a plain HTTP request (not a WebSocket upgrade)
resp, err := http.Get(server.URL)
require.NoError(t, err)
defer resp.Body.Close()
// The handler should have returned an error response
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.Equal(t, 0, hub.ClientCount())
})
}
func TestClient_Close(t *testing.T) {
t.Run("unregisters and closes connection", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
// Get the client from the hub
hub.mu.RLock()
var client *Client
for c := range hub.clients {
client = c
break
}
hub.mu.RUnlock()
require.NotNil(t, client)
// Close via Client.Close()
err = client.Close()
// conn.Close may return an error if already closing, that is acceptable
_ = err
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, hub.ClientCount())
// Connection should be closed — writing should fail
_ = conn.Close() // ensure clean up
})
}
func TestReadPump_MalformedJSON(t *testing.T) {
t.Run("ignores malformed JSON messages", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Send malformed JSON — should be ignored without disconnecting
err = conn.WriteMessage(websocket.TextMessage, []byte("this is not json"))
require.NoError(t, err)
// Send a valid subscribe after the bad message — client should still be alive
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
})
}
func TestReadPump_SubscribeWithNonStringData(t *testing.T) {
t.Run("ignores subscribe with non-string data", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Send subscribe with numeric data instead of string
err = conn.WriteJSON(map[string]any{
"type": "subscribe",
"data": 12345,
})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
// No channels should have been created
assert.Equal(t, 0, hub.ChannelCount())
})
}
func TestReadPump_UnsubscribeWithNonStringData(t *testing.T) {
t.Run("ignores unsubscribe with non-string data", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Subscribe first with valid data
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
// Send unsubscribe with non-string data — should be ignored
err = conn.WriteJSON(map[string]any{
"type": "unsubscribe",
"data": []string{"test-channel"},
})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
// Channel should still have the subscriber
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
})
}
func TestReadPump_UnknownMessageType(t *testing.T) {
t.Run("ignores unknown message types without disconnecting", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Send a message with an unknown type
err = conn.WriteJSON(Message{Type: "unknown_type", Data: "anything"})
require.NoError(t, err)
// Client should still be connected
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
})
}
func TestWritePump_SendsCloseOnChannelClose(t *testing.T) {
t.Run("sends close message when send channel is closed", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Unregister the client (which closes its send channel)
hub.mu.RLock()
var client *Client
for c := range hub.clients {
client = c
break
}
hub.mu.RUnlock()
hub.unregister <- client
time.Sleep(50 * time.Millisecond)
// The client should receive a close message and the connection should end
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
_, _, readErr := conn.ReadMessage()
assert.Error(t, readErr, "reading should fail after close")
})
}
func TestWritePump_BatchesMessages(t *testing.T) {
t.Run("batches multiple queued messages into a single write", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Get the client
hub.mu.RLock()
var client *Client
for c := range hub.clients {
client = c
break
}
hub.mu.RUnlock()
require.NotNil(t, client)
// Queue multiple messages rapidly on the send channel directly
msg1 := mustMarshal(Message{Type: TypeEvent, Data: "batch-1", Timestamp: time.Now()})
msg2 := mustMarshal(Message{Type: TypeEvent, Data: "batch-2", Timestamp: time.Now()})
msg3 := mustMarshal(Message{Type: TypeEvent, Data: "batch-3", Timestamp: time.Now()})
client.send <- msg1
client.send <- msg2
client.send <- msg3
// Read the batched response — it should contain all three messages
// separated by newlines
conn.SetReadDeadline(time.Now().Add(time.Second))
_, data, readErr := conn.ReadMessage()
require.NoError(t, readErr)
content := string(data)
assert.Contains(t, content, "batch-1")
// Depending on timing, messages may be batched or separate
// At minimum, the first message should be received
})
}
func TestHub_MultipleClientsOnChannel(t *testing.T) {
t.Run("delivers channel messages to all subscribers", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
// Connect three clients and subscribe them all to the same channel
conns := make([]*websocket.Conn, 3)
for i := range conns {
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
conns[i] = conn
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 3, hub.ClientCount())
// Subscribe all to "shared"
for _, conn := range conns {
err := conn.WriteJSON(Message{Type: TypeSubscribe, Data: "shared"})
require.NoError(t, err)
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 3, hub.ChannelSubscriberCount("shared"))
// Send to channel
err := hub.SendToChannel("shared", Message{Type: TypeEvent, Data: "hello all"})
require.NoError(t, err)
// All three clients should receive the message
for i, conn := range conns {
conn.SetReadDeadline(time.Now().Add(time.Second))
var received Message
err := conn.ReadJSON(&received)
require.NoError(t, err, "client %d should receive message", i)
assert.Equal(t, TypeEvent, received.Type)
assert.Equal(t, "hello all", received.Data)
}
})
}
func TestHub_ConcurrentSubscribeUnsubscribe(t *testing.T) {
t.Run("handles concurrent subscribe and unsubscribe safely", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
var wg sync.WaitGroup
numClients := 50
clients := make([]*Client, numClients)
for i := range numClients {
clients[i] = &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.mu.Lock()
hub.clients[clients[i]] = true
hub.mu.Unlock()
}
// Half subscribe, half unsubscribe concurrently
for i := range numClients {
wg.Add(1)
go func(idx int) {
defer wg.Done()
hub.Subscribe(clients[idx], "race-channel")
}(i)
}
wg.Wait()
// Now concurrently unsubscribe half and subscribe the other half to a new channel
for i := range numClients {
wg.Add(1)
go func(idx int) {
defer wg.Done()
if idx%2 == 0 {
hub.Unsubscribe(clients[idx], "race-channel")
} else {
hub.Subscribe(clients[idx], "another-channel")
}
}(i)
}
wg.Wait()
// Verify: half should remain on race-channel, half should be on another-channel
assert.Equal(t, numClients/2, hub.ChannelSubscriberCount("race-channel"))
assert.Equal(t, numClients/2, hub.ChannelSubscriberCount("another-channel"))
})
}
func TestHub_ProcessOutputEndToEnd(t *testing.T) {
t.Run("end-to-end process output via WebSocket", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Subscribe to process channel
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "process:build-42"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
// Send lines one at a time with a small delay to avoid batching
lines := []string{"Compiling...", "Linking...", "Done."}
for _, line := range lines {
err = hub.SendProcessOutput("build-42", line)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond) // Allow writePump to flush each individually
}
// Read all three — messages may arrive individually or batched
// with newline separators. ReadMessage gives raw frames.
var received []Message
for len(received) < 3 {
conn.SetReadDeadline(time.Now().Add(time.Second))
_, data, readErr := conn.ReadMessage()
require.NoError(t, readErr)
// A single frame may contain multiple newline-separated JSON objects
parts := strings.SplitSeq(strings.TrimSpace(string(data)), "\n")
for part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
var msg Message
err := json.Unmarshal([]byte(part), &msg)
require.NoError(t, err)
received = append(received, msg)
}
}
require.Len(t, received, 3)
for i, expected := range lines {
assert.Equal(t, TypeProcessOutput, received[i].Type)
assert.Equal(t, "build-42", received[i].ProcessID)
assert.Equal(t, expected, received[i].Data)
}
})
}
func TestHub_ProcessStatusEndToEnd(t *testing.T) {
t.Run("end-to-end process status via WebSocket", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Subscribe
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "process:job-7"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
// Send status
err = hub.SendProcessStatus("job-7", "exited", 1)
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(time.Second))
var received Message
err = conn.ReadJSON(&received)
require.NoError(t, err)
assert.Equal(t, TypeProcessStatus, received.Type)
assert.Equal(t, "job-7", received.ProcessID)
data, ok := received.Data.(map[string]any)
require.True(t, ok)
assert.Equal(t, "exited", data["status"])
assert.Equal(t, float64(1), data["exitCode"])
})
}
// --- Benchmarks ---
func BenchmarkBroadcast(b *testing.B) {
hub := NewHub()
ctx := b.Context()
go hub.Run(ctx)
// Register 100 clients
for range 100 {
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client
}
time.Sleep(50 * time.Millisecond)
msg := Message{Type: TypeEvent, Data: "benchmark"}
b.ResetTimer()
for range b.N {
_ = hub.Broadcast(msg)
}
}
func BenchmarkSendToChannel(b *testing.B) {
hub := NewHub()
ctx := b.Context()
go hub.Run(ctx)
// Register 50 clients, all subscribed to the same channel
for range 50 {
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.mu.Lock()
hub.clients[client] = true
hub.mu.Unlock()
hub.Subscribe(client, "bench-channel")
}
msg := Message{Type: TypeEvent, Data: "benchmark"}
b.ResetTimer()
for range b.N {
_ = hub.SendToChannel("bench-channel", msg)
}
}
// --- Phase 1: Connection resilience tests ---
func TestNewHubWithConfig(t *testing.T) {
t.Run("uses custom heartbeat and pong timeout", func(t *testing.T) {
config := HubConfig{
HeartbeatInterval: 5 * time.Second,
PongTimeout: 10 * time.Second,
WriteTimeout: 3 * time.Second,
}
hub := NewHubWithConfig(config)
assert.Equal(t, 5*time.Second, hub.config.HeartbeatInterval)
assert.Equal(t, 10*time.Second, hub.config.PongTimeout)
assert.Equal(t, 3*time.Second, hub.config.WriteTimeout)
})
t.Run("applies defaults for zero values", func(t *testing.T) {
hub := NewHubWithConfig(HubConfig{})
assert.Equal(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval)
assert.Equal(t, DefaultPongTimeout, hub.config.PongTimeout)
assert.Equal(t, DefaultWriteTimeout, hub.config.WriteTimeout)
})
t.Run("applies defaults for negative values", func(t *testing.T) {
hub := NewHubWithConfig(HubConfig{
HeartbeatInterval: -1,
PongTimeout: -1,
WriteTimeout: -1,
})
assert.Equal(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval)
assert.Equal(t, DefaultPongTimeout, hub.config.PongTimeout)
assert.Equal(t, DefaultWriteTimeout, hub.config.WriteTimeout)
})
}
func TestDefaultHubConfig(t *testing.T) {
t.Run("returns sensible defaults", func(t *testing.T) {
config := DefaultHubConfig()
assert.Equal(t, 30*time.Second, config.HeartbeatInterval)
assert.Equal(t, 60*time.Second, config.PongTimeout)
assert.Equal(t, 10*time.Second, config.WriteTimeout)
assert.Nil(t, config.OnConnect)
assert.Nil(t, config.OnDisconnect)
})
}
func TestHub_ConnectionCallbacks(t *testing.T) {
t.Run("calls OnConnect when client connects", func(t *testing.T) {
connectCalled := make(chan *Client, 1)
hub := NewHubWithConfig(HubConfig{
OnConnect: func(client *Client) {
connectCalled <- client
},
})
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
select {
case c := <-connectCalled:
assert.NotNil(t, c)
case <-time.After(time.Second):
t.Fatal("OnConnect callback should have been called")
}
})
t.Run("calls OnDisconnect when client disconnects", func(t *testing.T) {
disconnectCalled := make(chan *Client, 1)
hub := NewHubWithConfig(HubConfig{
OnDisconnect: func(client *Client) {
disconnectCalled <- client
},
})
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
// Close the connection to trigger disconnect
conn.Close()
select {
case c := <-disconnectCalled:
assert.NotNil(t, c)
case <-time.After(time.Second):
t.Fatal("OnDisconnect callback should have been called")
}
})
t.Run("does not call OnDisconnect for unknown client", func(t *testing.T) {
disconnectCalled := make(chan struct{}, 1)
hub := NewHubWithConfig(HubConfig{
OnDisconnect: func(client *Client) {
disconnectCalled <- struct{}{}
},
})
ctx := t.Context()
go hub.Run(ctx)
// Send an unregister for a client that was never registered
unknownClient := &Client{
hub: hub,
send: make(chan []byte, 1),
subscriptions: make(map[string]bool),
}
hub.unregister <- unknownClient
select {
case <-disconnectCalled:
t.Fatal("OnDisconnect should not be called for unknown client")
case <-time.After(100 * time.Millisecond):
// Good — callback was not called
}
})
}
func TestHub_CustomHeartbeat(t *testing.T) {
t.Run("uses custom heartbeat interval for server pings", func(t *testing.T) {
// Use a very short heartbeat to test it actually fires
hub := NewHubWithConfig(HubConfig{
HeartbeatInterval: 100 * time.Millisecond,
PongTimeout: 500 * time.Millisecond,
WriteTimeout: 500 * time.Millisecond,
})
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
pingReceived := make(chan struct{}, 1)
dialer := websocket.Dialer{}
conn, _, err := dialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
conn.SetPingHandler(func(appData string) error {
select {
case pingReceived <- struct{}{}:
default:
}
// Respond with pong
return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second))
})
// Start reading in background to process control frames
go func() {
for {
_, _, err := conn.ReadMessage()
if err != nil {
return
}
}
}()
select {
case <-pingReceived:
// Server ping received within custom heartbeat interval
case <-time.After(time.Second):
t.Fatal("should have received server ping within custom heartbeat interval")
}
})
}
func TestReconnectingClient_Connect(t *testing.T) {
t.Run("connects to server and receives messages", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
connectCalled := make(chan struct{}, 1)
msgReceived := make(chan Message, 1)
rc := NewReconnectingClient(ReconnectConfig{
URL: wsURL,
OnConnect: func() {
select {
case connectCalled <- struct{}{}:
default:
}
},
OnMessage: func(msg Message) {
select {
case msgReceived <- msg:
default:
}
},
})
// Run Connect in background
clientCtx, clientCancel := context.WithCancel(context.Background())
defer clientCancel()
go rc.Connect(clientCtx)
// Wait for connect
select {
case <-connectCalled:
// Good
case <-time.After(time.Second):
t.Fatal("OnConnect should have been called")
}
assert.Equal(t, StateConnected, rc.State())
// Wait for client to register
time.Sleep(50 * time.Millisecond)
// Broadcast a message
err := hub.Broadcast(Message{Type: TypeEvent, Data: "hello"})
require.NoError(t, err)
select {
case msg := <-msgReceived:
assert.Equal(t, TypeEvent, msg.Type)
assert.Equal(t, "hello", msg.Data)
case <-time.After(time.Second):
t.Fatal("should have received the broadcast message")
}
clientCancel()
time.Sleep(50 * time.Millisecond)
})
}
func TestReconnectingClient_Reconnect(t *testing.T) {
t.Run("reconnects after server restart", func(t *testing.T) {
hub := NewHub()
ctx, cancel := context.WithCancel(context.Background())
go hub.Run(ctx)
// Use a net.Listener so we control the port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
server := &httptest.Server{
Listener: listener,
Config: &http.Server{Handler: hub.Handler()},
}
server.Start()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
addr := listener.Addr().String()
reconnectCalled := make(chan int, 5)
disconnectCalled := make(chan struct{}, 5)
connectCalled := make(chan struct{}, 5)
rc := NewReconnectingClient(ReconnectConfig{
URL: wsURL,
InitialBackoff: 50 * time.Millisecond,
MaxBackoff: 200 * time.Millisecond,
OnConnect: func() {
select {
case connectCalled <- struct{}{}:
default:
}
},
OnDisconnect: func() {
select {
case disconnectCalled <- struct{}{}:
default:
}
},
OnReconnect: func(attempt int) {
select {
case reconnectCalled <- attempt:
default:
}
},
})
clientCtx, clientCancel := context.WithCancel(context.Background())
defer clientCancel()
go rc.Connect(clientCtx)
// Wait for initial connection
select {
case <-connectCalled:
case <-time.After(time.Second):
t.Fatal("initial connection should have succeeded")
}
assert.Equal(t, StateConnected, rc.State())
// Shut down the server to simulate disconnect
cancel()
server.Close()
// Wait for disconnect callback
select {
case <-disconnectCalled:
// Good
case <-time.After(time.Second):
t.Fatal("OnDisconnect should have been called")
}
// Start a new server on the same address for reconnection
hub2 := NewHub()
ctx2 := t.Context()
go hub2.Run(ctx2)
listener2, err := net.Listen("tcp", addr)
if err != nil {
t.Skipf("could not bind to same port: %v", err)
return
}
newServer := &httptest.Server{
Listener: listener2,
Config: &http.Server{Handler: hub2.Handler()},
}
newServer.Start()
defer newServer.Close()
// Wait for reconnection
select {
case attempt := <-reconnectCalled:
assert.Greater(t, attempt, 0)
case <-time.After(3 * time.Second):
t.Fatal("OnReconnect should have been called")
}
assert.Equal(t, StateConnected, rc.State())
clientCancel()
})
}
func TestReconnectingClient_MaxRetries(t *testing.T) {
t.Run("stops after max retries exceeded", func(t *testing.T) {
// Use a URL that will never connect
rc := NewReconnectingClient(ReconnectConfig{
URL: "ws://127.0.0.1:1", // Should refuse connection
InitialBackoff: 10 * time.Millisecond,
MaxBackoff: 50 * time.Millisecond,
MaxRetries: 3,
})
errCh := make(chan error, 1)
go func() {
errCh <- rc.Connect(context.Background())
}()
select {
case err := <-errCh:
require.Error(t, err)
assert.Contains(t, err.Error(), "max retries (3) exceeded")
case <-time.After(5 * time.Second):
t.Fatal("should have stopped after max retries")
}
assert.Equal(t, StateDisconnected, rc.State())
})
}
func TestReconnectingClient_Send(t *testing.T) {
t.Run("sends message to server", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
connected := make(chan struct{}, 1)
rc := NewReconnectingClient(ReconnectConfig{
URL: wsURL,
OnConnect: func() {
select {
case connected <- struct{}{}:
default:
}
},
})
clientCtx, clientCancel := context.WithCancel(context.Background())
defer clientCancel()
go rc.Connect(clientCtx)
<-connected
time.Sleep(50 * time.Millisecond)
// Send a subscribe message
err := rc.Send(Message{
Type: TypeSubscribe,
Data: "test-channel",
})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
clientCancel()
})
t.Run("returns error when not connected", func(t *testing.T) {
rc := NewReconnectingClient(ReconnectConfig{
URL: "ws://localhost:1",
})
err := rc.Send(Message{Type: TypePing})
require.Error(t, err)
assert.Contains(t, err.Error(), "not connected")
})
t.Run("returns error for unmarshalable message", func(t *testing.T) {
rc := NewReconnectingClient(ReconnectConfig{
URL: "ws://localhost:1",
})
// Force a conn to be set so we get past the nil check
// to hit the marshal error first
err := rc.Send(Message{Type: TypeEvent, Data: make(chan int)})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to marshal message")
})
}
func TestReconnectingClient_Close(t *testing.T) {
t.Run("stops reconnection loop", func(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
connected := make(chan struct{}, 1)
rc := NewReconnectingClient(ReconnectConfig{
URL: wsURL,
OnConnect: func() {
select {
case connected <- struct{}{}:
default:
}
},
})
clientCtx := t.Context()
done := make(chan error, 1)
go func() {
done <- rc.Connect(clientCtx)
}()
<-connected
err := rc.Close()
assert.NoError(t, err)
select {
case <-done:
// Good — Connect returned
case <-time.After(time.Second):
t.Fatal("Connect should have returned after Close")
}
})
t.Run("close when not connected is safe", func(t *testing.T) {
rc := NewReconnectingClient(ReconnectConfig{
URL: "ws://localhost:1",
})
err := rc.Close()
assert.NoError(t, err)
})
}
func TestReconnectingClient_ExponentialBackoff(t *testing.T) {
t.Run("calculates correct backoff values", func(t *testing.T) {
rc := NewReconnectingClient(ReconnectConfig{
URL: "ws://localhost:1",
InitialBackoff: 100 * time.Millisecond,
MaxBackoff: 1 * time.Second,
BackoffMultiplier: 2.0,
})
// attempt 1: 100ms
assert.Equal(t, 100*time.Millisecond, rc.calculateBackoff(1))
// attempt 2: 200ms
assert.Equal(t, 200*time.Millisecond, rc.calculateBackoff(2))
// attempt 3: 400ms
assert.Equal(t, 400*time.Millisecond, rc.calculateBackoff(3))
// attempt 4: 800ms
assert.Equal(t, 800*time.Millisecond, rc.calculateBackoff(4))
// attempt 5: capped at 1s
assert.Equal(t, 1*time.Second, rc.calculateBackoff(5))
// attempt 10: still capped at 1s
assert.Equal(t, 1*time.Second, rc.calculateBackoff(10))
})
}
func TestReconnectingClient_Defaults(t *testing.T) {
t.Run("applies defaults for zero config values", func(t *testing.T) {
rc := NewReconnectingClient(ReconnectConfig{
URL: "ws://localhost:1",
})
assert.Equal(t, 1*time.Second, rc.config.InitialBackoff)
assert.Equal(t, 30*time.Second, rc.config.MaxBackoff)
assert.Equal(t, 2.0, rc.config.BackoffMultiplier)
assert.NotNil(t, rc.config.Dialer)
})
}
func TestReconnectingClient_ContextCancel(t *testing.T) {
t.Run("returns context error on cancel during backoff", func(t *testing.T) {
rc := NewReconnectingClient(ReconnectConfig{
URL: "ws://127.0.0.1:1",
InitialBackoff: 10 * time.Second, // Long backoff
})
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- rc.Connect(ctx)
}()
// Allow first dial attempt to fail
time.Sleep(200 * time.Millisecond)
// Cancel during backoff
cancel()
select {
case err := <-done:
require.Error(t, err)
assert.Equal(t, context.Canceled, err)
case <-time.After(2 * time.Second):
t.Fatal("Connect should have returned after context cancel")
}
})
}
func TestConnectionState(t *testing.T) {
t.Run("state constants are distinct", func(t *testing.T) {
assert.NotEqual(t, StateDisconnected, StateConnecting)
assert.NotEqual(t, StateConnecting, StateConnected)
assert.NotEqual(t, StateDisconnected, StateConnected)
})
}
// ---------------------------------------------------------------------------
// Hub.Run lifecycle — register, broadcast delivery, unregister via channels
// ---------------------------------------------------------------------------
func TestHubRun_RegisterClient_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount(), "client should be registered via hub loop")
}
func TestHubRun_BroadcastDelivery_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client
time.Sleep(20 * time.Millisecond)
err := hub.Broadcast(Message{Type: TypeEvent, Data: "lifecycle-test"})
require.NoError(t, err)
// Hub.Run loop delivers the broadcast to the client's send channel
select {
case msg := <-client.send:
var received Message
require.NoError(t, json.Unmarshal(msg, &received))
assert.Equal(t, TypeEvent, received.Type)
assert.Equal(t, "lifecycle-test", received.Data)
case <-time.After(time.Second):
t.Fatal("broadcast should be delivered via hub loop")
}
}
func TestHubRun_UnregisterClient_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
// Subscribe so we can verify channel cleanup
hub.Subscribe(client, "lifecycle-chan")
assert.Equal(t, 1, hub.ChannelSubscriberCount("lifecycle-chan"))
hub.unregister <- client
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, hub.ClientCount())
assert.Equal(t, 0, hub.ChannelSubscriberCount("lifecycle-chan"))
}
func TestHubRun_UnregisterIgnoresDuplicate_Bad(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client
time.Sleep(20 * time.Millisecond)
hub.unregister <- client
time.Sleep(20 * time.Millisecond)
// Second unregister should not panic or block
done := make(chan struct{})
go func() {
hub.unregister <- client
close(done)
}()
select {
case <-done:
// Good -- no panic, no block
case <-time.After(time.Second):
t.Fatal("duplicate unregister should not block")
}
}
// ---------------------------------------------------------------------------
// Subscribe / Unsubscribe — additional channel management tests
// ---------------------------------------------------------------------------
func TestSubscribe_MultipleChannels_Good(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.Subscribe(client, "alpha")
hub.Subscribe(client, "beta")
hub.Subscribe(client, "gamma")
assert.Equal(t, 3, hub.ChannelCount())
subs := client.Subscriptions()
assert.Len(t, subs, 3)
assert.Contains(t, subs, "alpha")
assert.Contains(t, subs, "beta")
assert.Contains(t, subs, "gamma")
}
func TestSubscribe_IdempotentDoubleSubscribe_Good(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.Subscribe(client, "dupl")
hub.Subscribe(client, "dupl")
// Still only one subscriber entry in the channel map
assert.Equal(t, 1, hub.ChannelSubscriberCount("dupl"))
}
func TestUnsubscribe_PartialLeave_Good(t *testing.T) {
hub := NewHub()
client1 := &Client{hub: hub, send: make(chan []byte, 256), subscriptions: make(map[string]bool)}
client2 := &Client{hub: hub, send: make(chan []byte, 256), subscriptions: make(map[string]bool)}
hub.Subscribe(client1, "shared")
hub.Subscribe(client2, "shared")
assert.Equal(t, 2, hub.ChannelSubscriberCount("shared"))
hub.Unsubscribe(client1, "shared")
assert.Equal(t, 1, hub.ChannelSubscriberCount("shared"))
// Channel still exists because client2 is subscribed
hub.mu.RLock()
_, exists := hub.channels["shared"]
hub.mu.RUnlock()
assert.True(t, exists, "channel should persist while subscribers remain")
}
// ---------------------------------------------------------------------------
// SendToChannel — multiple subscribers
// ---------------------------------------------------------------------------
func TestSendToChannel_MultipleSubscribers_Good(t *testing.T) {
hub := NewHub()
clients := make([]*Client, 5)
for i := range clients {
clients[i] = &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.Subscribe(clients[i], "multi")
}
err := hub.SendToChannel("multi", Message{Type: TypeEvent, Data: "fanout"})
require.NoError(t, err)
for i, c := range clients {
select {
case msg := <-c.send:
var received Message
require.NoError(t, json.Unmarshal(msg, &received))
assert.Equal(t, "multi", received.Channel)
case <-time.After(time.Second):
t.Fatalf("client %d should have received the message", i)
}
}
}
// ---------------------------------------------------------------------------
// SendProcessOutput / SendProcessStatus — edge cases
// ---------------------------------------------------------------------------
func TestSendProcessOutput_NoSubscribers_Good(t *testing.T) {
hub := NewHub()
err := hub.SendProcessOutput("orphan-proc", "some output")
assert.NoError(t, err, "sending to a process with no subscribers should not error")
}
func TestSendProcessStatus_NonZeroExit_Good(t *testing.T) {
hub := NewHub()
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.Subscribe(client, "process:fail-1")
err := hub.SendProcessStatus("fail-1", "exited", 137)
require.NoError(t, err)
select {
case msg := <-client.send:
var received Message
require.NoError(t, json.Unmarshal(msg, &received))
assert.Equal(t, TypeProcessStatus, received.Type)
assert.Equal(t, "fail-1", received.ProcessID)
data := received.Data.(map[string]any)
assert.Equal(t, "exited", data["status"])
assert.Equal(t, float64(137), data["exitCode"])
case <-time.After(time.Second):
t.Fatal("expected process status message")
}
}
// ---------------------------------------------------------------------------
// readPump — ping with timestamp verification
// ---------------------------------------------------------------------------
func TestReadPump_PingTimestamp_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
err = conn.WriteJSON(Message{Type: TypePing})
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(time.Second))
var pong Message
err = conn.ReadJSON(&pong)
require.NoError(t, err)
assert.Equal(t, TypePong, pong.Type)
assert.False(t, pong.Timestamp.IsZero(), "pong should include a timestamp")
}
// ---------------------------------------------------------------------------
// writePump — batch sending with multiple messages
// ---------------------------------------------------------------------------
func TestWritePump_BatchMultipleMessages_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Rapidly send multiple broadcasts so they queue up
numMessages := 10
for i := range numMessages {
err := hub.Broadcast(Message{
Type: TypeEvent,
Data: fmt.Sprintf("batch-%d", i),
})
require.NoError(t, err)
}
time.Sleep(100 * time.Millisecond)
// Read all messages — batched with newline separators
received := 0
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
for received < numMessages {
_, raw, err := conn.ReadMessage()
if err != nil {
break
}
parts := strings.SplitSeq(string(raw), "\n")
for part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
var msg Message
if json.Unmarshal([]byte(part), &msg) == nil {
received++
}
}
}
assert.Equal(t, numMessages, received, "all batched messages should be received")
}
// ---------------------------------------------------------------------------
// Integration — unsubscribe stops delivery
// ---------------------------------------------------------------------------
func TestIntegration_UnsubscribeStopsDelivery_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Subscribe
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "temp:feed"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
// Verify we receive messages on the channel
err = hub.SendToChannel("temp:feed", Message{Type: TypeEvent, Data: "before-unsub"})
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(time.Second))
var msg1 Message
err = conn.ReadJSON(&msg1)
require.NoError(t, err)
assert.Equal(t, "before-unsub", msg1.Data)
// Unsubscribe
err = conn.WriteJSON(Message{Type: TypeUnsubscribe, Data: "temp:feed"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
// Send another message -- client should NOT receive it
err = hub.SendToChannel("temp:feed", Message{Type: TypeEvent, Data: "after-unsub"})
require.NoError(t, err)
// Try to read -- should timeout (no message delivered)
conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
var msg2 Message
err = conn.ReadJSON(&msg2)
assert.Error(t, err, "should not receive messages after unsubscribing")
}
// ---------------------------------------------------------------------------
// Integration — broadcast reaches all clients (no channel subscription)
// ---------------------------------------------------------------------------
func TestIntegration_BroadcastReachesAllClients_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
numClients := 3
conns := make([]*websocket.Conn, numClients)
for i := range numClients {
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
require.NoError(t, err)
defer conn.Close()
conns[i] = conn
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, numClients, hub.ClientCount())
// Broadcast -- no channel subscription needed
err := hub.Broadcast(Message{Type: TypeError, Data: "global-alert"})
require.NoError(t, err)
for i, conn := range conns {
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
var received Message
err := conn.ReadJSON(&received)
require.NoError(t, err, "client %d should receive broadcast", i)
assert.Equal(t, TypeError, received.Type)
assert.Equal(t, "global-alert", received.Data)
}
}
// ---------------------------------------------------------------------------
// Integration — disconnect cleans up all subscriptions
// ---------------------------------------------------------------------------
func TestIntegration_DisconnectCleansUpEverything_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
require.NoError(t, err)
// Subscribe to multiple channels
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "ch-a"})
require.NoError(t, err)
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "ch-b"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
assert.Equal(t, 1, hub.ChannelSubscriberCount("ch-a"))
assert.Equal(t, 1, hub.ChannelSubscriberCount("ch-b"))
// Disconnect
conn.Close()
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 0, hub.ClientCount())
assert.Equal(t, 0, hub.ChannelSubscriberCount("ch-a"))
assert.Equal(t, 0, hub.ChannelSubscriberCount("ch-b"))
assert.Equal(t, 0, hub.ChannelCount(), "empty channels should be cleaned up")
}
// ---------------------------------------------------------------------------
// Concurrent broadcast + subscribe via hub loop (race test)
// ---------------------------------------------------------------------------
func TestConcurrentSubscribeAndBroadcast_Good(t *testing.T) {
hub := NewHub()
ctx := t.Context()
go hub.Run(ctx)
var wg sync.WaitGroup
for i := range 50 {
wg.Add(2)
go func(id int) {
defer wg.Done()
client := &Client{
hub: hub,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
hub.register <- client
}(i)
go func(id int) {
defer wg.Done()
_ = hub.Broadcast(Message{Type: TypeEvent, Data: id})
}(i)
}
wg.Wait()
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 50, hub.ClientCount())
}