feat(auth): Phase 2 — token-based authentication on WebSocket upgrade

Add Authenticator interface with AuthenticatorFunc adapter and built-in
APIKeyAuthenticator for Bearer token validation. Hub.Handler() now gates
connections when an Authenticator is configured on HubConfig, responding
HTTP 401 for failed auth. Client.UserID and Client.Claims are populated
on successful upgrade. OnAuthFailure callback enables logging/metrics.

Nil authenticator preserves full backward compatibility — all existing
tests pass unchanged. 18 new tests (unit + integration) cover valid/
invalid/missing/malformed headers, func adapter, multi-client auth,
message delivery post-auth, and the OnAuthFailure callback.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-20 08:12:45 +00:00
parent 534bbe571f
commit 9e48f0b60d
6 changed files with 631 additions and 9 deletions

View file

@ -17,6 +17,11 @@ go test -v -run Name # Run single test
- Messages types: `process_output`, `process_status`, `event`, `error`, `ping/pong`, `subscribe/unsubscribe`
- `hub.SendProcessOutput(id, line)` broadcasts to subscribers
- `HubConfig` provides configurable heartbeat, pong timeout, write timeout, and connection callbacks
- `Authenticator` interface for token-based auth on upgrade — nil means all connections accepted (backward compat)
- `APIKeyAuthenticator` validates `Authorization: Bearer <key>` against a static key→userID map
- `AuthenticatorFunc` adapter lets plain functions satisfy the `Authenticator` interface
- `Client.UserID` and `Client.Claims` populated during authenticated upgrade
- `OnAuthFailure` callback on `HubConfig` for logging/metrics on rejected connections
- `ReconnectingClient` provides client-side reconnection with exponential backoff
- `ConnectionState`: `StateDisconnected`, `StateConnecting`, `StateConnected`
- Coverage: 98.5%

18
TODO.md
View file

@ -18,13 +18,13 @@ Dispatched from core/go orchestration. Pick up tasks in order.
- [x] Tune heartbeat interval and pong timeout for flaky networks — `HubConfig` with `HeartbeatInterval`, `PongTimeout`, `WriteTimeout`. `NewHubWithConfig()` constructor. Defaults: 30s heartbeat, 60s pong timeout, 10s write timeout.
- [x] Add connection state callbacks (onConnect, onDisconnect, onReconnect) — Hub-level `OnConnect`/`OnDisconnect` callbacks in `HubConfig`. Client-level `OnConnect`/`OnDisconnect`/`OnReconnect` callbacks in `ReconnectConfig`. `ConnectionState` enum: `StateDisconnected`, `StateConnecting`, `StateConnected`.
## Phase 2: Auth
## Phase 2: Auth (complete)
Token-based authentication on WebSocket upgrade handshake. Pure Go, no JWT library dependency — consumers bring their own validation logic via an interface.
### 2.1 Authenticator Interface
- [ ] **Create `auth.go`** — Define the auth abstraction:
- [x] **Create `auth.go`** — Define the auth abstraction:
- `type AuthResult struct { Valid bool; UserID string; Claims map[string]any; Error error }` — result of authentication
- `type Authenticator interface { Authenticate(r *http.Request) AuthResult }` — validates the HTTP request during upgrade. Implementations can check headers (`Authorization: Bearer <token>`), query params (`?token=xxx`), or cookies.
- `type AuthenticatorFunc func(r *http.Request) AuthResult` — adapter for using functions as Authenticators (implements the interface)
@ -33,22 +33,22 @@ Token-based authentication on WebSocket upgrade handshake. Pure Go, no JWT libra
### 2.2 Wire Into Hub
- [ ] **Add `Authenticator` to `HubConfig`** — Optional field. When nil, all connections are accepted (backward compatible). When set, `Handler()` calls `Authenticate(r)` before upgrading.
- [ ] **Update `Handler()`** — If `h.config.Authenticator != nil`, call `Authenticate(r)`. If `!result.Valid`, respond with `http.StatusUnauthorized` (or `http.StatusForbidden` if `result.Error` indicates a different status) and return without upgrading. If valid, store `result.UserID` and `result.Claims` on the `Client` struct.
- [ ] **Add auth fields to `Client`**`UserID string` and `Claims map[string]any` fields. Set during authenticated upgrade. Empty for unauthenticated hubs (nil authenticator).
- [ ] **Expose `OnAuthFailure` callback** — Optional `OnAuthFailure func(r *http.Request, result AuthResult)` on `HubConfig` for logging/metrics on rejected connections.
- [x] **Add `Authenticator` to `HubConfig`** — Optional field. When nil, all connections are accepted (backward compatible). When set, `Handler()` calls `Authenticate(r)` before upgrading.
- [x] **Update `Handler()`** — If `h.config.Authenticator != nil`, call `Authenticate(r)`. If `!result.Valid`, respond with `http.StatusUnauthorized` (or `http.StatusForbidden` if `result.Error` indicates a different status) and return without upgrading. If valid, store `result.UserID` and `result.Claims` on the `Client` struct.
- [x] **Add auth fields to `Client`**`UserID string` and `Claims map[string]any` fields. Set during authenticated upgrade. Empty for unauthenticated hubs (nil authenticator).
- [x] **Expose `OnAuthFailure` callback** — Optional `OnAuthFailure func(r *http.Request, result AuthResult)` on `HubConfig` for logging/metrics on rejected connections.
### 2.3 Tests
- [ ] **Unit tests** — (a) APIKeyAuthenticator valid key, (b) invalid key, (c) missing header, (d) malformed header ("Bearer" without token, wrong scheme), (e) AuthenticatorFunc adapter, (f) nil Authenticator (backward compat — all connections accepted)
- [ ] **Integration tests** — Using httptest + gorilla/websocket Dial:
- [x] **Unit tests** — (a) APIKeyAuthenticator valid key, (b) invalid key, (c) missing header, (d) malformed header ("Bearer" without token, wrong scheme), (e) AuthenticatorFunc adapter, (f) nil Authenticator (backward compat — all connections accepted)
- [x] **Integration tests** — Using httptest + gorilla/websocket Dial:
- (a) Authenticated connect with valid API key → upgrade succeeds, client.UserID set
- (b) Rejected connect with invalid key → HTTP 401, no WebSocket upgrade
- (c) Rejected connect with no auth header → HTTP 401
- (d) Nil authenticator → all connections accepted (existing behaviour preserved)
- (e) OnAuthFailure callback fires on rejection
- (f) Multiple clients with different API keys → each gets correct UserID
- [ ] **Existing tests still pass** — No authenticator set = backward compatible
- [x] **Existing tests still pass** — No authenticator set = backward compatible
## Phase 3: Scaling

99
auth.go Normal file
View file

@ -0,0 +1,99 @@
// SPDX-Licence-Identifier: EUPL-1.2
package ws
import (
"net/http"
"strings"
)
// AuthResult holds the outcome of an authentication attempt.
type AuthResult struct {
// Valid indicates whether authentication succeeded.
Valid bool
// UserID is the authenticated user's identifier.
UserID string
// Claims holds arbitrary metadata from the authentication source
// (e.g. roles, scopes, tenant ID).
Claims map[string]any
// Error holds the reason for authentication failure, if any.
Error error
}
// Authenticator validates an HTTP request during the WebSocket upgrade
// handshake. Implementations may inspect headers, query parameters,
// cookies, or any other request attribute.
type Authenticator interface {
Authenticate(r *http.Request) AuthResult
}
// AuthenticatorFunc is an adapter that allows ordinary functions to be
// used as Authenticators. If f is a function with the appropriate
// signature, AuthenticatorFunc(f) is an Authenticator that calls f.
type AuthenticatorFunc func(r *http.Request) AuthResult
// Authenticate calls f(r).
func (f AuthenticatorFunc) Authenticate(r *http.Request) AuthResult {
return f(r)
}
// APIKeyAuthenticator validates requests against a static map of API
// keys. It expects the key in the Authorization header as a Bearer
// token: `Authorization: Bearer <key>`. Each key maps to a user ID.
type APIKeyAuthenticator struct {
// Keys maps API key values to user IDs.
Keys map[string]string
}
// NewAPIKeyAuth creates an APIKeyAuthenticator from the given key→userID
// mapping. The returned authenticator validates `Authorization: Bearer <key>`
// headers against the provided keys.
func NewAPIKeyAuth(keys map[string]string) *APIKeyAuthenticator {
return &APIKeyAuthenticator{Keys: keys}
}
// Authenticate checks the Authorization header for a valid Bearer token.
func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult {
header := r.Header.Get("Authorization")
if header == "" {
return AuthResult{
Valid: false,
Error: ErrMissingAuthHeader,
}
}
parts := strings.SplitN(header, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return AuthResult{
Valid: false,
Error: ErrMalformedAuthHeader,
}
}
token := strings.TrimSpace(parts[1])
if token == "" {
return AuthResult{
Valid: false,
Error: ErrMalformedAuthHeader,
}
}
userID, ok := a.Keys[token]
if !ok {
return AuthResult{
Valid: false,
Error: ErrInvalidAPIKey,
}
}
return AuthResult{
Valid: true,
UserID: userID,
Claims: map[string]any{
"auth_method": "api_key",
},
}
}

449
auth_test.go Normal file
View file

@ -0,0 +1,449 @@
// SPDX-Licence-Identifier: EUPL-1.2
package ws
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Unit tests — APIKeyAuthenticator
// ---------------------------------------------------------------------------
func TestAPIKeyAuthenticator_ValidKey(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"key-abc": "user-1",
"key-def": "user-2",
})
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", "Bearer key-abc")
result := auth.Authenticate(r)
assert.True(t, result.Valid)
assert.Equal(t, "user-1", result.UserID)
assert.Equal(t, "api_key", result.Claims["auth_method"])
assert.NoError(t, result.Error)
}
func TestAPIKeyAuthenticator_InvalidKey(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"key-abc": "user-1",
})
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", "Bearer wrong-key")
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.Empty(t, result.UserID)
assert.True(t, errors.Is(result.Error, ErrInvalidAPIKey))
}
func TestAPIKeyAuthenticator_MissingHeader(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"key-abc": "user-1",
})
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
// No Authorization header set
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.True(t, errors.Is(result.Error, ErrMissingAuthHeader))
}
func TestAPIKeyAuthenticator_MalformedHeader(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"key-abc": "user-1",
})
tests := []struct {
name string
header string
}{
{"bearer without token", "Bearer "},
{"bearer only", "Bearer"},
{"wrong scheme", "Basic key-abc"},
{"no scheme", "key-abc"},
{"empty bearer with spaces", "Bearer "},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", tt.header)
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.True(t, errors.Is(result.Error, ErrMalformedAuthHeader))
})
}
}
func TestAPIKeyAuthenticator_CaseInsensitiveScheme(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"key-abc": "user-1",
})
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", "bearer key-abc")
result := auth.Authenticate(r)
assert.True(t, result.Valid)
assert.Equal(t, "user-1", result.UserID)
}
func TestAPIKeyAuthenticator_SecondKey(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"key-abc": "user-1",
"key-def": "user-2",
})
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", "Bearer key-def")
result := auth.Authenticate(r)
assert.True(t, result.Valid)
assert.Equal(t, "user-2", result.UserID)
}
// ---------------------------------------------------------------------------
// Unit tests — AuthenticatorFunc adapter
// ---------------------------------------------------------------------------
func TestAuthenticatorFunc_Adapter(t *testing.T) {
called := false
fn := AuthenticatorFunc(func(r *http.Request) AuthResult {
called = true
return AuthResult{Valid: true, UserID: "func-user"}
})
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
result := fn.Authenticate(r)
assert.True(t, called)
assert.True(t, result.Valid)
assert.Equal(t, "func-user", result.UserID)
}
func TestAuthenticatorFunc_Rejection(t *testing.T) {
fn := AuthenticatorFunc(func(r *http.Request) AuthResult {
return AuthResult{Valid: false, Error: errors.New("custom rejection")}
})
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
result := fn.Authenticate(r)
assert.False(t, result.Valid)
assert.EqualError(t, result.Error, "custom rejection")
}
// ---------------------------------------------------------------------------
// Unit tests — nil Authenticator (backward compat)
// ---------------------------------------------------------------------------
func TestNilAuthenticator_AllConnectionsAccepted(t *testing.T) {
hub := NewHub() // No authenticator set
assert.Nil(t, hub.config.Authenticator)
}
// ---------------------------------------------------------------------------
// Integration tests — httptest + gorilla/websocket Dial
// ---------------------------------------------------------------------------
// helper: start a hub with the given config, return server + cleanup
func startAuthTestHub(t *testing.T, config HubConfig) (*httptest.Server, *Hub, context.CancelFunc) {
t.Helper()
hub := NewHubWithConfig(config)
ctx, cancel := context.WithCancel(context.Background())
go hub.Run(ctx)
server := httptest.NewServer(hub.Handler())
t.Cleanup(func() {
server.Close()
cancel()
})
return server, hub, cancel
}
func authWSURL(server *httptest.Server) string {
return "ws" + strings.TrimPrefix(server.URL, "http")
}
func TestIntegration_AuthenticatedConnect(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"valid-key": "user-42",
})
var connectedClient *Client
var mu sync.Mutex
server, _, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
OnConnect: func(client *Client) {
mu.Lock()
connectedClient = client
mu.Unlock()
},
})
header := http.Header{}
header.Set("Authorization", "Bearer valid-key")
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
require.NoError(t, err)
defer conn.Close()
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
// Give the hub a moment to process registration
time.Sleep(50 * time.Millisecond)
mu.Lock()
client := connectedClient
mu.Unlock()
require.NotNil(t, client, "OnConnect should have fired")
assert.Equal(t, "user-42", client.UserID)
assert.Equal(t, "api_key", client.Claims["auth_method"])
}
func TestIntegration_RejectedConnect_InvalidKey(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"valid-key": "user-42",
})
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
})
header := http.Header{}
header.Set("Authorization", "Bearer wrong-key")
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
if conn != nil {
conn.Close()
}
require.Error(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, 0, hub.ClientCount())
}
func TestIntegration_RejectedConnect_NoAuthHeader(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"valid-key": "user-42",
})
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
})
// No Authorization header
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil)
if conn != nil {
conn.Close()
}
require.Error(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, 0, hub.ClientCount())
}
func TestIntegration_NilAuthenticator_BackwardCompat(t *testing.T) {
// No authenticator — all connections should be accepted
server, hub, _ := startAuthTestHub(t, HubConfig{})
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil)
require.NoError(t, err)
defer conn.Close()
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
}
func TestIntegration_OnAuthFailure_Callback(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"valid-key": "user-42",
})
var failureMu sync.Mutex
var failureResult AuthResult
var failureRequest *http.Request
failureCalled := false
server, _, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
OnAuthFailure: func(r *http.Request, result AuthResult) {
failureMu.Lock()
failureCalled = true
failureResult = result
failureRequest = r
failureMu.Unlock()
},
})
header := http.Header{}
header.Set("Authorization", "Bearer bad-key")
conn, _, _ := websocket.DefaultDialer.Dial(authWSURL(server), header)
if conn != nil {
conn.Close()
}
// Give callback time to execute
time.Sleep(50 * time.Millisecond)
failureMu.Lock()
defer failureMu.Unlock()
assert.True(t, failureCalled, "OnAuthFailure should have been called")
assert.False(t, failureResult.Valid)
assert.True(t, errors.Is(failureResult.Error, ErrInvalidAPIKey))
assert.NotNil(t, failureRequest)
}
func TestIntegration_MultipleClients_DifferentKeys(t *testing.T) {
auth := NewAPIKeyAuth(map[string]string{
"key-alpha": "user-alpha",
"key-beta": "user-beta",
"key-gamma": "user-gamma",
})
var mu sync.Mutex
connectedClients := make(map[string]*Client)
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
OnConnect: func(client *Client) {
mu.Lock()
connectedClients[client.UserID] = client
mu.Unlock()
},
})
keys := []struct {
key string
userID string
}{
{"key-alpha", "user-alpha"},
{"key-beta", "user-beta"},
{"key-gamma", "user-gamma"},
}
conns := make([]*websocket.Conn, 0, len(keys))
for _, k := range keys {
header := http.Header{}
header.Set("Authorization", "Bearer "+k.key)
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
require.NoError(t, err, "key %s should connect", k.key)
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
conns = append(conns, conn)
}
defer func() {
for _, c := range conns {
c.Close()
}
}()
time.Sleep(100 * time.Millisecond)
assert.Equal(t, 3, hub.ClientCount())
mu.Lock()
defer mu.Unlock()
for _, k := range keys {
client, ok := connectedClients[k.userID]
require.True(t, ok, "should have client for %s", k.userID)
assert.Equal(t, k.userID, client.UserID)
}
}
func TestIntegration_AuthenticatorFunc_WithHub(t *testing.T) {
// Use AuthenticatorFunc as the hub's authenticator
fn := AuthenticatorFunc(func(r *http.Request) AuthResult {
token := r.URL.Query().Get("token")
if token == "magic" {
return AuthResult{
Valid: true,
UserID: "magic-user",
Claims: map[string]any{"source": "query_param"},
}
}
return AuthResult{Valid: false, Error: errors.New("bad token")}
})
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: fn,
})
// Valid token via query parameter
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server)+"?token=magic", nil)
require.NoError(t, err)
defer conn.Close()
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
// Invalid token
conn2, resp2, _ := websocket.DefaultDialer.Dial(authWSURL(server)+"?token=wrong", nil)
if conn2 != nil {
conn2.Close()
}
assert.Equal(t, http.StatusUnauthorized, resp2.StatusCode)
}
func TestIntegration_AuthenticatedClient_ReceivesMessages(t *testing.T) {
// Verify that an authenticated client can still receive messages normally
auth := NewAPIKeyAuth(map[string]string{
"key-1": "user-1",
})
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
})
header := http.Header{}
header.Set("Authorization", "Bearer key-1")
conn, _, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Broadcast a message
err = hub.Broadcast(Message{Type: TypeEvent, Data: "hello"})
require.NoError(t, err)
// Read it
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, data, err := conn.ReadMessage()
require.NoError(t, err)
var msg Message
require.NoError(t, json.Unmarshal(data, &msg))
assert.Equal(t, TypeEvent, msg.Type)
assert.Equal(t, "hello", msg.Data)
}

19
errors.go Normal file
View file

@ -0,0 +1,19 @@
// SPDX-Licence-Identifier: EUPL-1.2
package ws
import "errors"
// Authentication errors returned by the built-in APIKeyAuthenticator.
var (
// ErrMissingAuthHeader is returned when no Authorization header is present.
ErrMissingAuthHeader = errors.New("missing Authorization header")
// ErrMalformedAuthHeader is returned when the Authorization header is
// not in the expected "Bearer <token>" format.
ErrMalformedAuthHeader = errors.New("malformed Authorization header")
// ErrInvalidAPIKey is returned when the provided API key does not
// match any known key.
ErrInvalidAPIKey = errors.New("invalid API key")
)

50
ws.go
View file

@ -14,6 +14,18 @@
// // Register HTTP handler
// http.HandleFunc("/ws", hub.Handler())
//
// # Authentication
//
// The hub supports optional token-based authentication on upgrade. Supply an
// Authenticator via HubConfig to gate connections:
//
// auth := ws.NewAPIKeyAuth(map[string]string{"secret-key": "user-1"})
// hub := ws.NewHubWithConfig(ws.HubConfig{Authenticator: auth})
// go hub.Run(ctx)
//
// When no Authenticator is set (nil), all connections are accepted — preserving
// backward compatibility.
//
// # Message Types
//
// The package defines several message types for different purposes:
@ -104,6 +116,16 @@ type HubConfig struct {
// OnDisconnect is called when a client disconnects from the hub.
OnDisconnect func(client *Client)
// Authenticator validates incoming WebSocket connections during the
// HTTP upgrade handshake. When nil, all connections are accepted
// (backward compatible). When set, connections that fail authentication
// receive an HTTP 401 response and are not upgraded.
Authenticator Authenticator
// OnAuthFailure is called when a connection is rejected by the
// Authenticator. Useful for logging or metrics. Optional.
OnAuthFailure func(r *http.Request, result AuthResult)
}
// DefaultHubConfig returns a HubConfig with sensible defaults.
@ -153,6 +175,15 @@ type Client struct {
send chan []byte
subscriptions map[string]bool
mu sync.RWMutex
// UserID is the authenticated user's identifier, set during the
// upgrade handshake when an Authenticator is configured. Empty
// when no Authenticator is set.
UserID string
// Claims holds arbitrary authentication metadata (e.g. roles,
// scopes). Nil when no Authenticator is set.
Claims map[string]any
}
// Hub manages WebSocket connections and message broadcasting.
@ -422,6 +453,19 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
// Handler returns an HTTP handler for WebSocket connections.
func (h *Hub) Handler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Authenticate if an Authenticator is configured.
var authResult AuthResult
if h.config.Authenticator != nil {
authResult = h.config.Authenticator.Authenticate(r)
if !authResult.Valid {
if h.config.OnAuthFailure != nil {
h.config.OnAuthFailure(r, authResult)
}
http.Error(w, "Unauthorised", http.StatusUnauthorized)
return
}
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
@ -434,6 +478,12 @@ func (h *Hub) Handler() http.HandlerFunc {
subscriptions: make(map[string]bool),
}
// Populate auth fields when authentication succeeded.
if h.config.Authenticator != nil {
client.UserID = authResult.UserID
client.Claims = authResult.Claims
}
h.register <- client
go client.writePump()