commit e9425004e2c2378acedf579faa943cd9dd31dd1e Author: Snider Date: Thu Feb 19 16:08:52 2026 +0000 feat: extract go-ws from core/go pkg/ws WebSocket hub with channel-based pub/sub. Zero internal dependencies. Module: forge.lthn.ai/core/go-ws Co-Authored-By: Virgil diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..fa72d28 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,25 @@ +# CLAUDE.md + +## What This Is + +WebSocket hub with channel-based pub/sub for real-time streaming. Module: `forge.lthn.ai/core/go-ws` + +## Commands + +```bash +go test ./... # Run all tests +go test -v -run Name # Run single test +``` + +## Architecture + +- `Hub` manages WebSocket connections and channel subscriptions +- Messages types: `process_output`, `process_status`, `event`, `error`, `ping/pong`, `subscribe/unsubscribe` +- `hub.SendProcessOutput(id, line)` broadcasts to subscribers + +## Coding Standards + +- UK English +- `go test ./...` must pass before commit +- Conventional commits: `type(scope): description` +- Co-Author: `Co-Authored-By: Virgil ` diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6889543 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module forge.lthn.ai/core/go-ws + +go 1.25.5 + +require ( + github.com/gorilla/websocket v1.5.3 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4b33f39 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ws.go b/ws.go new file mode 100644 index 0000000..16dd6f7 --- /dev/null +++ b/ws.go @@ -0,0 +1,465 @@ +// Package ws provides WebSocket support for real-time streaming. +// +// The ws package enables live process output, events, and bidirectional communication +// between the Go backend and web frontends. It implements a hub pattern for managing +// WebSocket connections and channel-based subscriptions. +// +// # Getting Started +// +// hub := ws.NewHub() +// go hub.Run(ctx) +// +// // Register HTTP handler +// http.HandleFunc("/ws", hub.Handler()) +// +// # Message Types +// +// The package defines several message types for different purposes: +// - TypeProcessOutput: Real-time process output streaming +// - TypeProcessStatus: Process status updates (running, exited, etc.) +// - TypeEvent: Generic events +// - TypeError: Error messages +// - TypePing/TypePong: Keep-alive messages +// - TypeSubscribe/TypeUnsubscribe: Channel subscription management +// +// # Channel Subscriptions +// +// Clients can subscribe to specific channels to receive targeted messages: +// +// // Client sends: {"type": "subscribe", "data": "process:proc-1"} +// // Server broadcasts only to subscribers of "process:proc-1" +// +// # Integration with Core +// +// The Hub can receive process events via Core.ACTION and forward them to WebSocket clients: +// +// core.RegisterAction(func(c *framework.Core, msg framework.Message) error { +// switch m := msg.(type) { +// case process.ActionProcessOutput: +// hub.SendProcessOutput(m.ID, m.Line) +// case process.ActionProcessExited: +// hub.SendProcessStatus(m.ID, "exited", m.ExitCode) +// } +// return nil +// }) +package ws + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins for local development + }, +} + +// MessageType identifies the type of WebSocket message. +type MessageType string + +const ( + // TypeProcessOutput indicates real-time process output. + TypeProcessOutput MessageType = "process_output" + // TypeProcessStatus indicates a process status change. + TypeProcessStatus MessageType = "process_status" + // TypeEvent indicates a generic event. + TypeEvent MessageType = "event" + // TypeError indicates an error message. + TypeError MessageType = "error" + // TypePing is a client-to-server keep-alive request. + TypePing MessageType = "ping" + // TypePong is the server response to ping. + TypePong MessageType = "pong" + // TypeSubscribe requests subscription to a channel. + TypeSubscribe MessageType = "subscribe" + // TypeUnsubscribe requests unsubscription from a channel. + TypeUnsubscribe MessageType = "unsubscribe" +) + +// Message is the standard WebSocket message format. +type Message struct { + Type MessageType `json:"type"` + Channel string `json:"channel,omitempty"` + ProcessID string `json:"processId,omitempty"` + Data any `json:"data,omitempty"` + Timestamp time.Time `json:"timestamp"` +} + +// Client represents a connected WebSocket client. +type Client struct { + hub *Hub + conn *websocket.Conn + send chan []byte + subscriptions map[string]bool + mu sync.RWMutex +} + +// Hub manages WebSocket connections and message broadcasting. +type Hub struct { + clients map[*Client]bool + broadcast chan []byte + register chan *Client + unregister chan *Client + channels map[string]map[*Client]bool + mu sync.RWMutex +} + +// NewHub creates a new WebSocket hub. +func NewHub() *Hub { + return &Hub{ + clients: make(map[*Client]bool), + broadcast: make(chan []byte, 256), + register: make(chan *Client), + unregister: make(chan *Client), + channels: make(map[string]map[*Client]bool), + } +} + +// Run starts the hub's main loop. It should be called in a goroutine. +// The loop exits when the context is canceled. +func (h *Hub) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + // Close all client connections on shutdown + h.mu.Lock() + for client := range h.clients { + close(client.send) + delete(h.clients, client) + } + h.mu.Unlock() + return + case client := <-h.register: + h.mu.Lock() + h.clients[client] = true + h.mu.Unlock() + case client := <-h.unregister: + h.mu.Lock() + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + // Remove from all channels + for channel := range client.subscriptions { + if clients, ok := h.channels[channel]; ok { + delete(clients, client) + // Clean up empty channels + if len(clients) == 0 { + delete(h.channels, channel) + } + } + } + } + h.mu.Unlock() + case message := <-h.broadcast: + h.mu.RLock() + for client := range h.clients { + select { + case client.send <- message: + default: + // Client buffer full, will be cleaned up + go func(c *Client) { + h.unregister <- c + }(client) + } + } + h.mu.RUnlock() + } + } +} + +// Subscribe adds a client to a channel. +func (h *Hub) Subscribe(client *Client, channel string) { + h.mu.Lock() + defer h.mu.Unlock() + + if _, ok := h.channels[channel]; !ok { + h.channels[channel] = make(map[*Client]bool) + } + h.channels[channel][client] = true + + client.mu.Lock() + client.subscriptions[channel] = true + client.mu.Unlock() +} + +// Unsubscribe removes a client from a channel. +func (h *Hub) Unsubscribe(client *Client, channel string) { + h.mu.Lock() + defer h.mu.Unlock() + + if clients, ok := h.channels[channel]; ok { + delete(clients, client) + // Clean up empty channels + if len(clients) == 0 { + delete(h.channels, channel) + } + } + + client.mu.Lock() + delete(client.subscriptions, channel) + client.mu.Unlock() +} + +// Broadcast sends a message to all connected clients. +func (h *Hub) Broadcast(msg Message) error { + msg.Timestamp = time.Now() + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + select { + case h.broadcast <- data: + default: + return fmt.Errorf("broadcast channel full") + } + return nil +} + +// SendToChannel sends a message to all clients subscribed to a channel. +func (h *Hub) SendToChannel(channel string, msg Message) error { + msg.Timestamp = time.Now() + msg.Channel = channel + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + h.mu.RLock() + clients, ok := h.channels[channel] + h.mu.RUnlock() + + if !ok { + return nil // No subscribers, not an error + } + + for client := range clients { + select { + case client.send <- data: + default: + // Client buffer full, skip + } + } + return nil +} + +// SendProcessOutput sends process output to subscribers of the process channel. +func (h *Hub) SendProcessOutput(processID string, output string) error { + return h.SendToChannel("process:"+processID, Message{ + Type: TypeProcessOutput, + ProcessID: processID, + Data: output, + }) +} + +// SendProcessStatus sends a process status update to subscribers. +func (h *Hub) SendProcessStatus(processID string, status string, exitCode int) error { + return h.SendToChannel("process:"+processID, Message{ + Type: TypeProcessStatus, + ProcessID: processID, + Data: map[string]any{ + "status": status, + "exitCode": exitCode, + }, + }) +} + +// SendError sends an error message to all connected clients. +func (h *Hub) SendError(errMsg string) error { + return h.Broadcast(Message{ + Type: TypeError, + Data: errMsg, + }) +} + +// SendEvent sends a generic event to all connected clients. +func (h *Hub) SendEvent(eventType string, data any) error { + return h.Broadcast(Message{ + Type: TypeEvent, + Data: map[string]any{ + "event": eventType, + "data": data, + }, + }) +} + +// ClientCount returns the number of connected clients. +func (h *Hub) ClientCount() int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.clients) +} + +// ChannelCount returns the number of active channels. +func (h *Hub) ChannelCount() int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.channels) +} + +// ChannelSubscriberCount returns the number of subscribers for a channel. +func (h *Hub) ChannelSubscriberCount(channel string) int { + h.mu.RLock() + defer h.mu.RUnlock() + if clients, ok := h.channels[channel]; ok { + return len(clients) + } + return 0 +} + +// HubStats contains hub statistics. +type HubStats struct { + Clients int `json:"clients"` + Channels int `json:"channels"` +} + +// Stats returns current hub statistics. +func (h *Hub) Stats() HubStats { + h.mu.RLock() + defer h.mu.RUnlock() + return HubStats{ + Clients: len(h.clients), + Channels: len(h.channels), + } +} + +// HandleWebSocket is an alias for Handler for clearer API. +func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + h.Handler()(w, r) +} + +// Handler returns an HTTP handler for WebSocket connections. +func (h *Hub) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + client := &Client{ + hub: h, + conn: conn, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + h.register <- client + + go client.writePump() + go client.readPump() + } +} + +// readPump handles incoming messages from the client. +func (c *Client) readPump() { + defer func() { + c.hub.unregister <- c + c.conn.Close() + }() + + c.conn.SetReadLimit(65536) + c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.conn.SetPongHandler(func(string) error { + c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + break + } + + var msg Message + if err := json.Unmarshal(message, &msg); err != nil { + continue + } + + switch msg.Type { + case TypeSubscribe: + if channel, ok := msg.Data.(string); ok { + c.hub.Subscribe(c, channel) + } + case TypeUnsubscribe: + if channel, ok := msg.Data.(string); ok { + c.hub.Unsubscribe(c, channel) + } + case TypePing: + c.send <- mustMarshal(Message{Type: TypePong, Timestamp: time.Now()}) + } + } +} + +// writePump sends messages to the client. +func (c *Client) writePump() { + ticker := time.NewTicker(30 * time.Second) + defer func() { + ticker.Stop() + c.conn.Close() + }() + + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if !ok { + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + w.Write(message) + + // Batch queued messages + n := len(c.send) + for i := 0; i < n; i++ { + w.Write([]byte{'\n'}) + w.Write(<-c.send) + } + + if err := w.Close(); err != nil { + return + } + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +func mustMarshal(v any) []byte { + data, _ := json.Marshal(v) + return data +} + +// Subscriptions returns a copy of the client's current subscriptions. +func (c *Client) Subscriptions() []string { + c.mu.RLock() + defer c.mu.RUnlock() + + result := make([]string, 0, len(c.subscriptions)) + for channel := range c.subscriptions { + result = append(result, channel) + } + return result +} + +// Close closes the client connection. +func (c *Client) Close() error { + c.hub.unregister <- c + return c.conn.Close() +} diff --git a/ws_test.go b/ws_test.go new file mode 100644 index 0000000..0632568 --- /dev/null +++ b/ws_test.go @@ -0,0 +1,792 @@ +package ws + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewHub(t *testing.T) { + t.Run("creates hub with initialized maps", func(t *testing.T) { + hub := NewHub() + + require.NotNil(t, hub) + assert.NotNil(t, hub.clients) + assert.NotNil(t, hub.broadcast) + assert.NotNil(t, hub.register) + assert.NotNil(t, hub.unregister) + assert.NotNil(t, hub.channels) + }) +} + +func TestHub_Run(t *testing.T) { + t.Run("stops on context cancel", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + hub.Run(ctx) + close(done) + }() + + cancel() + + select { + case <-done: + // Good - hub stopped + case <-time.After(time.Second): + t.Fatal("hub should have stopped on context cancel") + } + }) +} + +func TestHub_Broadcast(t *testing.T) { + t.Run("marshals message with timestamp", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + msg := Message{ + Type: TypeEvent, + Data: "test data", + } + + err := hub.Broadcast(msg) + require.NoError(t, err) + }) + + t.Run("returns error when channel full", func(t *testing.T) { + hub := NewHub() + // Fill the broadcast channel + for i := 0; i < 256; i++ { + hub.broadcast <- []byte("test") + } + + err := hub.Broadcast(Message{Type: TypeEvent}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "broadcast channel full") + }) +} + +func TestHub_Stats(t *testing.T) { + t.Run("returns empty stats for new hub", func(t *testing.T) { + hub := NewHub() + + stats := hub.Stats() + + assert.Equal(t, 0, stats.Clients) + assert.Equal(t, 0, stats.Channels) + }) + + t.Run("tracks client and channel counts", func(t *testing.T) { + hub := NewHub() + + // Manually add clients for testing + hub.mu.Lock() + client1 := &Client{subscriptions: make(map[string]bool)} + client2 := &Client{subscriptions: make(map[string]bool)} + hub.clients[client1] = true + hub.clients[client2] = true + hub.channels["test-channel"] = make(map[*Client]bool) + hub.mu.Unlock() + + stats := hub.Stats() + + assert.Equal(t, 2, stats.Clients) + assert.Equal(t, 1, stats.Channels) + }) +} + +func TestHub_ClientCount(t *testing.T) { + t.Run("returns zero for empty hub", func(t *testing.T) { + hub := NewHub() + assert.Equal(t, 0, hub.ClientCount()) + }) + + t.Run("counts connected clients", func(t *testing.T) { + hub := NewHub() + + hub.mu.Lock() + hub.clients[&Client{}] = true + hub.clients[&Client{}] = true + hub.mu.Unlock() + + assert.Equal(t, 2, hub.ClientCount()) + }) +} + +func TestHub_ChannelCount(t *testing.T) { + t.Run("returns zero for empty hub", func(t *testing.T) { + hub := NewHub() + assert.Equal(t, 0, hub.ChannelCount()) + }) + + t.Run("counts active channels", func(t *testing.T) { + hub := NewHub() + + hub.mu.Lock() + hub.channels["channel1"] = make(map[*Client]bool) + hub.channels["channel2"] = make(map[*Client]bool) + hub.mu.Unlock() + + assert.Equal(t, 2, hub.ChannelCount()) + }) +} + +func TestHub_ChannelSubscriberCount(t *testing.T) { + t.Run("returns zero for non-existent channel", func(t *testing.T) { + hub := NewHub() + assert.Equal(t, 0, hub.ChannelSubscriberCount("non-existent")) + }) + + t.Run("counts subscribers in channel", func(t *testing.T) { + hub := NewHub() + + hub.mu.Lock() + hub.channels["test-channel"] = make(map[*Client]bool) + hub.channels["test-channel"][&Client{}] = true + hub.channels["test-channel"][&Client{}] = true + hub.mu.Unlock() + + assert.Equal(t, 2, hub.ChannelSubscriberCount("test-channel")) + }) +} + +func TestHub_Subscribe(t *testing.T) { + t.Run("adds client to channel", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + + hub.Subscribe(client, "test-channel") + + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + assert.True(t, client.subscriptions["test-channel"]) + }) + + t.Run("creates channel if not exists", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + hub.Subscribe(client, "new-channel") + + hub.mu.RLock() + _, exists := hub.channels["new-channel"] + hub.mu.RUnlock() + + assert.True(t, exists) + }) +} + +func TestHub_Unsubscribe(t *testing.T) { + t.Run("removes client from channel", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + hub.Subscribe(client, "test-channel") + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + + hub.Unsubscribe(client, "test-channel") + assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel")) + assert.False(t, client.subscriptions["test-channel"]) + }) + + t.Run("cleans up empty channels", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + hub.Subscribe(client, "temp-channel") + hub.Unsubscribe(client, "temp-channel") + + hub.mu.RLock() + _, exists := hub.channels["temp-channel"] + hub.mu.RUnlock() + + assert.False(t, exists, "empty channel should be removed") + }) + + t.Run("handles non-existent channel gracefully", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + // Should not panic + hub.Unsubscribe(client, "non-existent") + }) +} + +func TestHub_SendToChannel(t *testing.T) { + t.Run("sends to channel subscribers", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + hub.Subscribe(client, "test-channel") + + err := hub.SendToChannel("test-channel", Message{ + Type: TypeEvent, + Data: "test", + }) + require.NoError(t, err) + + select { + case msg := <-client.send: + var received Message + err := json.Unmarshal(msg, &received) + require.NoError(t, err) + assert.Equal(t, TypeEvent, received.Type) + assert.Equal(t, "test-channel", received.Channel) + case <-time.After(time.Second): + t.Fatal("expected message on client send channel") + } + }) + + t.Run("returns nil for non-existent channel", func(t *testing.T) { + hub := NewHub() + + err := hub.SendToChannel("non-existent", Message{Type: TypeEvent}) + assert.NoError(t, err, "should not error for non-existent channel") + }) +} + +func TestHub_SendProcessOutput(t *testing.T) { + t.Run("sends output to process channel", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + hub.Subscribe(client, "process:proc-1") + + err := hub.SendProcessOutput("proc-1", "hello world") + require.NoError(t, err) + + select { + case msg := <-client.send: + var received Message + err := json.Unmarshal(msg, &received) + require.NoError(t, err) + assert.Equal(t, TypeProcessOutput, received.Type) + assert.Equal(t, "proc-1", received.ProcessID) + assert.Equal(t, "hello world", received.Data) + case <-time.After(time.Second): + t.Fatal("expected message on client send channel") + } + }) +} + +func TestHub_SendProcessStatus(t *testing.T) { + t.Run("sends status to process channel", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + hub.Subscribe(client, "process:proc-1") + + err := hub.SendProcessStatus("proc-1", "exited", 0) + require.NoError(t, err) + + select { + case msg := <-client.send: + var received Message + err := json.Unmarshal(msg, &received) + require.NoError(t, err) + assert.Equal(t, TypeProcessStatus, received.Type) + assert.Equal(t, "proc-1", received.ProcessID) + + data, ok := received.Data.(map[string]any) + require.True(t, ok) + assert.Equal(t, "exited", data["status"]) + assert.Equal(t, float64(0), data["exitCode"]) + case <-time.After(time.Second): + t.Fatal("expected message on client send channel") + } + }) +} + +func TestHub_SendError(t *testing.T) { + t.Run("broadcasts error message", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.register <- client + // Give time for registration + time.Sleep(10 * time.Millisecond) + + err := hub.SendError("something went wrong") + require.NoError(t, err) + + select { + case msg := <-client.send: + var received Message + err := json.Unmarshal(msg, &received) + require.NoError(t, err) + assert.Equal(t, TypeError, received.Type) + assert.Equal(t, "something went wrong", received.Data) + case <-time.After(time.Second): + t.Fatal("expected error message on client send channel") + } + }) +} + +func TestHub_SendEvent(t *testing.T) { + t.Run("broadcasts event message", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.register <- client + time.Sleep(10 * time.Millisecond) + + err := hub.SendEvent("user_joined", map[string]string{"user": "alice"}) + require.NoError(t, err) + + select { + case msg := <-client.send: + var received Message + err := json.Unmarshal(msg, &received) + require.NoError(t, err) + assert.Equal(t, TypeEvent, received.Type) + + data, ok := received.Data.(map[string]any) + require.True(t, ok) + assert.Equal(t, "user_joined", data["event"]) + case <-time.After(time.Second): + t.Fatal("expected event message on client send channel") + } + }) +} + +func TestClient_Subscriptions(t *testing.T) { + t.Run("returns copy of subscriptions", func(t *testing.T) { + hub := NewHub() + client := &Client{ + hub: hub, + subscriptions: make(map[string]bool), + } + + hub.Subscribe(client, "channel1") + hub.Subscribe(client, "channel2") + + subs := client.Subscriptions() + + assert.Len(t, subs, 2) + assert.Contains(t, subs, "channel1") + assert.Contains(t, subs, "channel2") + }) +} + +func TestMessage_JSON(t *testing.T) { + t.Run("marshals correctly", func(t *testing.T) { + msg := Message{ + Type: TypeProcessOutput, + Channel: "process:1", + ProcessID: "1", + Data: "output line", + Timestamp: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + assert.Contains(t, string(data), `"type":"process_output"`) + assert.Contains(t, string(data), `"channel":"process:1"`) + assert.Contains(t, string(data), `"processId":"1"`) + assert.Contains(t, string(data), `"data":"output line"`) + }) + + t.Run("unmarshals correctly", func(t *testing.T) { + jsonStr := `{"type":"subscribe","data":"channel:test"}` + + var msg Message + err := json.Unmarshal([]byte(jsonStr), &msg) + require.NoError(t, err) + + assert.Equal(t, TypeSubscribe, msg.Type) + assert.Equal(t, "channel:test", msg.Data) + }) +} + +func TestHub_WebSocketHandler(t *testing.T) { + t.Run("upgrades connection and registers client", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + // Give time for registration + time.Sleep(50 * time.Millisecond) + + assert.Equal(t, 1, hub.ClientCount()) + }) + + t.Run("handles subscribe message", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + // Send subscribe message + subscribeMsg := Message{ + Type: TypeSubscribe, + Data: "test-channel", + } + err = conn.WriteJSON(subscribeMsg) + require.NoError(t, err) + + // Give time for subscription + time.Sleep(50 * time.Millisecond) + + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + }) + + t.Run("handles unsubscribe message", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + // Subscribe first + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) + require.NoError(t, err) + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + + // Unsubscribe + err = conn.WriteJSON(Message{Type: TypeUnsubscribe, Data: "test-channel"}) + require.NoError(t, err) + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel")) + }) + + t.Run("responds to ping with pong", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + // Give time for registration + time.Sleep(50 * time.Millisecond) + + // Send ping + err = conn.WriteJSON(Message{Type: TypePing}) + require.NoError(t, err) + + // Read pong response + var response Message + conn.SetReadDeadline(time.Now().Add(time.Second)) + err = conn.ReadJSON(&response) + require.NoError(t, err) + + assert.Equal(t, TypePong, response.Type) + }) + + t.Run("broadcasts messages to clients", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + // Give time for registration + time.Sleep(50 * time.Millisecond) + + // Broadcast a message + err = hub.Broadcast(Message{ + Type: TypeEvent, + Data: "broadcast test", + }) + require.NoError(t, err) + + // Read the broadcast + var response Message + conn.SetReadDeadline(time.Now().Add(time.Second)) + err = conn.ReadJSON(&response) + require.NoError(t, err) + + assert.Equal(t, TypeEvent, response.Type) + assert.Equal(t, "broadcast test", response.Data) + }) + + t.Run("unregisters client on connection close", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + + // Wait for registration + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ClientCount()) + + // Close connection + conn.Close() + + // Wait for unregistration + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, hub.ClientCount()) + }) + + t.Run("removes client from channels on disconnect", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + + // Subscribe to channel + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "test-channel"}) + require.NoError(t, err) + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ChannelSubscriberCount("test-channel")) + + // Close connection + conn.Close() + time.Sleep(50 * time.Millisecond) + + // Channel should be cleaned up + assert.Equal(t, 0, hub.ChannelSubscriberCount("test-channel")) + }) +} + +func TestHub_Concurrency(t *testing.T) { + t.Run("handles concurrent subscriptions", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + var wg sync.WaitGroup + numClients := 100 + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + client := &Client{ + hub: hub, + send: make(chan []byte, 256), + subscriptions: make(map[string]bool), + } + + hub.mu.Lock() + hub.clients[client] = true + hub.mu.Unlock() + + hub.Subscribe(client, "shared-channel") + hub.Subscribe(client, "shared-channel") // Double subscribe should be safe + }(i) + } + + wg.Wait() + + assert.Equal(t, numClients, hub.ChannelSubscriberCount("shared-channel")) + }) + + t.Run("handles concurrent broadcasts", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + client := &Client{ + hub: hub, + send: make(chan []byte, 1000), + subscriptions: make(map[string]bool), + } + + hub.register <- client + time.Sleep(10 * time.Millisecond) + + var wg sync.WaitGroup + numBroadcasts := 100 + + for i := 0; i < numBroadcasts; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + _ = hub.Broadcast(Message{ + Type: TypeEvent, + Data: id, + }) + }(i) + } + + wg.Wait() + + // Give time for broadcasts to be delivered + time.Sleep(100 * time.Millisecond) + + // Count received messages + received := 0 + timeout := time.After(100 * time.Millisecond) + loop: + for { + select { + case <-client.send: + received++ + case <-timeout: + break loop + } + } + + // All or most broadcasts should be received + assert.GreaterOrEqual(t, received, numBroadcasts-10, "should receive most broadcasts") + }) +} + +func TestHub_HandleWebSocket(t *testing.T) { + t.Run("alias works same as Handler", func(t *testing.T) { + hub := NewHub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go hub.Run(ctx) + + // Test with HandleWebSocket directly + server := httptest.NewServer(http.HandlerFunc(hub.HandleWebSocket)) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ClientCount()) + }) +} + +func TestMustMarshal(t *testing.T) { + t.Run("marshals valid data", func(t *testing.T) { + data := mustMarshal(Message{Type: TypePong}) + assert.Contains(t, string(data), "pong") + }) + + t.Run("handles unmarshalable data without panic", func(t *testing.T) { + // Create a channel which cannot be marshaled + // This should not panic, even if it returns nil + ch := make(chan int) + assert.NotPanics(t, func() { + _ = mustMarshal(ch) + }) + }) +}