diff --git a/CLAUDE.md b/CLAUDE.md index 635c8ae..c0e66b6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,6 +17,11 @@ go test -v -run Name # Run single test - Messages types: `process_output`, `process_status`, `event`, `error`, `ping/pong`, `subscribe/unsubscribe` - `hub.SendProcessOutput(id, line)` broadcasts to subscribers - `HubConfig` provides configurable heartbeat, pong timeout, write timeout, and connection callbacks +- `Authenticator` interface for token-based auth on upgrade — nil means all connections accepted (backward compat) +- `APIKeyAuthenticator` validates `Authorization: Bearer ` against a static key→userID map +- `AuthenticatorFunc` adapter lets plain functions satisfy the `Authenticator` interface +- `Client.UserID` and `Client.Claims` populated during authenticated upgrade +- `OnAuthFailure` callback on `HubConfig` for logging/metrics on rejected connections - `ReconnectingClient` provides client-side reconnection with exponential backoff - `ConnectionState`: `StateDisconnected`, `StateConnecting`, `StateConnected` - Coverage: 98.5% diff --git a/TODO.md b/TODO.md index 463b71e..17d8dc0 100644 --- a/TODO.md +++ b/TODO.md @@ -18,13 +18,13 @@ Dispatched from core/go orchestration. Pick up tasks in order. - [x] Tune heartbeat interval and pong timeout for flaky networks — `HubConfig` with `HeartbeatInterval`, `PongTimeout`, `WriteTimeout`. `NewHubWithConfig()` constructor. Defaults: 30s heartbeat, 60s pong timeout, 10s write timeout. - [x] Add connection state callbacks (onConnect, onDisconnect, onReconnect) — Hub-level `OnConnect`/`OnDisconnect` callbacks in `HubConfig`. Client-level `OnConnect`/`OnDisconnect`/`OnReconnect` callbacks in `ReconnectConfig`. `ConnectionState` enum: `StateDisconnected`, `StateConnecting`, `StateConnected`. -## Phase 2: Auth +## Phase 2: Auth (complete) Token-based authentication on WebSocket upgrade handshake. Pure Go, no JWT library dependency — consumers bring their own validation logic via an interface. ### 2.1 Authenticator Interface -- [ ] **Create `auth.go`** — Define the auth abstraction: +- [x] **Create `auth.go`** — Define the auth abstraction: - `type AuthResult struct { Valid bool; UserID string; Claims map[string]any; Error error }` — result of authentication - `type Authenticator interface { Authenticate(r *http.Request) AuthResult }` — validates the HTTP request during upgrade. Implementations can check headers (`Authorization: Bearer `), query params (`?token=xxx`), or cookies. - `type AuthenticatorFunc func(r *http.Request) AuthResult` — adapter for using functions as Authenticators (implements the interface) @@ -33,22 +33,22 @@ Token-based authentication on WebSocket upgrade handshake. Pure Go, no JWT libra ### 2.2 Wire Into Hub -- [ ] **Add `Authenticator` to `HubConfig`** — Optional field. When nil, all connections are accepted (backward compatible). When set, `Handler()` calls `Authenticate(r)` before upgrading. -- [ ] **Update `Handler()`** — If `h.config.Authenticator != nil`, call `Authenticate(r)`. If `!result.Valid`, respond with `http.StatusUnauthorized` (or `http.StatusForbidden` if `result.Error` indicates a different status) and return without upgrading. If valid, store `result.UserID` and `result.Claims` on the `Client` struct. -- [ ] **Add auth fields to `Client`** — `UserID string` and `Claims map[string]any` fields. Set during authenticated upgrade. Empty for unauthenticated hubs (nil authenticator). -- [ ] **Expose `OnAuthFailure` callback** — Optional `OnAuthFailure func(r *http.Request, result AuthResult)` on `HubConfig` for logging/metrics on rejected connections. +- [x] **Add `Authenticator` to `HubConfig`** — Optional field. When nil, all connections are accepted (backward compatible). When set, `Handler()` calls `Authenticate(r)` before upgrading. +- [x] **Update `Handler()`** — If `h.config.Authenticator != nil`, call `Authenticate(r)`. If `!result.Valid`, respond with `http.StatusUnauthorized` (or `http.StatusForbidden` if `result.Error` indicates a different status) and return without upgrading. If valid, store `result.UserID` and `result.Claims` on the `Client` struct. +- [x] **Add auth fields to `Client`** — `UserID string` and `Claims map[string]any` fields. Set during authenticated upgrade. Empty for unauthenticated hubs (nil authenticator). +- [x] **Expose `OnAuthFailure` callback** — Optional `OnAuthFailure func(r *http.Request, result AuthResult)` on `HubConfig` for logging/metrics on rejected connections. ### 2.3 Tests -- [ ] **Unit tests** — (a) APIKeyAuthenticator valid key, (b) invalid key, (c) missing header, (d) malformed header ("Bearer" without token, wrong scheme), (e) AuthenticatorFunc adapter, (f) nil Authenticator (backward compat — all connections accepted) -- [ ] **Integration tests** — Using httptest + gorilla/websocket Dial: +- [x] **Unit tests** — (a) APIKeyAuthenticator valid key, (b) invalid key, (c) missing header, (d) malformed header ("Bearer" without token, wrong scheme), (e) AuthenticatorFunc adapter, (f) nil Authenticator (backward compat — all connections accepted) +- [x] **Integration tests** — Using httptest + gorilla/websocket Dial: - (a) Authenticated connect with valid API key → upgrade succeeds, client.UserID set - (b) Rejected connect with invalid key → HTTP 401, no WebSocket upgrade - (c) Rejected connect with no auth header → HTTP 401 - (d) Nil authenticator → all connections accepted (existing behaviour preserved) - (e) OnAuthFailure callback fires on rejection - (f) Multiple clients with different API keys → each gets correct UserID -- [ ] **Existing tests still pass** — No authenticator set = backward compatible +- [x] **Existing tests still pass** — No authenticator set = backward compatible ## Phase 3: Scaling diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..0513a81 --- /dev/null +++ b/auth.go @@ -0,0 +1,99 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ws + +import ( + "net/http" + "strings" +) + +// AuthResult holds the outcome of an authentication attempt. +type AuthResult struct { + // Valid indicates whether authentication succeeded. + Valid bool + + // UserID is the authenticated user's identifier. + UserID string + + // Claims holds arbitrary metadata from the authentication source + // (e.g. roles, scopes, tenant ID). + Claims map[string]any + + // Error holds the reason for authentication failure, if any. + Error error +} + +// Authenticator validates an HTTP request during the WebSocket upgrade +// handshake. Implementations may inspect headers, query parameters, +// cookies, or any other request attribute. +type Authenticator interface { + Authenticate(r *http.Request) AuthResult +} + +// AuthenticatorFunc is an adapter that allows ordinary functions to be +// used as Authenticators. If f is a function with the appropriate +// signature, AuthenticatorFunc(f) is an Authenticator that calls f. +type AuthenticatorFunc func(r *http.Request) AuthResult + +// Authenticate calls f(r). +func (f AuthenticatorFunc) Authenticate(r *http.Request) AuthResult { + return f(r) +} + +// APIKeyAuthenticator validates requests against a static map of API +// keys. It expects the key in the Authorization header as a Bearer +// token: `Authorization: Bearer `. Each key maps to a user ID. +type APIKeyAuthenticator struct { + // Keys maps API key values to user IDs. + Keys map[string]string +} + +// NewAPIKeyAuth creates an APIKeyAuthenticator from the given key→userID +// mapping. The returned authenticator validates `Authorization: Bearer ` +// headers against the provided keys. +func NewAPIKeyAuth(keys map[string]string) *APIKeyAuthenticator { + return &APIKeyAuthenticator{Keys: keys} +} + +// Authenticate checks the Authorization header for a valid Bearer token. +func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult { + header := r.Header.Get("Authorization") + if header == "" { + return AuthResult{ + Valid: false, + Error: ErrMissingAuthHeader, + } + } + + parts := strings.SplitN(header, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return AuthResult{ + Valid: false, + Error: ErrMalformedAuthHeader, + } + } + + token := strings.TrimSpace(parts[1]) + if token == "" { + return AuthResult{ + Valid: false, + Error: ErrMalformedAuthHeader, + } + } + + userID, ok := a.Keys[token] + if !ok { + return AuthResult{ + Valid: false, + Error: ErrInvalidAPIKey, + } + } + + return AuthResult{ + Valid: true, + UserID: userID, + Claims: map[string]any{ + "auth_method": "api_key", + }, + } +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..d2a460d --- /dev/null +++ b/auth_test.go @@ -0,0 +1,449 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ws + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Unit tests — APIKeyAuthenticator +// --------------------------------------------------------------------------- + +func TestAPIKeyAuthenticator_ValidKey(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + "key-def": "user-2", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-abc") + + result := auth.Authenticate(r) + + assert.True(t, result.Valid) + assert.Equal(t, "user-1", result.UserID) + assert.Equal(t, "api_key", result.Claims["auth_method"]) + assert.NoError(t, result.Error) +} + +func TestAPIKeyAuthenticator_InvalidKey(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer wrong-key") + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.Empty(t, result.UserID) + assert.True(t, errors.Is(result.Error, ErrInvalidAPIKey)) +} + +func TestAPIKeyAuthenticator_MissingHeader(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + // No Authorization header set + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.True(t, errors.Is(result.Error, ErrMissingAuthHeader)) +} + +func TestAPIKeyAuthenticator_MalformedHeader(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + }) + + tests := []struct { + name string + header string + }{ + {"bearer without token", "Bearer "}, + {"bearer only", "Bearer"}, + {"wrong scheme", "Basic key-abc"}, + {"no scheme", "key-abc"}, + {"empty bearer with spaces", "Bearer "}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", tt.header) + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.True(t, errors.Is(result.Error, ErrMalformedAuthHeader)) + }) + } +} + +func TestAPIKeyAuthenticator_CaseInsensitiveScheme(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "bearer key-abc") + + result := auth.Authenticate(r) + + assert.True(t, result.Valid) + assert.Equal(t, "user-1", result.UserID) +} + +func TestAPIKeyAuthenticator_SecondKey(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-abc": "user-1", + "key-def": "user-2", + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer key-def") + + result := auth.Authenticate(r) + + assert.True(t, result.Valid) + assert.Equal(t, "user-2", result.UserID) +} + +// --------------------------------------------------------------------------- +// Unit tests — AuthenticatorFunc adapter +// --------------------------------------------------------------------------- + +func TestAuthenticatorFunc_Adapter(t *testing.T) { + called := false + fn := AuthenticatorFunc(func(r *http.Request) AuthResult { + called = true + return AuthResult{Valid: true, UserID: "func-user"} + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + result := fn.Authenticate(r) + + assert.True(t, called) + assert.True(t, result.Valid) + assert.Equal(t, "func-user", result.UserID) +} + +func TestAuthenticatorFunc_Rejection(t *testing.T) { + fn := AuthenticatorFunc(func(r *http.Request) AuthResult { + return AuthResult{Valid: false, Error: errors.New("custom rejection")} + }) + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + result := fn.Authenticate(r) + + assert.False(t, result.Valid) + assert.EqualError(t, result.Error, "custom rejection") +} + +// --------------------------------------------------------------------------- +// Unit tests — nil Authenticator (backward compat) +// --------------------------------------------------------------------------- + +func TestNilAuthenticator_AllConnectionsAccepted(t *testing.T) { + hub := NewHub() // No authenticator set + assert.Nil(t, hub.config.Authenticator) +} + +// --------------------------------------------------------------------------- +// Integration tests — httptest + gorilla/websocket Dial +// --------------------------------------------------------------------------- + +// helper: start a hub with the given config, return server + cleanup +func startAuthTestHub(t *testing.T, config HubConfig) (*httptest.Server, *Hub, context.CancelFunc) { + t.Helper() + hub := NewHubWithConfig(config) + ctx, cancel := context.WithCancel(context.Background()) + go hub.Run(ctx) + + server := httptest.NewServer(hub.Handler()) + t.Cleanup(func() { + server.Close() + cancel() + }) + return server, hub, cancel +} + +func authWSURL(server *httptest.Server) string { + return "ws" + strings.TrimPrefix(server.URL, "http") +} + +func TestIntegration_AuthenticatedConnect(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "valid-key": "user-42", + }) + + var connectedClient *Client + var mu sync.Mutex + + server, _, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + OnConnect: func(client *Client) { + mu.Lock() + connectedClient = client + mu.Unlock() + }, + }) + + header := http.Header{} + header.Set("Authorization", "Bearer valid-key") + + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) + require.NoError(t, err) + defer conn.Close() + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + // Give the hub a moment to process registration + time.Sleep(50 * time.Millisecond) + + mu.Lock() + client := connectedClient + mu.Unlock() + + require.NotNil(t, client, "OnConnect should have fired") + assert.Equal(t, "user-42", client.UserID) + assert.Equal(t, "api_key", client.Claims["auth_method"]) +} + +func TestIntegration_RejectedConnect_InvalidKey(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "valid-key": "user-42", + }) + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + }) + + header := http.Header{} + header.Set("Authorization", "Bearer wrong-key") + + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) + if conn != nil { + conn.Close() + } + + require.Error(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, 0, hub.ClientCount()) +} + +func TestIntegration_RejectedConnect_NoAuthHeader(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "valid-key": "user-42", + }) + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + }) + + // No Authorization header + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil) + if conn != nil { + conn.Close() + } + + require.Error(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, 0, hub.ClientCount()) +} + +func TestIntegration_NilAuthenticator_BackwardCompat(t *testing.T) { + // No authenticator — all connections should be accepted + server, hub, _ := startAuthTestHub(t, HubConfig{}) + + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil) + require.NoError(t, err) + defer conn.Close() + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ClientCount()) +} + +func TestIntegration_OnAuthFailure_Callback(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "valid-key": "user-42", + }) + + var failureMu sync.Mutex + var failureResult AuthResult + var failureRequest *http.Request + failureCalled := false + + server, _, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + OnAuthFailure: func(r *http.Request, result AuthResult) { + failureMu.Lock() + failureCalled = true + failureResult = result + failureRequest = r + failureMu.Unlock() + }, + }) + + header := http.Header{} + header.Set("Authorization", "Bearer bad-key") + + conn, _, _ := websocket.DefaultDialer.Dial(authWSURL(server), header) + if conn != nil { + conn.Close() + } + + // Give callback time to execute + time.Sleep(50 * time.Millisecond) + + failureMu.Lock() + defer failureMu.Unlock() + + assert.True(t, failureCalled, "OnAuthFailure should have been called") + assert.False(t, failureResult.Valid) + assert.True(t, errors.Is(failureResult.Error, ErrInvalidAPIKey)) + assert.NotNil(t, failureRequest) +} + +func TestIntegration_MultipleClients_DifferentKeys(t *testing.T) { + auth := NewAPIKeyAuth(map[string]string{ + "key-alpha": "user-alpha", + "key-beta": "user-beta", + "key-gamma": "user-gamma", + }) + + var mu sync.Mutex + connectedClients := make(map[string]*Client) + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + OnConnect: func(client *Client) { + mu.Lock() + connectedClients[client.UserID] = client + mu.Unlock() + }, + }) + + keys := []struct { + key string + userID string + }{ + {"key-alpha", "user-alpha"}, + {"key-beta", "user-beta"}, + {"key-gamma", "user-gamma"}, + } + + conns := make([]*websocket.Conn, 0, len(keys)) + for _, k := range keys { + header := http.Header{} + header.Set("Authorization", "Bearer "+k.key) + + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) + require.NoError(t, err, "key %s should connect", k.key) + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + conns = append(conns, conn) + } + defer func() { + for _, c := range conns { + c.Close() + } + }() + + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, 3, hub.ClientCount()) + + mu.Lock() + defer mu.Unlock() + for _, k := range keys { + client, ok := connectedClients[k.userID] + require.True(t, ok, "should have client for %s", k.userID) + assert.Equal(t, k.userID, client.UserID) + } +} + +func TestIntegration_AuthenticatorFunc_WithHub(t *testing.T) { + // Use AuthenticatorFunc as the hub's authenticator + fn := AuthenticatorFunc(func(r *http.Request) AuthResult { + token := r.URL.Query().Get("token") + if token == "magic" { + return AuthResult{ + Valid: true, + UserID: "magic-user", + Claims: map[string]any{"source": "query_param"}, + } + } + return AuthResult{Valid: false, Error: errors.New("bad token")} + }) + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: fn, + }) + + // Valid token via query parameter + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server)+"?token=magic", nil) + require.NoError(t, err) + defer conn.Close() + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ClientCount()) + + // Invalid token + conn2, resp2, _ := websocket.DefaultDialer.Dial(authWSURL(server)+"?token=wrong", nil) + if conn2 != nil { + conn2.Close() + } + assert.Equal(t, http.StatusUnauthorized, resp2.StatusCode) +} + +func TestIntegration_AuthenticatedClient_ReceivesMessages(t *testing.T) { + // Verify that an authenticated client can still receive messages normally + auth := NewAPIKeyAuth(map[string]string{ + "key-1": "user-1", + }) + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + }) + + header := http.Header{} + header.Set("Authorization", "Bearer key-1") + + conn, _, err := websocket.DefaultDialer.Dial(authWSURL(server), header) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Broadcast a message + err = hub.Broadcast(Message{Type: TypeEvent, Data: "hello"}) + require.NoError(t, err) + + // Read it + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, data, err := conn.ReadMessage() + require.NoError(t, err) + + var msg Message + require.NoError(t, json.Unmarshal(data, &msg)) + assert.Equal(t, TypeEvent, msg.Type) + assert.Equal(t, "hello", msg.Data) +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..654435e --- /dev/null +++ b/errors.go @@ -0,0 +1,19 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ws + +import "errors" + +// Authentication errors returned by the built-in APIKeyAuthenticator. +var ( + // ErrMissingAuthHeader is returned when no Authorization header is present. + ErrMissingAuthHeader = errors.New("missing Authorization header") + + // ErrMalformedAuthHeader is returned when the Authorization header is + // not in the expected "Bearer " format. + ErrMalformedAuthHeader = errors.New("malformed Authorization header") + + // ErrInvalidAPIKey is returned when the provided API key does not + // match any known key. + ErrInvalidAPIKey = errors.New("invalid API key") +) diff --git a/ws.go b/ws.go index c8a3a2e..6cb98cb 100644 --- a/ws.go +++ b/ws.go @@ -14,6 +14,18 @@ // // Register HTTP handler // http.HandleFunc("/ws", hub.Handler()) // +// # Authentication +// +// The hub supports optional token-based authentication on upgrade. Supply an +// Authenticator via HubConfig to gate connections: +// +// auth := ws.NewAPIKeyAuth(map[string]string{"secret-key": "user-1"}) +// hub := ws.NewHubWithConfig(ws.HubConfig{Authenticator: auth}) +// go hub.Run(ctx) +// +// When no Authenticator is set (nil), all connections are accepted — preserving +// backward compatibility. +// // # Message Types // // The package defines several message types for different purposes: @@ -104,6 +116,16 @@ type HubConfig struct { // OnDisconnect is called when a client disconnects from the hub. OnDisconnect func(client *Client) + + // Authenticator validates incoming WebSocket connections during the + // HTTP upgrade handshake. When nil, all connections are accepted + // (backward compatible). When set, connections that fail authentication + // receive an HTTP 401 response and are not upgraded. + Authenticator Authenticator + + // OnAuthFailure is called when a connection is rejected by the + // Authenticator. Useful for logging or metrics. Optional. + OnAuthFailure func(r *http.Request, result AuthResult) } // DefaultHubConfig returns a HubConfig with sensible defaults. @@ -153,6 +175,15 @@ type Client struct { send chan []byte subscriptions map[string]bool mu sync.RWMutex + + // UserID is the authenticated user's identifier, set during the + // upgrade handshake when an Authenticator is configured. Empty + // when no Authenticator is set. + UserID string + + // Claims holds arbitrary authentication metadata (e.g. roles, + // scopes). Nil when no Authenticator is set. + Claims map[string]any } // Hub manages WebSocket connections and message broadcasting. @@ -422,6 +453,19 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { // Handler returns an HTTP handler for WebSocket connections. func (h *Hub) Handler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + // Authenticate if an Authenticator is configured. + var authResult AuthResult + if h.config.Authenticator != nil { + authResult = h.config.Authenticator.Authenticate(r) + if !authResult.Valid { + if h.config.OnAuthFailure != nil { + h.config.OnAuthFailure(r, authResult) + } + http.Error(w, "Unauthorised", http.StatusUnauthorized) + return + } + } + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return @@ -434,6 +478,12 @@ func (h *Hub) Handler() http.HandlerFunc { subscriptions: make(map[string]bool), } + // Populate auth fields when authentication succeeded. + if h.config.Authenticator != nil { + client.UserID = authResult.UserID + client.Claims = authResult.Claims + } + h.register <- client go client.writePump()