Add AllClients, AllChannels, AllSubscriptions iter.Seq iterators. Use slices.Collect(maps.Keys()) in SendToChannel/Subscriptions, range-over-int in calculateBackoff and benchmarks. Co-Authored-By: Gemini <noreply@google.com> Co-Authored-By: Virgil <virgil@lethean.io>
2512 lines
64 KiB
Go
2512 lines
64 KiB
Go
// SPDX-Licence-Identifier: EUPL-1.2
|
|
|
|
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// wsURL converts an httptest server URL to a WebSocket URL.
|
|
func wsURL(server *httptest.Server) string {
|
|
return "ws" + strings.TrimPrefix(server.URL, "http")
|
|
}
|
|
|
|
func TestNewHub(t *testing.T) {
|
|
t.Run("creates hub with initialized maps", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
require.NotNil(t, hub)
|
|
assert.NotNil(t, hub.clients)
|
|
assert.NotNil(t, hub.broadcast)
|
|
assert.NotNil(t, hub.register)
|
|
assert.NotNil(t, hub.unregister)
|
|
assert.NotNil(t, hub.channels)
|
|
})
|
|
}
|
|
|
|
func TestHub_Run(t *testing.T) {
|
|
t.Run("stops on context cancel", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
hub.Run(ctx)
|
|
close(done)
|
|
}()
|
|
|
|
cancel()
|
|
|
|
select {
|
|
case <-done:
|
|
// Good - hub stopped
|
|
case <-time.After(time.Second):
|
|
t.Fatal("hub should have stopped on context cancel")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHub_Broadcast(t *testing.T) {
|
|
t.Run("marshals message with timestamp", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
msg := Message{
|
|
Type: TypeEvent,
|
|
Data: "test data",
|
|
}
|
|
|
|
err := hub.Broadcast(msg)
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("returns error when channel full", func(t *testing.T) {
|
|
hub := NewHub()
|
|
// Fill the broadcast channel
|
|
for range 256 {
|
|
hub.broadcast <- []byte("test")
|
|
}
|
|
|
|
err := hub.Broadcast(Message{Type: TypeEvent})
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "broadcast channel full")
|
|
})
|
|
}
|
|
|
|
func TestHub_Stats(t *testing.T) {
|
|
t.Run("returns empty stats for new hub", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
stats := hub.Stats()
|
|
|
|
assert.Equal(t, 0, stats.Clients)
|
|
assert.Equal(t, 0, stats.Channels)
|
|
})
|
|
|
|
t.Run("tracks client and channel counts", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
// Manually add clients for testing
|
|
hub.mu.Lock()
|
|
client1 := &Client{subscriptions: make(map[string]bool)}
|
|
client2 := &Client{subscriptions: make(map[string]bool)}
|
|
hub.clients[client1] = true
|
|
hub.clients[client2] = true
|
|
hub.channels["test-channel"] = make(map[*Client]bool)
|
|
hub.mu.Unlock()
|
|
|
|
stats := hub.Stats()
|
|
|
|
assert.Equal(t, 2, stats.Clients)
|
|
assert.Equal(t, 1, stats.Channels)
|
|
})
|
|
}
|
|
|
|
func TestHub_ClientCount(t *testing.T) {
|
|
t.Run("returns zero for empty hub", func(t *testing.T) {
|
|
hub := NewHub()
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
})
|
|
|
|
t.Run("counts connected clients", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
hub.mu.Lock()
|
|
hub.clients[&Client{}] = true
|
|
hub.clients[&Client{}] = true
|
|
hub.mu.Unlock()
|
|
|
|
assert.Equal(t, 2, hub.ClientCount())
|
|
})
|
|
}
|
|
|
|
func TestHub_ChannelCount(t *testing.T) {
|
|
t.Run("returns zero for empty hub", func(t *testing.T) {
|
|
hub := NewHub()
|
|
assert.Equal(t, 0, hub.ChannelCount())
|
|
})
|
|
|
|
t.Run("counts active channels", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
hub.mu.Lock()
|
|
hub.channels["channel1"] = make(map[*Client]bool)
|
|
hub.channels["channel2"] = make(map[*Client]bool)
|
|
hub.mu.Unlock()
|
|
|
|
assert.Equal(t, 2, hub.ChannelCount())
|
|
})
|
|
}
|
|
|
|
func TestHub_ChannelSubscriberCount(t *testing.T) {
|
|
t.Run("returns zero for non-existent channel", func(t *testing.T) {
|
|
hub := NewHub()
|
|
assert.Equal(t, 0, hub.ChannelSubscriberCount("non-existent"))
|
|
})
|
|
|
|
t.Run("counts subscribers in channel", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
hub.mu.Lock()
|
|
hub.channels["test-channel"] = make(map[*Client]bool)
|
|
hub.channels["test-channel"][&Client{}] = true
|
|
hub.channels["test-channel"][&Client{}] = true
|
|
hub.mu.Unlock()
|
|
|
|
assert.Equal(t, 2, hub.ChannelSubscriberCount("test-channel"))
|
|
})
|
|
}
|
|
|
|
func TestHub_Subscribe(t *testing.T) {
|
|
t.Run("adds client to channel", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.mu.Lock()
|
|
hub.clients[client] = true
|
|
hub.mu.Unlock()
|
|
|
|
hub.Subscribe(client, "test-channel")
|
|
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
assert.True(t, client.subscriptions["test-channel"])
|
|
})
|
|
|
|
t.Run("creates channel if not exists", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.Subscribe(client, "new-channel")
|
|
|
|
hub.mu.RLock()
|
|
_, exists := hub.channels["new-channel"]
|
|
hub.mu.RUnlock()
|
|
|
|
assert.True(t, exists)
|
|
})
|
|
}
|
|
|
|
func TestHub_Unsubscribe(t *testing.T) {
|
|
t.Run("removes client from channel", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.Subscribe(client, "test-channel")
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
|
|
hub.Unsubscribe(client, "test-channel")
|
|
assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel"))
|
|
assert.False(t, client.subscriptions["test-channel"])
|
|
})
|
|
|
|
t.Run("cleans up empty channels", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.Subscribe(client, "temp-channel")
|
|
hub.Unsubscribe(client, "temp-channel")
|
|
|
|
hub.mu.RLock()
|
|
_, exists := hub.channels["temp-channel"]
|
|
hub.mu.RUnlock()
|
|
|
|
assert.False(t, exists, "empty channel should be removed")
|
|
})
|
|
|
|
t.Run("handles non-existent channel gracefully", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
// Should not panic
|
|
hub.Unsubscribe(client, "non-existent")
|
|
})
|
|
}
|
|
|
|
func TestHub_SendToChannel(t *testing.T) {
|
|
t.Run("sends to channel subscribers", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.mu.Lock()
|
|
hub.clients[client] = true
|
|
hub.mu.Unlock()
|
|
hub.Subscribe(client, "test-channel")
|
|
|
|
err := hub.SendToChannel("test-channel", Message{
|
|
Type: TypeEvent,
|
|
Data: "test",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case msg := <-client.send:
|
|
var received Message
|
|
err := json.Unmarshal(msg, &received)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, TypeEvent, received.Type)
|
|
assert.Equal(t, "test-channel", received.Channel)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected message on client send channel")
|
|
}
|
|
})
|
|
|
|
t.Run("returns nil for non-existent channel", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
err := hub.SendToChannel("non-existent", Message{Type: TypeEvent})
|
|
assert.NoError(t, err, "should not error for non-existent channel")
|
|
})
|
|
}
|
|
|
|
func TestHub_SendProcessOutput(t *testing.T) {
|
|
t.Run("sends output to process channel", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.mu.Lock()
|
|
hub.clients[client] = true
|
|
hub.mu.Unlock()
|
|
hub.Subscribe(client, "process:proc-1")
|
|
|
|
err := hub.SendProcessOutput("proc-1", "hello world")
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case msg := <-client.send:
|
|
var received Message
|
|
err := json.Unmarshal(msg, &received)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, TypeProcessOutput, received.Type)
|
|
assert.Equal(t, "proc-1", received.ProcessID)
|
|
assert.Equal(t, "hello world", received.Data)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected message on client send channel")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHub_SendProcessStatus(t *testing.T) {
|
|
t.Run("sends status to process channel", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.mu.Lock()
|
|
hub.clients[client] = true
|
|
hub.mu.Unlock()
|
|
hub.Subscribe(client, "process:proc-1")
|
|
|
|
err := hub.SendProcessStatus("proc-1", "exited", 0)
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case msg := <-client.send:
|
|
var received Message
|
|
err := json.Unmarshal(msg, &received)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, TypeProcessStatus, received.Type)
|
|
assert.Equal(t, "proc-1", received.ProcessID)
|
|
|
|
data, ok := received.Data.(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "exited", data["status"])
|
|
assert.Equal(t, float64(0), data["exitCode"])
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected message on client send channel")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHub_SendError(t *testing.T) {
|
|
t.Run("broadcasts error message", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- client
|
|
// Give time for registration
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
err := hub.SendError("something went wrong")
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case msg := <-client.send:
|
|
var received Message
|
|
err := json.Unmarshal(msg, &received)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, TypeError, received.Type)
|
|
assert.Equal(t, "something went wrong", received.Data)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected error message on client send channel")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHub_SendEvent(t *testing.T) {
|
|
t.Run("broadcasts event message", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- client
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
err := hub.SendEvent("user_joined", map[string]string{"user": "alice"})
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case msg := <-client.send:
|
|
var received Message
|
|
err := json.Unmarshal(msg, &received)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, TypeEvent, received.Type)
|
|
|
|
data, ok := received.Data.(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "user_joined", data["event"])
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected event message on client send channel")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestClient_Subscriptions(t *testing.T) {
|
|
t.Run("returns copy of subscriptions", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.Subscribe(client, "channel1")
|
|
hub.Subscribe(client, "channel2")
|
|
|
|
subs := client.Subscriptions()
|
|
|
|
assert.Len(t, subs, 2)
|
|
assert.Contains(t, subs, "channel1")
|
|
assert.Contains(t, subs, "channel2")
|
|
})
|
|
}
|
|
|
|
func TestClient_AllSubscriptions(t *testing.T) {
|
|
t.Run("returns iterator over subscriptions", func(t *testing.T) {
|
|
client := &Client{subscriptions: make(map[string]bool)}
|
|
client.subscriptions["sub1"] = true
|
|
client.subscriptions["sub2"] = true
|
|
|
|
subs := slices.Collect(client.AllSubscriptions())
|
|
assert.Len(t, subs, 2)
|
|
assert.Contains(t, subs, "sub1")
|
|
assert.Contains(t, subs, "sub2")
|
|
})
|
|
}
|
|
|
|
func TestHub_AllClients(t *testing.T) {
|
|
t.Run("returns iterator over all clients", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client1 := &Client{subscriptions: make(map[string]bool)}
|
|
client2 := &Client{subscriptions: make(map[string]bool)}
|
|
|
|
hub.mu.Lock()
|
|
hub.clients[client1] = true
|
|
hub.clients[client2] = true
|
|
hub.mu.Unlock()
|
|
|
|
clients := slices.Collect(hub.AllClients())
|
|
assert.Len(t, clients, 2)
|
|
assert.Contains(t, clients, client1)
|
|
assert.Contains(t, clients, client2)
|
|
})
|
|
}
|
|
|
|
func TestHub_AllChannels(t *testing.T) {
|
|
t.Run("returns iterator over all active channels", func(t *testing.T) {
|
|
hub := NewHub()
|
|
hub.mu.Lock()
|
|
hub.channels["ch1"] = make(map[*Client]bool)
|
|
hub.channels["ch2"] = make(map[*Client]bool)
|
|
hub.mu.Unlock()
|
|
|
|
channels := slices.Collect(hub.AllChannels())
|
|
assert.Len(t, channels, 2)
|
|
assert.Contains(t, channels, "ch1")
|
|
assert.Contains(t, channels, "ch2")
|
|
})
|
|
}
|
|
|
|
func TestMessage_JSON(t *testing.T) {
|
|
t.Run("marshals correctly", func(t *testing.T) {
|
|
msg := Message{
|
|
Type: TypeProcessOutput,
|
|
Channel: "process:1",
|
|
ProcessID: "1",
|
|
Data: "output line",
|
|
Timestamp: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
}
|
|
|
|
data, err := json.Marshal(msg)
|
|
require.NoError(t, err)
|
|
|
|
assert.Contains(t, string(data), `"type":"process_output"`)
|
|
assert.Contains(t, string(data), `"channel":"process:1"`)
|
|
assert.Contains(t, string(data), `"processId":"1"`)
|
|
assert.Contains(t, string(data), `"data":"output line"`)
|
|
})
|
|
|
|
t.Run("unmarshals correctly", func(t *testing.T) {
|
|
jsonStr := `{"type":"subscribe","data":"channel:test"}`
|
|
|
|
var msg Message
|
|
err := json.Unmarshal([]byte(jsonStr), &msg)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, TypeSubscribe, msg.Type)
|
|
assert.Equal(t, "channel:test", msg.Data)
|
|
})
|
|
}
|
|
|
|
func TestHub_WebSocketHandler(t *testing.T) {
|
|
t.Run("upgrades connection and registers client", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Give time for registration
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
})
|
|
|
|
t.Run("handles subscribe message", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Send subscribe message
|
|
subscribeMsg := Message{
|
|
Type: TypeSubscribe,
|
|
Data: "test-channel",
|
|
}
|
|
err = conn.WriteJSON(subscribeMsg)
|
|
require.NoError(t, err)
|
|
|
|
// Give time for subscription
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
})
|
|
|
|
t.Run("handles unsubscribe message", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Subscribe first
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
|
|
// Unsubscribe
|
|
err = conn.WriteJSON(Message{Type: TypeUnsubscribe, Data: "test-channel"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel"))
|
|
})
|
|
|
|
t.Run("responds to ping with pong", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Give time for registration
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send ping
|
|
err = conn.WriteJSON(Message{Type: TypePing})
|
|
require.NoError(t, err)
|
|
|
|
// Read pong response
|
|
var response Message
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
err = conn.ReadJSON(&response)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, TypePong, response.Type)
|
|
})
|
|
|
|
t.Run("broadcasts messages to clients", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Give time for registration
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Broadcast a message
|
|
err = hub.Broadcast(Message{
|
|
Type: TypeEvent,
|
|
Data: "broadcast test",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Read the broadcast
|
|
var response Message
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
err = conn.ReadJSON(&response)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, TypeEvent, response.Type)
|
|
assert.Equal(t, "broadcast test", response.Data)
|
|
})
|
|
|
|
t.Run("unregisters client on connection close", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Wait for registration
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
|
|
// Close connection
|
|
conn.Close()
|
|
|
|
// Wait for unregistration
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
})
|
|
|
|
t.Run("removes client from channels on disconnect", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Subscribe to channel
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
|
|
// Close connection
|
|
conn.Close()
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Channel should be cleaned up
|
|
assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel"))
|
|
})
|
|
}
|
|
|
|
func TestHub_Concurrency(t *testing.T) {
|
|
t.Run("handles concurrent subscriptions", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
var wg sync.WaitGroup
|
|
numClients := 100
|
|
|
|
for i := range numClients {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.mu.Lock()
|
|
hub.clients[client] = true
|
|
hub.mu.Unlock()
|
|
|
|
hub.Subscribe(client, "shared-channel")
|
|
hub.Subscribe(client, "shared-channel") // Double subscribe should be safe
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
assert.Equal(t, numClients, hub.ChannelSubscriberCount("shared-channel"))
|
|
})
|
|
|
|
t.Run("handles concurrent broadcasts", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 1000),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- client
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
var wg sync.WaitGroup
|
|
numBroadcasts := 100
|
|
|
|
for i := range numBroadcasts {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
_ = hub.Broadcast(Message{
|
|
Type: TypeEvent,
|
|
Data: id,
|
|
})
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Give time for broadcasts to be delivered
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Count received messages
|
|
received := 0
|
|
timeout := time.After(100 * time.Millisecond)
|
|
loop:
|
|
for {
|
|
select {
|
|
case <-client.send:
|
|
received++
|
|
case <-timeout:
|
|
break loop
|
|
}
|
|
}
|
|
|
|
// All or most broadcasts should be received
|
|
assert.GreaterOrEqual(t, received, numBroadcasts-10, "should receive most broadcasts")
|
|
})
|
|
}
|
|
|
|
func TestHub_HandleWebSocket(t *testing.T) {
|
|
t.Run("alias works same as Handler", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
// Test with HandleWebSocket directly
|
|
server := httptest.NewServer(http.HandlerFunc(hub.HandleWebSocket))
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
})
|
|
}
|
|
|
|
func TestMustMarshal(t *testing.T) {
|
|
t.Run("marshals valid data", func(t *testing.T) {
|
|
data := mustMarshal(Message{Type: TypePong})
|
|
assert.Contains(t, string(data), "pong")
|
|
})
|
|
|
|
t.Run("handles unmarshalable data without panic", func(t *testing.T) {
|
|
// Create a channel which cannot be marshaled
|
|
// This should not panic, even if it returns nil
|
|
ch := make(chan int)
|
|
assert.NotPanics(t, func() {
|
|
_ = mustMarshal(ch)
|
|
})
|
|
})
|
|
}
|
|
|
|
// --- Phase 0: Additional coverage tests ---
|
|
|
|
func TestHub_Run_ShutdownClosesClients(t *testing.T) {
|
|
t.Run("closes all client send channels on shutdown", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go hub.Run(ctx)
|
|
|
|
// Register clients via the hub's Run loop
|
|
client1 := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
client2 := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- client1
|
|
hub.register <- client2
|
|
time.Sleep(20 * time.Millisecond)
|
|
|
|
assert.Equal(t, 2, hub.ClientCount())
|
|
|
|
// Cancel context to trigger shutdown
|
|
cancel()
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send channels should be closed
|
|
_, ok1 := <-client1.send
|
|
assert.False(t, ok1, "client1 send channel should be closed")
|
|
_, ok2 := <-client2.send
|
|
assert.False(t, ok2, "client2 send channel should be closed")
|
|
})
|
|
}
|
|
|
|
func TestHub_Run_BroadcastToClientWithFullBuffer(t *testing.T) {
|
|
t.Run("unregisters client with full send buffer during broadcast", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
// Create a client with a tiny buffer that will overflow
|
|
slowClient := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 1), // Very small buffer
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- slowClient
|
|
time.Sleep(20 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
|
|
// Fill the client's send buffer
|
|
slowClient.send <- []byte("blocking")
|
|
|
|
// Broadcast should trigger the overflow path
|
|
err := hub.Broadcast(Message{Type: TypeEvent, Data: "overflow"})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the unregister goroutine to fire
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
assert.Equal(t, 0, hub.ClientCount(), "slow client should be unregistered")
|
|
})
|
|
}
|
|
|
|
func TestHub_SendToChannel_ClientBufferFull(t *testing.T) {
|
|
t.Run("skips client with full send buffer", func(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 1), // Tiny buffer
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.mu.Lock()
|
|
hub.clients[client] = true
|
|
hub.mu.Unlock()
|
|
hub.Subscribe(client, "test-channel")
|
|
|
|
// Fill the client buffer
|
|
client.send <- []byte("blocking")
|
|
|
|
// SendToChannel should not block; it skips the full client
|
|
err := hub.SendToChannel("test-channel", Message{Type: TypeEvent, Data: "overflow"})
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestHub_Broadcast_MarshalError(t *testing.T) {
|
|
t.Run("returns error for unmarshalable message", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
// Channels cannot be marshalled to JSON
|
|
msg := Message{
|
|
Type: TypeEvent,
|
|
Data: make(chan int),
|
|
}
|
|
|
|
err := hub.Broadcast(msg)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to marshal message")
|
|
})
|
|
}
|
|
|
|
func TestHub_SendToChannel_MarshalError(t *testing.T) {
|
|
t.Run("returns error for unmarshalable message", func(t *testing.T) {
|
|
hub := NewHub()
|
|
|
|
msg := Message{
|
|
Type: TypeEvent,
|
|
Data: make(chan int),
|
|
}
|
|
|
|
err := hub.SendToChannel("any-channel", msg)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to marshal message")
|
|
})
|
|
}
|
|
|
|
func TestHub_Handler_UpgradeError(t *testing.T) {
|
|
t.Run("returns silently on non-websocket request", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
// Make a plain HTTP request (not a WebSocket upgrade)
|
|
resp, err := http.Get(server.URL)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
// The handler should have returned an error response
|
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
})
|
|
}
|
|
|
|
func TestClient_Close(t *testing.T) {
|
|
t.Run("unregisters and closes connection", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
|
|
// Get the client from the hub
|
|
hub.mu.RLock()
|
|
var client *Client
|
|
for c := range hub.clients {
|
|
client = c
|
|
break
|
|
}
|
|
hub.mu.RUnlock()
|
|
require.NotNil(t, client)
|
|
|
|
// Close via Client.Close()
|
|
err = client.Close()
|
|
// conn.Close may return an error if already closing, that is acceptable
|
|
_ = err
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
|
|
// Connection should be closed — writing should fail
|
|
_ = conn.Close() // ensure clean up
|
|
})
|
|
}
|
|
|
|
func TestReadPump_MalformedJSON(t *testing.T) {
|
|
t.Run("ignores malformed JSON messages", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send malformed JSON — should be ignored without disconnecting
|
|
err = conn.WriteMessage(websocket.TextMessage, []byte("this is not json"))
|
|
require.NoError(t, err)
|
|
|
|
// Send a valid subscribe after the bad message — client should still be alive
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"})
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
})
|
|
}
|
|
|
|
func TestReadPump_SubscribeWithNonStringData(t *testing.T) {
|
|
t.Run("ignores subscribe with non-string data", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send subscribe with numeric data instead of string
|
|
err = conn.WriteJSON(map[string]any{
|
|
"type": "subscribe",
|
|
"data": 12345,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// No channels should have been created
|
|
assert.Equal(t, 0, hub.ChannelCount())
|
|
})
|
|
}
|
|
|
|
func TestReadPump_UnsubscribeWithNonStringData(t *testing.T) {
|
|
t.Run("ignores unsubscribe with non-string data", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Subscribe first with valid data
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
|
|
// Send unsubscribe with non-string data — should be ignored
|
|
err = conn.WriteJSON(map[string]any{
|
|
"type": "unsubscribe",
|
|
"data": []string{"test-channel"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Channel should still have the subscriber
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
})
|
|
}
|
|
|
|
func TestReadPump_UnknownMessageType(t *testing.T) {
|
|
t.Run("ignores unknown message types without disconnecting", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send a message with an unknown type
|
|
err = conn.WriteJSON(Message{Type: "unknown_type", Data: "anything"})
|
|
require.NoError(t, err)
|
|
|
|
// Client should still be connected
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
})
|
|
}
|
|
|
|
func TestWritePump_SendsCloseOnChannelClose(t *testing.T) {
|
|
t.Run("sends close message when send channel is closed", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Unregister the client (which closes its send channel)
|
|
hub.mu.RLock()
|
|
var client *Client
|
|
for c := range hub.clients {
|
|
client = c
|
|
break
|
|
}
|
|
hub.mu.RUnlock()
|
|
|
|
hub.unregister <- client
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// The client should receive a close message and the connection should end
|
|
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
|
_, _, readErr := conn.ReadMessage()
|
|
assert.Error(t, readErr, "reading should fail after close")
|
|
})
|
|
}
|
|
|
|
func TestWritePump_BatchesMessages(t *testing.T) {
|
|
t.Run("batches multiple queued messages into a single write", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Get the client
|
|
hub.mu.RLock()
|
|
var client *Client
|
|
for c := range hub.clients {
|
|
client = c
|
|
break
|
|
}
|
|
hub.mu.RUnlock()
|
|
require.NotNil(t, client)
|
|
|
|
// Queue multiple messages rapidly on the send channel directly
|
|
msg1 := mustMarshal(Message{Type: TypeEvent, Data: "batch-1", Timestamp: time.Now()})
|
|
msg2 := mustMarshal(Message{Type: TypeEvent, Data: "batch-2", Timestamp: time.Now()})
|
|
msg3 := mustMarshal(Message{Type: TypeEvent, Data: "batch-3", Timestamp: time.Now()})
|
|
client.send <- msg1
|
|
client.send <- msg2
|
|
client.send <- msg3
|
|
|
|
// Read the batched response — it should contain all three messages
|
|
// separated by newlines
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, data, readErr := conn.ReadMessage()
|
|
require.NoError(t, readErr)
|
|
|
|
content := string(data)
|
|
assert.Contains(t, content, "batch-1")
|
|
// Depending on timing, messages may be batched or separate
|
|
// At minimum, the first message should be received
|
|
})
|
|
}
|
|
|
|
func TestHub_MultipleClientsOnChannel(t *testing.T) {
|
|
t.Run("delivers channel messages to all subscribers", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
// Connect three clients and subscribe them all to the same channel
|
|
conns := make([]*websocket.Conn, 3)
|
|
for i := range conns {
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
conns[i] = conn
|
|
}
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 3, hub.ClientCount())
|
|
|
|
// Subscribe all to "shared"
|
|
for _, conn := range conns {
|
|
err := conn.WriteJSON(Message{Type: TypeSubscribe, Data: "shared"})
|
|
require.NoError(t, err)
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 3, hub.ChannelSubscriberCount("shared"))
|
|
|
|
// Send to channel
|
|
err := hub.SendToChannel("shared", Message{Type: TypeEvent, Data: "hello all"})
|
|
require.NoError(t, err)
|
|
|
|
// All three clients should receive the message
|
|
for i, conn := range conns {
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
var received Message
|
|
err := conn.ReadJSON(&received)
|
|
require.NoError(t, err, "client %d should receive message", i)
|
|
assert.Equal(t, TypeEvent, received.Type)
|
|
assert.Equal(t, "hello all", received.Data)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHub_ConcurrentSubscribeUnsubscribe(t *testing.T) {
|
|
t.Run("handles concurrent subscribe and unsubscribe safely", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
var wg sync.WaitGroup
|
|
numClients := 50
|
|
|
|
clients := make([]*Client, numClients)
|
|
for i := range numClients {
|
|
clients[i] = &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
hub.mu.Lock()
|
|
hub.clients[clients[i]] = true
|
|
hub.mu.Unlock()
|
|
}
|
|
|
|
// Half subscribe, half unsubscribe concurrently
|
|
for i := range numClients {
|
|
wg.Add(1)
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
hub.Subscribe(clients[idx], "race-channel")
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
// Now concurrently unsubscribe half and subscribe the other half to a new channel
|
|
for i := range numClients {
|
|
wg.Add(1)
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
if idx%2 == 0 {
|
|
hub.Unsubscribe(clients[idx], "race-channel")
|
|
} else {
|
|
hub.Subscribe(clients[idx], "another-channel")
|
|
}
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
// Verify: half should remain on race-channel, half should be on another-channel
|
|
assert.Equal(t, numClients/2, hub.ChannelSubscriberCount("race-channel"))
|
|
assert.Equal(t, numClients/2, hub.ChannelSubscriberCount("another-channel"))
|
|
})
|
|
}
|
|
|
|
func TestHub_ProcessOutputEndToEnd(t *testing.T) {
|
|
t.Run("end-to-end process output via WebSocket", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Subscribe to process channel
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "process:build-42"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send lines one at a time with a small delay to avoid batching
|
|
lines := []string{"Compiling...", "Linking...", "Done."}
|
|
for _, line := range lines {
|
|
err = hub.SendProcessOutput("build-42", line)
|
|
require.NoError(t, err)
|
|
time.Sleep(10 * time.Millisecond) // Allow writePump to flush each individually
|
|
}
|
|
|
|
// Read all three — messages may arrive individually or batched
|
|
// with newline separators. ReadMessage gives raw frames.
|
|
var received []Message
|
|
for len(received) < 3 {
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
_, data, readErr := conn.ReadMessage()
|
|
require.NoError(t, readErr)
|
|
|
|
// A single frame may contain multiple newline-separated JSON objects
|
|
parts := strings.SplitSeq(strings.TrimSpace(string(data)), "\n")
|
|
for part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
continue
|
|
}
|
|
var msg Message
|
|
err := json.Unmarshal([]byte(part), &msg)
|
|
require.NoError(t, err)
|
|
received = append(received, msg)
|
|
}
|
|
}
|
|
|
|
require.Len(t, received, 3)
|
|
for i, expected := range lines {
|
|
assert.Equal(t, TypeProcessOutput, received[i].Type)
|
|
assert.Equal(t, "build-42", received[i].ProcessID)
|
|
assert.Equal(t, expected, received[i].Data)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHub_ProcessStatusEndToEnd(t *testing.T) {
|
|
t.Run("end-to-end process status via WebSocket", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Subscribe
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "process:job-7"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send status
|
|
err = hub.SendProcessStatus("job-7", "exited", 1)
|
|
require.NoError(t, err)
|
|
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
var received Message
|
|
err = conn.ReadJSON(&received)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, TypeProcessStatus, received.Type)
|
|
assert.Equal(t, "job-7", received.ProcessID)
|
|
|
|
data, ok := received.Data.(map[string]any)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "exited", data["status"])
|
|
assert.Equal(t, float64(1), data["exitCode"])
|
|
})
|
|
}
|
|
|
|
// --- Benchmarks ---
|
|
|
|
func BenchmarkBroadcast(b *testing.B) {
|
|
hub := NewHub()
|
|
ctx := b.Context()
|
|
go hub.Run(ctx)
|
|
|
|
// Register 100 clients
|
|
for range 100 {
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
hub.register <- client
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
msg := Message{Type: TypeEvent, Data: "benchmark"}
|
|
|
|
b.ResetTimer()
|
|
for range b.N {
|
|
_ = hub.Broadcast(msg)
|
|
}
|
|
}
|
|
|
|
func BenchmarkSendToChannel(b *testing.B) {
|
|
hub := NewHub()
|
|
ctx := b.Context()
|
|
go hub.Run(ctx)
|
|
|
|
// Register 50 clients, all subscribed to the same channel
|
|
for range 50 {
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
hub.mu.Lock()
|
|
hub.clients[client] = true
|
|
hub.mu.Unlock()
|
|
hub.Subscribe(client, "bench-channel")
|
|
}
|
|
|
|
msg := Message{Type: TypeEvent, Data: "benchmark"}
|
|
|
|
b.ResetTimer()
|
|
for range b.N {
|
|
_ = hub.SendToChannel("bench-channel", msg)
|
|
}
|
|
}
|
|
|
|
// --- Phase 1: Connection resilience tests ---
|
|
|
|
func TestNewHubWithConfig(t *testing.T) {
|
|
t.Run("uses custom heartbeat and pong timeout", func(t *testing.T) {
|
|
config := HubConfig{
|
|
HeartbeatInterval: 5 * time.Second,
|
|
PongTimeout: 10 * time.Second,
|
|
WriteTimeout: 3 * time.Second,
|
|
}
|
|
hub := NewHubWithConfig(config)
|
|
|
|
assert.Equal(t, 5*time.Second, hub.config.HeartbeatInterval)
|
|
assert.Equal(t, 10*time.Second, hub.config.PongTimeout)
|
|
assert.Equal(t, 3*time.Second, hub.config.WriteTimeout)
|
|
})
|
|
|
|
t.Run("applies defaults for zero values", func(t *testing.T) {
|
|
hub := NewHubWithConfig(HubConfig{})
|
|
|
|
assert.Equal(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval)
|
|
assert.Equal(t, DefaultPongTimeout, hub.config.PongTimeout)
|
|
assert.Equal(t, DefaultWriteTimeout, hub.config.WriteTimeout)
|
|
})
|
|
|
|
t.Run("applies defaults for negative values", func(t *testing.T) {
|
|
hub := NewHubWithConfig(HubConfig{
|
|
HeartbeatInterval: -1,
|
|
PongTimeout: -1,
|
|
WriteTimeout: -1,
|
|
})
|
|
|
|
assert.Equal(t, DefaultHeartbeatInterval, hub.config.HeartbeatInterval)
|
|
assert.Equal(t, DefaultPongTimeout, hub.config.PongTimeout)
|
|
assert.Equal(t, DefaultWriteTimeout, hub.config.WriteTimeout)
|
|
})
|
|
}
|
|
|
|
func TestDefaultHubConfig(t *testing.T) {
|
|
t.Run("returns sensible defaults", func(t *testing.T) {
|
|
config := DefaultHubConfig()
|
|
|
|
assert.Equal(t, 30*time.Second, config.HeartbeatInterval)
|
|
assert.Equal(t, 60*time.Second, config.PongTimeout)
|
|
assert.Equal(t, 10*time.Second, config.WriteTimeout)
|
|
assert.Nil(t, config.OnConnect)
|
|
assert.Nil(t, config.OnDisconnect)
|
|
})
|
|
}
|
|
|
|
func TestHub_ConnectionCallbacks(t *testing.T) {
|
|
t.Run("calls OnConnect when client connects", func(t *testing.T) {
|
|
connectCalled := make(chan *Client, 1)
|
|
|
|
hub := NewHubWithConfig(HubConfig{
|
|
OnConnect: func(client *Client) {
|
|
connectCalled <- client
|
|
},
|
|
})
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
select {
|
|
case c := <-connectCalled:
|
|
assert.NotNil(t, c)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("OnConnect callback should have been called")
|
|
}
|
|
})
|
|
|
|
t.Run("calls OnDisconnect when client disconnects", func(t *testing.T) {
|
|
disconnectCalled := make(chan *Client, 1)
|
|
|
|
hub := NewHubWithConfig(HubConfig{
|
|
OnDisconnect: func(client *Client) {
|
|
disconnectCalled <- client
|
|
},
|
|
})
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Close the connection to trigger disconnect
|
|
conn.Close()
|
|
|
|
select {
|
|
case c := <-disconnectCalled:
|
|
assert.NotNil(t, c)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("OnDisconnect callback should have been called")
|
|
}
|
|
})
|
|
|
|
t.Run("does not call OnDisconnect for unknown client", func(t *testing.T) {
|
|
disconnectCalled := make(chan struct{}, 1)
|
|
|
|
hub := NewHubWithConfig(HubConfig{
|
|
OnDisconnect: func(client *Client) {
|
|
disconnectCalled <- struct{}{}
|
|
},
|
|
})
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
// Send an unregister for a client that was never registered
|
|
unknownClient := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 1),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
hub.unregister <- unknownClient
|
|
|
|
select {
|
|
case <-disconnectCalled:
|
|
t.Fatal("OnDisconnect should not be called for unknown client")
|
|
case <-time.After(100 * time.Millisecond):
|
|
// Good — callback was not called
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHub_CustomHeartbeat(t *testing.T) {
|
|
t.Run("uses custom heartbeat interval for server pings", func(t *testing.T) {
|
|
// Use a very short heartbeat to test it actually fires
|
|
hub := NewHubWithConfig(HubConfig{
|
|
HeartbeatInterval: 100 * time.Millisecond,
|
|
PongTimeout: 500 * time.Millisecond,
|
|
WriteTimeout: 500 * time.Millisecond,
|
|
})
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
pingReceived := make(chan struct{}, 1)
|
|
dialer := websocket.Dialer{}
|
|
conn, _, err := dialer.Dial(wsURL, nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
conn.SetPingHandler(func(appData string) error {
|
|
select {
|
|
case pingReceived <- struct{}{}:
|
|
default:
|
|
}
|
|
// Respond with pong
|
|
return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second))
|
|
})
|
|
|
|
// Start reading in background to process control frames
|
|
go func() {
|
|
for {
|
|
_, _, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-pingReceived:
|
|
// Server ping received within custom heartbeat interval
|
|
case <-time.After(time.Second):
|
|
t.Fatal("should have received server ping within custom heartbeat interval")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestReconnectingClient_Connect(t *testing.T) {
|
|
t.Run("connects to server and receives messages", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
connectCalled := make(chan struct{}, 1)
|
|
msgReceived := make(chan Message, 1)
|
|
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: wsURL,
|
|
OnConnect: func() {
|
|
select {
|
|
case connectCalled <- struct{}{}:
|
|
default:
|
|
}
|
|
},
|
|
OnMessage: func(msg Message) {
|
|
select {
|
|
case msgReceived <- msg:
|
|
default:
|
|
}
|
|
},
|
|
})
|
|
|
|
// Run Connect in background
|
|
clientCtx, clientCancel := context.WithCancel(context.Background())
|
|
defer clientCancel()
|
|
go rc.Connect(clientCtx)
|
|
|
|
// Wait for connect
|
|
select {
|
|
case <-connectCalled:
|
|
// Good
|
|
case <-time.After(time.Second):
|
|
t.Fatal("OnConnect should have been called")
|
|
}
|
|
|
|
assert.Equal(t, StateConnected, rc.State())
|
|
|
|
// Wait for client to register
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Broadcast a message
|
|
err := hub.Broadcast(Message{Type: TypeEvent, Data: "hello"})
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case msg := <-msgReceived:
|
|
assert.Equal(t, TypeEvent, msg.Type)
|
|
assert.Equal(t, "hello", msg.Data)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("should have received the broadcast message")
|
|
}
|
|
|
|
clientCancel()
|
|
time.Sleep(50 * time.Millisecond)
|
|
})
|
|
}
|
|
|
|
func TestReconnectingClient_Reconnect(t *testing.T) {
|
|
t.Run("reconnects after server restart", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go hub.Run(ctx)
|
|
|
|
// Use a net.Listener so we control the port
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
server := &httptest.Server{
|
|
Listener: listener,
|
|
Config: &http.Server{Handler: hub.Handler()},
|
|
}
|
|
server.Start()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
addr := listener.Addr().String()
|
|
|
|
reconnectCalled := make(chan int, 5)
|
|
disconnectCalled := make(chan struct{}, 5)
|
|
connectCalled := make(chan struct{}, 5)
|
|
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: wsURL,
|
|
InitialBackoff: 50 * time.Millisecond,
|
|
MaxBackoff: 200 * time.Millisecond,
|
|
OnConnect: func() {
|
|
select {
|
|
case connectCalled <- struct{}{}:
|
|
default:
|
|
}
|
|
},
|
|
OnDisconnect: func() {
|
|
select {
|
|
case disconnectCalled <- struct{}{}:
|
|
default:
|
|
}
|
|
},
|
|
OnReconnect: func(attempt int) {
|
|
select {
|
|
case reconnectCalled <- attempt:
|
|
default:
|
|
}
|
|
},
|
|
})
|
|
|
|
clientCtx, clientCancel := context.WithCancel(context.Background())
|
|
defer clientCancel()
|
|
go rc.Connect(clientCtx)
|
|
|
|
// Wait for initial connection
|
|
select {
|
|
case <-connectCalled:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("initial connection should have succeeded")
|
|
}
|
|
assert.Equal(t, StateConnected, rc.State())
|
|
|
|
// Shut down the server to simulate disconnect
|
|
cancel()
|
|
server.Close()
|
|
|
|
// Wait for disconnect callback
|
|
select {
|
|
case <-disconnectCalled:
|
|
// Good
|
|
case <-time.After(time.Second):
|
|
t.Fatal("OnDisconnect should have been called")
|
|
}
|
|
|
|
// Start a new server on the same address for reconnection
|
|
hub2 := NewHub()
|
|
ctx2 := t.Context()
|
|
go hub2.Run(ctx2)
|
|
|
|
listener2, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
t.Skipf("could not bind to same port: %v", err)
|
|
return
|
|
}
|
|
|
|
newServer := &httptest.Server{
|
|
Listener: listener2,
|
|
Config: &http.Server{Handler: hub2.Handler()},
|
|
}
|
|
newServer.Start()
|
|
defer newServer.Close()
|
|
|
|
// Wait for reconnection
|
|
select {
|
|
case attempt := <-reconnectCalled:
|
|
assert.Greater(t, attempt, 0)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("OnReconnect should have been called")
|
|
}
|
|
|
|
assert.Equal(t, StateConnected, rc.State())
|
|
clientCancel()
|
|
})
|
|
}
|
|
|
|
func TestReconnectingClient_MaxRetries(t *testing.T) {
|
|
t.Run("stops after max retries exceeded", func(t *testing.T) {
|
|
// Use a URL that will never connect
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: "ws://127.0.0.1:1", // Should refuse connection
|
|
InitialBackoff: 10 * time.Millisecond,
|
|
MaxBackoff: 50 * time.Millisecond,
|
|
MaxRetries: 3,
|
|
})
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- rc.Connect(context.Background())
|
|
}()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "max retries (3) exceeded")
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("should have stopped after max retries")
|
|
}
|
|
|
|
assert.Equal(t, StateDisconnected, rc.State())
|
|
})
|
|
}
|
|
|
|
func TestReconnectingClient_Send(t *testing.T) {
|
|
t.Run("sends message to server", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
connected := make(chan struct{}, 1)
|
|
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: wsURL,
|
|
OnConnect: func() {
|
|
select {
|
|
case connected <- struct{}{}:
|
|
default:
|
|
}
|
|
},
|
|
})
|
|
|
|
clientCtx, clientCancel := context.WithCancel(context.Background())
|
|
defer clientCancel()
|
|
go rc.Connect(clientCtx)
|
|
|
|
<-connected
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send a subscribe message
|
|
err := rc.Send(Message{
|
|
Type: TypeSubscribe,
|
|
Data: "test-channel",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel"))
|
|
|
|
clientCancel()
|
|
})
|
|
|
|
t.Run("returns error when not connected", func(t *testing.T) {
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: "ws://localhost:1",
|
|
})
|
|
|
|
err := rc.Send(Message{Type: TypePing})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "not connected")
|
|
})
|
|
|
|
t.Run("returns error for unmarshalable message", func(t *testing.T) {
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: "ws://localhost:1",
|
|
})
|
|
// Force a conn to be set so we get past the nil check
|
|
// to hit the marshal error first
|
|
err := rc.Send(Message{Type: TypeEvent, Data: make(chan int)})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to marshal message")
|
|
})
|
|
}
|
|
|
|
func TestReconnectingClient_Close(t *testing.T) {
|
|
t.Run("stops reconnection loop", func(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
|
|
|
connected := make(chan struct{}, 1)
|
|
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: wsURL,
|
|
OnConnect: func() {
|
|
select {
|
|
case connected <- struct{}{}:
|
|
default:
|
|
}
|
|
},
|
|
})
|
|
|
|
clientCtx := t.Context()
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- rc.Connect(clientCtx)
|
|
}()
|
|
|
|
<-connected
|
|
|
|
err := rc.Close()
|
|
assert.NoError(t, err)
|
|
|
|
select {
|
|
case <-done:
|
|
// Good — Connect returned
|
|
case <-time.After(time.Second):
|
|
t.Fatal("Connect should have returned after Close")
|
|
}
|
|
})
|
|
|
|
t.Run("close when not connected is safe", func(t *testing.T) {
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: "ws://localhost:1",
|
|
})
|
|
|
|
err := rc.Close()
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestReconnectingClient_ExponentialBackoff(t *testing.T) {
|
|
t.Run("calculates correct backoff values", func(t *testing.T) {
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: "ws://localhost:1",
|
|
InitialBackoff: 100 * time.Millisecond,
|
|
MaxBackoff: 1 * time.Second,
|
|
BackoffMultiplier: 2.0,
|
|
})
|
|
|
|
// attempt 1: 100ms
|
|
assert.Equal(t, 100*time.Millisecond, rc.calculateBackoff(1))
|
|
// attempt 2: 200ms
|
|
assert.Equal(t, 200*time.Millisecond, rc.calculateBackoff(2))
|
|
// attempt 3: 400ms
|
|
assert.Equal(t, 400*time.Millisecond, rc.calculateBackoff(3))
|
|
// attempt 4: 800ms
|
|
assert.Equal(t, 800*time.Millisecond, rc.calculateBackoff(4))
|
|
// attempt 5: capped at 1s
|
|
assert.Equal(t, 1*time.Second, rc.calculateBackoff(5))
|
|
// attempt 10: still capped at 1s
|
|
assert.Equal(t, 1*time.Second, rc.calculateBackoff(10))
|
|
})
|
|
}
|
|
|
|
func TestReconnectingClient_Defaults(t *testing.T) {
|
|
t.Run("applies defaults for zero config values", func(t *testing.T) {
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: "ws://localhost:1",
|
|
})
|
|
|
|
assert.Equal(t, 1*time.Second, rc.config.InitialBackoff)
|
|
assert.Equal(t, 30*time.Second, rc.config.MaxBackoff)
|
|
assert.Equal(t, 2.0, rc.config.BackoffMultiplier)
|
|
assert.NotNil(t, rc.config.Dialer)
|
|
})
|
|
}
|
|
|
|
func TestReconnectingClient_ContextCancel(t *testing.T) {
|
|
t.Run("returns context error on cancel during backoff", func(t *testing.T) {
|
|
rc := NewReconnectingClient(ReconnectConfig{
|
|
URL: "ws://127.0.0.1:1",
|
|
InitialBackoff: 10 * time.Second, // Long backoff
|
|
})
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- rc.Connect(ctx)
|
|
}()
|
|
|
|
// Allow first dial attempt to fail
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
// Cancel during backoff
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-done:
|
|
require.Error(t, err)
|
|
assert.Equal(t, context.Canceled, err)
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("Connect should have returned after context cancel")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConnectionState(t *testing.T) {
|
|
t.Run("state constants are distinct", func(t *testing.T) {
|
|
assert.NotEqual(t, StateDisconnected, StateConnecting)
|
|
assert.NotEqual(t, StateConnecting, StateConnected)
|
|
assert.NotEqual(t, StateDisconnected, StateConnected)
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Hub.Run lifecycle — register, broadcast delivery, unregister via channels
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestHubRun_RegisterClient_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- client
|
|
time.Sleep(20 * time.Millisecond)
|
|
|
|
assert.Equal(t, 1, hub.ClientCount(), "client should be registered via hub loop")
|
|
}
|
|
|
|
func TestHubRun_BroadcastDelivery_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- client
|
|
time.Sleep(20 * time.Millisecond)
|
|
|
|
err := hub.Broadcast(Message{Type: TypeEvent, Data: "lifecycle-test"})
|
|
require.NoError(t, err)
|
|
|
|
// Hub.Run loop delivers the broadcast to the client's send channel
|
|
select {
|
|
case msg := <-client.send:
|
|
var received Message
|
|
require.NoError(t, json.Unmarshal(msg, &received))
|
|
assert.Equal(t, TypeEvent, received.Type)
|
|
assert.Equal(t, "lifecycle-test", received.Data)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("broadcast should be delivered via hub loop")
|
|
}
|
|
}
|
|
|
|
func TestHubRun_UnregisterClient_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- client
|
|
time.Sleep(20 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
|
|
// Subscribe so we can verify channel cleanup
|
|
hub.Subscribe(client, "lifecycle-chan")
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("lifecycle-chan"))
|
|
|
|
hub.unregister <- client
|
|
time.Sleep(20 * time.Millisecond)
|
|
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
assert.Equal(t, 0, hub.ChannelSubscriberCount("lifecycle-chan"))
|
|
}
|
|
|
|
func TestHubRun_UnregisterIgnoresDuplicate_Bad(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.register <- client
|
|
time.Sleep(20 * time.Millisecond)
|
|
|
|
hub.unregister <- client
|
|
time.Sleep(20 * time.Millisecond)
|
|
|
|
// Second unregister should not panic or block
|
|
done := make(chan struct{})
|
|
go func() {
|
|
hub.unregister <- client
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// Good -- no panic, no block
|
|
case <-time.After(time.Second):
|
|
t.Fatal("duplicate unregister should not block")
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Subscribe / Unsubscribe — additional channel management tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestSubscribe_MultipleChannels_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.Subscribe(client, "alpha")
|
|
hub.Subscribe(client, "beta")
|
|
hub.Subscribe(client, "gamma")
|
|
|
|
assert.Equal(t, 3, hub.ChannelCount())
|
|
subs := client.Subscriptions()
|
|
assert.Len(t, subs, 3)
|
|
assert.Contains(t, subs, "alpha")
|
|
assert.Contains(t, subs, "beta")
|
|
assert.Contains(t, subs, "gamma")
|
|
}
|
|
|
|
func TestSubscribe_IdempotentDoubleSubscribe_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
|
|
hub.Subscribe(client, "dupl")
|
|
hub.Subscribe(client, "dupl")
|
|
|
|
// Still only one subscriber entry in the channel map
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("dupl"))
|
|
}
|
|
|
|
func TestUnsubscribe_PartialLeave_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
client1 := &Client{hub: hub, send: make(chan []byte, 256), subscriptions: make(map[string]bool)}
|
|
client2 := &Client{hub: hub, send: make(chan []byte, 256), subscriptions: make(map[string]bool)}
|
|
|
|
hub.Subscribe(client1, "shared")
|
|
hub.Subscribe(client2, "shared")
|
|
assert.Equal(t, 2, hub.ChannelSubscriberCount("shared"))
|
|
|
|
hub.Unsubscribe(client1, "shared")
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("shared"))
|
|
|
|
// Channel still exists because client2 is subscribed
|
|
hub.mu.RLock()
|
|
_, exists := hub.channels["shared"]
|
|
hub.mu.RUnlock()
|
|
assert.True(t, exists, "channel should persist while subscribers remain")
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// SendToChannel — multiple subscribers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestSendToChannel_MultipleSubscribers_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
clients := make([]*Client, 5)
|
|
for i := range clients {
|
|
clients[i] = &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
hub.Subscribe(clients[i], "multi")
|
|
}
|
|
|
|
err := hub.SendToChannel("multi", Message{Type: TypeEvent, Data: "fanout"})
|
|
require.NoError(t, err)
|
|
|
|
for i, c := range clients {
|
|
select {
|
|
case msg := <-c.send:
|
|
var received Message
|
|
require.NoError(t, json.Unmarshal(msg, &received))
|
|
assert.Equal(t, "multi", received.Channel)
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("client %d should have received the message", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// SendProcessOutput / SendProcessStatus — edge cases
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestSendProcessOutput_NoSubscribers_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
err := hub.SendProcessOutput("orphan-proc", "some output")
|
|
assert.NoError(t, err, "sending to a process with no subscribers should not error")
|
|
}
|
|
|
|
func TestSendProcessStatus_NonZeroExit_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
hub.Subscribe(client, "process:fail-1")
|
|
|
|
err := hub.SendProcessStatus("fail-1", "exited", 137)
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case msg := <-client.send:
|
|
var received Message
|
|
require.NoError(t, json.Unmarshal(msg, &received))
|
|
assert.Equal(t, TypeProcessStatus, received.Type)
|
|
assert.Equal(t, "fail-1", received.ProcessID)
|
|
data := received.Data.(map[string]any)
|
|
assert.Equal(t, "exited", data["status"])
|
|
assert.Equal(t, float64(137), data["exitCode"])
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected process status message")
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// readPump — ping with timestamp verification
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestReadPump_PingTimestamp_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
err = conn.WriteJSON(Message{Type: TypePing})
|
|
require.NoError(t, err)
|
|
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
var pong Message
|
|
err = conn.ReadJSON(&pong)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, TypePong, pong.Type)
|
|
assert.False(t, pong.Timestamp.IsZero(), "pong should include a timestamp")
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// writePump — batch sending with multiple messages
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestWritePump_BatchMultipleMessages_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Rapidly send multiple broadcasts so they queue up
|
|
numMessages := 10
|
|
for i := range numMessages {
|
|
err := hub.Broadcast(Message{
|
|
Type: TypeEvent,
|
|
Data: fmt.Sprintf("batch-%d", i),
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Read all messages — batched with newline separators
|
|
received := 0
|
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
for received < numMessages {
|
|
_, raw, err := conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
parts := strings.SplitSeq(string(raw), "\n")
|
|
for part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
continue
|
|
}
|
|
var msg Message
|
|
if json.Unmarshal([]byte(part), &msg) == nil {
|
|
received++
|
|
}
|
|
}
|
|
}
|
|
|
|
assert.Equal(t, numMessages, received, "all batched messages should be received")
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Integration — unsubscribe stops delivery
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestIntegration_UnsubscribeStopsDelivery_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Subscribe
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "temp:feed"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Verify we receive messages on the channel
|
|
err = hub.SendToChannel("temp:feed", Message{Type: TypeEvent, Data: "before-unsub"})
|
|
require.NoError(t, err)
|
|
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
var msg1 Message
|
|
err = conn.ReadJSON(&msg1)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "before-unsub", msg1.Data)
|
|
|
|
// Unsubscribe
|
|
err = conn.WriteJSON(Message{Type: TypeUnsubscribe, Data: "temp:feed"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send another message -- client should NOT receive it
|
|
err = hub.SendToChannel("temp:feed", Message{Type: TypeEvent, Data: "after-unsub"})
|
|
require.NoError(t, err)
|
|
|
|
// Try to read -- should timeout (no message delivered)
|
|
conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
|
var msg2 Message
|
|
err = conn.ReadJSON(&msg2)
|
|
assert.Error(t, err, "should not receive messages after unsubscribing")
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Integration — broadcast reaches all clients (no channel subscription)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestIntegration_BroadcastReachesAllClients_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
numClients := 3
|
|
conns := make([]*websocket.Conn, numClients)
|
|
for i := range numClients {
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
conns[i] = conn
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
assert.Equal(t, numClients, hub.ClientCount())
|
|
|
|
// Broadcast -- no channel subscription needed
|
|
err := hub.Broadcast(Message{Type: TypeError, Data: "global-alert"})
|
|
require.NoError(t, err)
|
|
|
|
for i, conn := range conns {
|
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
var received Message
|
|
err := conn.ReadJSON(&received)
|
|
require.NoError(t, err, "client %d should receive broadcast", i)
|
|
assert.Equal(t, TypeError, received.Type)
|
|
assert.Equal(t, "global-alert", received.Data)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Integration — disconnect cleans up all subscriptions
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestIntegration_DisconnectCleansUpEverything_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
defer server.Close()
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server), nil)
|
|
require.NoError(t, err)
|
|
|
|
// Subscribe to multiple channels
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "ch-a"})
|
|
require.NoError(t, err)
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "ch-b"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("ch-a"))
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("ch-b"))
|
|
|
|
// Disconnect
|
|
conn.Close()
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
assert.Equal(t, 0, hub.ChannelSubscriberCount("ch-a"))
|
|
assert.Equal(t, 0, hub.ChannelSubscriberCount("ch-b"))
|
|
assert.Equal(t, 0, hub.ChannelCount(), "empty channels should be cleaned up")
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Concurrent broadcast + subscribe via hub loop (race test)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestConcurrentSubscribeAndBroadcast_Good(t *testing.T) {
|
|
hub := NewHub()
|
|
ctx := t.Context()
|
|
go hub.Run(ctx)
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
for i := range 50 {
|
|
wg.Add(2)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
client := &Client{
|
|
hub: hub,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
hub.register <- client
|
|
}(i)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
_ = hub.Broadcast(Message{Type: TypeEvent, Data: id})
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
assert.Equal(t, 50, hub.ClientCount())
|
|
}
|