go-ws/ws.go
Snider 3d6f14cd91
All checks were successful
Security Scan / security (push) Successful in 9s
Test / test (push) Successful in 3m49s
feat: modernise to Go 1.26 iterators and stdlib helpers
Add AllClients, AllChannels, AllSubscriptions iter.Seq iterators.
Use slices.Collect(maps.Keys()) in SendToChannel/Subscriptions,
range-over-int in calculateBackoff and benchmarks.

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 05:32:08 +00:00

848 lines
21 KiB
Go

// SPDX-Licence-Identifier: EUPL-1.2
// Package ws provides WebSocket support for real-time streaming.
//
// The ws package enables live process output, events, and bidirectional communication
// between the Go backend and web frontends. It implements a hub pattern for managing
// WebSocket connections and channel-based subscriptions.
//
// # Getting Started
//
// hub := ws.NewHub()
// go hub.Run(ctx)
//
// // Register HTTP handler
// http.HandleFunc("/ws", hub.Handler())
//
// # Authentication
//
// The hub supports optional token-based authentication on upgrade. Supply an
// Authenticator via HubConfig to gate connections:
//
// auth := ws.NewAPIKeyAuth(map[string]string{"secret-key": "user-1"})
// hub := ws.NewHubWithConfig(ws.HubConfig{Authenticator: auth})
// go hub.Run(ctx)
//
// When no Authenticator is set (nil), all connections are accepted — preserving
// backward compatibility.
//
// # Message Types
//
// The package defines several message types for different purposes:
// - TypeProcessOutput: Real-time process output streaming
// - TypeProcessStatus: Process status updates (running, exited, etc.)
// - TypeEvent: Generic events
// - TypeError: Error messages
// - TypePing/TypePong: Keep-alive messages
// - TypeSubscribe/TypeUnsubscribe: Channel subscription management
//
// # Channel Subscriptions
//
// Clients can subscribe to specific channels to receive targeted messages:
//
// // Client sends: {"type": "subscribe", "data": "process:proc-1"}
// // Server broadcasts only to subscribers of "process:proc-1"
//
// # Integration with Core
//
// The Hub can receive process events via Core.ACTION and forward them to WebSocket clients:
//
// core.RegisterAction(func(c *framework.Core, msg framework.Message) error {
// switch m := msg.(type) {
// case process.ActionProcessOutput:
// hub.SendProcessOutput(m.ID, m.Line)
// case process.ActionProcessExited:
// hub.SendProcessStatus(m.ID, "exited", m.ExitCode)
// }
// return nil
// })
package ws
import (
"context"
"encoding/json"
"fmt"
"iter"
"maps"
"net/http"
"slices"
"sync"
"time"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for local development
},
}
// Default timing values for heartbeat and pong timeout.
const (
DefaultHeartbeatInterval = 30 * time.Second
DefaultPongTimeout = 60 * time.Second
DefaultWriteTimeout = 10 * time.Second
)
// ConnectionState represents the current state of a reconnecting client.
type ConnectionState int
const (
// StateDisconnected indicates the client is not connected.
StateDisconnected ConnectionState = iota
// StateConnecting indicates the client is attempting to connect.
StateConnecting
// StateConnected indicates the client has an active connection.
StateConnected
)
// HubConfig holds configuration for the Hub and its managed connections.
type HubConfig struct {
// HeartbeatInterval is the interval between server-side ping messages.
// Defaults to 30 seconds.
HeartbeatInterval time.Duration
// PongTimeout is how long the server waits for a pong before
// considering the connection dead. Must be greater than HeartbeatInterval.
// Defaults to 60 seconds.
PongTimeout time.Duration
// WriteTimeout is the deadline for write operations.
// Defaults to 10 seconds.
WriteTimeout time.Duration
// OnConnect is called when a client connects to the hub.
OnConnect func(client *Client)
// OnDisconnect is called when a client disconnects from the hub.
OnDisconnect func(client *Client)
// Authenticator validates incoming WebSocket connections during the
// HTTP upgrade handshake. When nil, all connections are accepted
// (backward compatible). When set, connections that fail authentication
// receive an HTTP 401 response and are not upgraded.
Authenticator Authenticator
// OnAuthFailure is called when a connection is rejected by the
// Authenticator. Useful for logging or metrics. Optional.
OnAuthFailure func(r *http.Request, result AuthResult)
}
// DefaultHubConfig returns a HubConfig with sensible defaults.
func DefaultHubConfig() HubConfig {
return HubConfig{
HeartbeatInterval: DefaultHeartbeatInterval,
PongTimeout: DefaultPongTimeout,
WriteTimeout: DefaultWriteTimeout,
}
}
// MessageType identifies the type of WebSocket message.
type MessageType string
const (
// TypeProcessOutput indicates real-time process output.
TypeProcessOutput MessageType = "process_output"
// TypeProcessStatus indicates a process status change.
TypeProcessStatus MessageType = "process_status"
// TypeEvent indicates a generic event.
TypeEvent MessageType = "event"
// TypeError indicates an error message.
TypeError MessageType = "error"
// TypePing is a client-to-server keep-alive request.
TypePing MessageType = "ping"
// TypePong is the server response to ping.
TypePong MessageType = "pong"
// TypeSubscribe requests subscription to a channel.
TypeSubscribe MessageType = "subscribe"
// TypeUnsubscribe requests unsubscription from a channel.
TypeUnsubscribe MessageType = "unsubscribe"
)
// Message is the standard WebSocket message format.
type Message struct {
Type MessageType `json:"type"`
Channel string `json:"channel,omitempty"`
ProcessID string `json:"processId,omitempty"`
Data any `json:"data,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// Client represents a connected WebSocket client.
type Client struct {
hub *Hub
conn *websocket.Conn
send chan []byte
subscriptions map[string]bool
mu sync.RWMutex
// UserID is the authenticated user's identifier, set during the
// upgrade handshake when an Authenticator is configured. Empty
// when no Authenticator is set.
UserID string
// Claims holds arbitrary authentication metadata (e.g. roles,
// scopes). Nil when no Authenticator is set.
Claims map[string]any
}
// Hub manages WebSocket connections and message broadcasting.
type Hub struct {
clients map[*Client]bool
broadcast chan []byte
register chan *Client
unregister chan *Client
channels map[string]map[*Client]bool
config HubConfig
mu sync.RWMutex
}
// NewHub creates a new WebSocket hub with default configuration.
func NewHub() *Hub {
return NewHubWithConfig(DefaultHubConfig())
}
// NewHubWithConfig creates a new WebSocket hub with the given configuration.
func NewHubWithConfig(config HubConfig) *Hub {
if config.HeartbeatInterval <= 0 {
config.HeartbeatInterval = DefaultHeartbeatInterval
}
if config.PongTimeout <= 0 {
config.PongTimeout = DefaultPongTimeout
}
if config.WriteTimeout <= 0 {
config.WriteTimeout = DefaultWriteTimeout
}
return &Hub{
clients: make(map[*Client]bool),
broadcast: make(chan []byte, 256),
register: make(chan *Client),
unregister: make(chan *Client),
channels: make(map[string]map[*Client]bool),
config: config,
}
}
// Run starts the hub's main loop. It should be called in a goroutine.
// The loop exits when the context is canceled.
func (h *Hub) Run(ctx context.Context) {
for {
select {
case <-ctx.Done():
// Close all client connections on shutdown
h.mu.Lock()
for client := range h.clients {
close(client.send)
delete(h.clients, client)
}
h.mu.Unlock()
return
case client := <-h.register:
h.mu.Lock()
h.clients[client] = true
h.mu.Unlock()
if h.config.OnConnect != nil {
h.config.OnConnect(client)
}
case client := <-h.unregister:
h.mu.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.send)
// Remove from all channels
for channel := range client.subscriptions {
if clients, ok := h.channels[channel]; ok {
delete(clients, client)
// Clean up empty channels
if len(clients) == 0 {
delete(h.channels, channel)
}
}
}
h.mu.Unlock()
if h.config.OnDisconnect != nil {
h.config.OnDisconnect(client)
}
} else {
h.mu.Unlock()
}
case message := <-h.broadcast:
h.mu.RLock()
for client := range h.clients {
select {
case client.send <- message:
default:
// Client buffer full, will be cleaned up
go func(c *Client) {
h.unregister <- c
}(client)
}
}
h.mu.RUnlock()
}
}
}
// Subscribe adds a client to a channel.
func (h *Hub) Subscribe(client *Client, channel string) {
h.mu.Lock()
defer h.mu.Unlock()
if _, ok := h.channels[channel]; !ok {
h.channels[channel] = make(map[*Client]bool)
}
h.channels[channel][client] = true
client.mu.Lock()
client.subscriptions[channel] = true
client.mu.Unlock()
}
// Unsubscribe removes a client from a channel.
func (h *Hub) Unsubscribe(client *Client, channel string) {
h.mu.Lock()
defer h.mu.Unlock()
if clients, ok := h.channels[channel]; ok {
delete(clients, client)
// Clean up empty channels
if len(clients) == 0 {
delete(h.channels, channel)
}
}
client.mu.Lock()
delete(client.subscriptions, channel)
client.mu.Unlock()
}
// Broadcast sends a message to all connected clients.
func (h *Hub) Broadcast(msg Message) error {
msg.Timestamp = time.Now()
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
select {
case h.broadcast <- data:
default:
return fmt.Errorf("broadcast channel full")
}
return nil
}
// SendToChannel sends a message to all clients subscribed to a channel.
func (h *Hub) SendToChannel(channel string, msg Message) error {
msg.Timestamp = time.Now()
msg.Channel = channel
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
h.mu.RLock()
clients, ok := h.channels[channel]
if !ok {
h.mu.RUnlock()
return nil // No subscribers, not an error
}
// Copy client references under lock to avoid races during iteration
targets := slices.Collect(maps.Keys(clients))
h.mu.RUnlock()
for _, client := range targets {
select {
case client.send <- data:
default:
// Client buffer full, skip
}
}
return nil
}
// SendProcessOutput sends process output to subscribers of the process channel.
func (h *Hub) SendProcessOutput(processID string, output string) error {
return h.SendToChannel("process:"+processID, Message{
Type: TypeProcessOutput,
ProcessID: processID,
Data: output,
})
}
// SendProcessStatus sends a process status update to subscribers.
func (h *Hub) SendProcessStatus(processID string, status string, exitCode int) error {
return h.SendToChannel("process:"+processID, Message{
Type: TypeProcessStatus,
ProcessID: processID,
Data: map[string]any{
"status": status,
"exitCode": exitCode,
},
})
}
// SendError sends an error message to all connected clients.
func (h *Hub) SendError(errMsg string) error {
return h.Broadcast(Message{
Type: TypeError,
Data: errMsg,
})
}
// SendEvent sends a generic event to all connected clients.
func (h *Hub) SendEvent(eventType string, data any) error {
return h.Broadcast(Message{
Type: TypeEvent,
Data: map[string]any{
"event": eventType,
"data": data,
},
})
}
// ClientCount returns the number of connected clients.
func (h *Hub) ClientCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.clients)
}
// ChannelCount returns the number of active channels.
func (h *Hub) ChannelCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.channels)
}
// ChannelSubscriberCount returns the number of subscribers for a channel.
func (h *Hub) ChannelSubscriberCount(channel string) int {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.channels[channel]; ok {
return len(clients)
}
return 0
}
// AllClients returns an iterator for all connected clients.
func (h *Hub) AllClients() iter.Seq[*Client] {
h.mu.RLock()
defer h.mu.RUnlock()
return slices.Values(slices.Collect(maps.Keys(h.clients)))
}
// AllChannels returns an iterator for all active channels.
func (h *Hub) AllChannels() iter.Seq[string] {
h.mu.RLock()
defer h.mu.RUnlock()
return slices.Values(slices.Collect(maps.Keys(h.channels)))
}
// HubStats contains hub statistics.
type HubStats struct {
Clients int `json:"clients"`
Channels int `json:"channels"`
}
// Stats returns current hub statistics.
func (h *Hub) Stats() HubStats {
h.mu.RLock()
defer h.mu.RUnlock()
return HubStats{
Clients: len(h.clients),
Channels: len(h.channels),
}
}
// HandleWebSocket is an alias for Handler for clearer API.
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
h.Handler()(w, r)
}
// Handler returns an HTTP handler for WebSocket connections.
func (h *Hub) Handler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Authenticate if an Authenticator is configured.
var authResult AuthResult
if h.config.Authenticator != nil {
authResult = h.config.Authenticator.Authenticate(r)
if !authResult.Valid {
if h.config.OnAuthFailure != nil {
h.config.OnAuthFailure(r, authResult)
}
http.Error(w, "Unauthorised", http.StatusUnauthorized)
return
}
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
client := &Client{
hub: h,
conn: conn,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
// Populate auth fields when authentication succeeded.
if h.config.Authenticator != nil {
client.UserID = authResult.UserID
client.Claims = authResult.Claims
}
h.register <- client
go client.writePump()
go client.readPump()
}
}
// readPump handles incoming messages from the client.
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
pongTimeout := c.hub.config.PongTimeout
c.conn.SetReadLimit(65536)
c.conn.SetReadDeadline(time.Now().Add(pongTimeout))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongTimeout))
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
break
}
var msg Message
if err := json.Unmarshal(message, &msg); err != nil {
continue
}
switch msg.Type {
case TypeSubscribe:
if channel, ok := msg.Data.(string); ok {
c.hub.Subscribe(c, channel)
}
case TypeUnsubscribe:
if channel, ok := msg.Data.(string); ok {
c.hub.Unsubscribe(c, channel)
}
case TypePing:
c.send <- mustMarshal(Message{Type: TypePong, Timestamp: time.Now()})
}
}
}
// writePump sends messages to the client.
func (c *Client) writePump() {
heartbeat := c.hub.config.HeartbeatInterval
writeTimeout := c.hub.config.WriteTimeout
ticker := time.NewTicker(heartbeat)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
// Batch queued messages
n := len(c.send)
for range n {
w.Write([]byte{'\n'})
w.Write(<-c.send)
}
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func mustMarshal(v any) []byte {
data, _ := json.Marshal(v)
return data
}
// Subscriptions returns a copy of the client's current subscriptions.
func (c *Client) Subscriptions() []string {
c.mu.RLock()
defer c.mu.RUnlock()
return slices.Collect(maps.Keys(c.subscriptions))
}
// AllSubscriptions returns an iterator for the client's current subscriptions.
func (c *Client) AllSubscriptions() iter.Seq[string] {
c.mu.RLock()
defer c.mu.RUnlock()
return slices.Values(slices.Collect(maps.Keys(c.subscriptions)))
}
// Close closes the client connection.
func (c *Client) Close() error {
c.hub.unregister <- c
return c.conn.Close()
}
// ReconnectConfig holds configuration for the reconnecting WebSocket client.
type ReconnectConfig struct {
// URL is the WebSocket server URL to connect to.
URL string
// InitialBackoff is the delay before the first reconnection attempt.
// Defaults to 1 second.
InitialBackoff time.Duration
// MaxBackoff is the maximum delay between reconnection attempts.
// Defaults to 30 seconds.
MaxBackoff time.Duration
// BackoffMultiplier controls exponential growth of the backoff.
// Defaults to 2.0.
BackoffMultiplier float64
// MaxRetries is the maximum number of consecutive reconnection attempts.
// Zero means unlimited retries.
MaxRetries int
// OnConnect is called when the client successfully connects.
OnConnect func()
// OnDisconnect is called when the client loses its connection.
OnDisconnect func()
// OnReconnect is called when the client successfully reconnects
// after a disconnection. The attempt count is passed in.
OnReconnect func(attempt int)
// OnMessage is called when a message is received from the server.
OnMessage func(msg Message)
// Dialer is the WebSocket dialer to use. Defaults to websocket.DefaultDialer.
Dialer *websocket.Dialer
// Headers are additional HTTP headers to send during the handshake.
Headers http.Header
}
// ReconnectingClient is a WebSocket client that automatically reconnects
// with exponential backoff when the connection drops.
type ReconnectingClient struct {
config ReconnectConfig
conn *websocket.Conn
send chan []byte
state ConnectionState
mu sync.RWMutex
done chan struct{}
ctx context.Context
cancel context.CancelFunc
}
// NewReconnectingClient creates a new reconnecting WebSocket client.
func NewReconnectingClient(config ReconnectConfig) *ReconnectingClient {
if config.InitialBackoff <= 0 {
config.InitialBackoff = 1 * time.Second
}
if config.MaxBackoff <= 0 {
config.MaxBackoff = 30 * time.Second
}
if config.BackoffMultiplier <= 0 {
config.BackoffMultiplier = 2.0
}
if config.Dialer == nil {
config.Dialer = websocket.DefaultDialer
}
return &ReconnectingClient{
config: config,
send: make(chan []byte, 256),
state: StateDisconnected,
done: make(chan struct{}),
}
}
// Connect starts the reconnecting client. It blocks until the context is
// cancelled. The client will automatically reconnect on connection loss.
func (rc *ReconnectingClient) Connect(ctx context.Context) error {
rc.ctx, rc.cancel = context.WithCancel(ctx)
defer rc.cancel()
attempt := 0
wasConnected := false
for {
select {
case <-rc.ctx.Done():
rc.setState(StateDisconnected)
return rc.ctx.Err()
default:
}
rc.setState(StateConnecting)
attempt++
conn, _, err := rc.config.Dialer.DialContext(rc.ctx, rc.config.URL, rc.config.Headers)
if err != nil {
if rc.config.MaxRetries > 0 && attempt > rc.config.MaxRetries {
rc.setState(StateDisconnected)
return fmt.Errorf("max retries (%d) exceeded: %w", rc.config.MaxRetries, err)
}
backoff := rc.calculateBackoff(attempt)
select {
case <-rc.ctx.Done():
rc.setState(StateDisconnected)
return rc.ctx.Err()
case <-time.After(backoff):
continue
}
}
// Connected successfully
rc.mu.Lock()
rc.conn = conn
rc.mu.Unlock()
rc.setState(StateConnected)
if wasConnected {
if rc.config.OnReconnect != nil {
rc.config.OnReconnect(attempt)
}
} else {
if rc.config.OnConnect != nil {
rc.config.OnConnect()
}
}
// Reset attempt counter after a successful connection
attempt = 0
wasConnected = true
// Run the read loop — blocks until connection drops
rc.readLoop()
// Connection lost
rc.mu.Lock()
rc.conn = nil
rc.mu.Unlock()
if rc.config.OnDisconnect != nil {
rc.config.OnDisconnect()
}
}
}
// Send sends a message to the server. Returns an error if not connected.
func (rc *ReconnectingClient) Send(msg Message) error {
msg.Timestamp = time.Now()
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
rc.mu.RLock()
conn := rc.conn
rc.mu.RUnlock()
if conn == nil {
return fmt.Errorf("not connected")
}
rc.mu.Lock()
defer rc.mu.Unlock()
return rc.conn.WriteMessage(websocket.TextMessage, data)
}
// State returns the current connection state.
func (rc *ReconnectingClient) State() ConnectionState {
rc.mu.RLock()
defer rc.mu.RUnlock()
return rc.state
}
// Close gracefully shuts down the reconnecting client.
func (rc *ReconnectingClient) Close() error {
if rc.cancel != nil {
rc.cancel()
}
rc.mu.RLock()
conn := rc.conn
rc.mu.RUnlock()
if conn != nil {
return conn.Close()
}
return nil
}
func (rc *ReconnectingClient) setState(state ConnectionState) {
rc.mu.Lock()
rc.state = state
rc.mu.Unlock()
}
func (rc *ReconnectingClient) calculateBackoff(attempt int) time.Duration {
backoff := rc.config.InitialBackoff
for range attempt - 1 {
backoff = time.Duration(float64(backoff) * rc.config.BackoffMultiplier)
if backoff > rc.config.MaxBackoff {
backoff = rc.config.MaxBackoff
break
}
}
return backoff
}
func (rc *ReconnectingClient) readLoop() {
rc.mu.RLock()
conn := rc.conn
rc.mu.RUnlock()
if conn == nil {
return
}
for {
_, data, err := conn.ReadMessage()
if err != nil {
return
}
if rc.config.OnMessage != nil {
var msg Message
if jsonErr := json.Unmarshal(data, &msg); jsonErr == nil {
rc.config.OnMessage(msg)
}
}
}
}