feat(ws): add authentication on WebSocket upgrade

Phase 2: Authenticator interface, BearerTokenAuth, QueryTokenAuth.
Reject unauthenticated connections before upgrade.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-20 14:34:57 +00:00
parent da3df0077d
commit ce342a1866
No known key found for this signature in database
GPG key ID: AF404715446AEB41
4 changed files with 482 additions and 1 deletions

View file

@ -19,12 +19,14 @@ go test -v -run Name # Run single test
- `HubConfig` provides configurable heartbeat, pong timeout, write timeout, and connection callbacks
- `Authenticator` interface for token-based auth on upgrade — nil means all connections accepted (backward compat)
- `APIKeyAuthenticator` validates `Authorization: Bearer <key>` against a static key→userID map
- `BearerTokenAuth` validates `Authorization: Bearer <token>` via caller-supplied function (JWT, introspection, etc.)
- `QueryTokenAuth` validates `?token=<value>` query parameter (for browser WebSocket clients)
- `AuthenticatorFunc` adapter lets plain functions satisfy the `Authenticator` interface
- `Client.UserID` and `Client.Claims` populated during authenticated upgrade
- `OnAuthFailure` callback on `HubConfig` for logging/metrics on rejected connections
- `ReconnectingClient` provides client-side reconnection with exponential backoff
- `ConnectionState`: `StateDisconnected`, `StateConnecting`, `StateConnected`
- Coverage: 98.5%
- Coverage: 95.9%
## Coding Standards

View file

@ -58,3 +58,32 @@ Extracted from `forge.lthn.ai/core/go` `pkg/ws/` on 19 Feb 2026.
- `ReconnectConfig`, `ReconnectingClient`, `NewReconnectingClient()`
- `DefaultHeartbeatInterval`, `DefaultPongTimeout`, `DefaultWriteTimeout` constants
- `NewHub()` still works unchanged (uses `DefaultHubConfig()` internally)
## 2026-02-20: Phase 2 auth additions — BearerTokenAuth & QueryTokenAuth (Charon)
### Motivation
The existing `APIKeyAuthenticator` validates against a static key-to-userID map.
Two additional built-in authenticators provide more flexible patterns:
- **`BearerTokenAuth`**: extracts `Authorization: Bearer <token>` and delegates validation
to a caller-supplied function. Suitable for JWT verification, token introspection, or
any custom bearer scheme.
- **`QueryTokenAuth`**: extracts `?token=<value>` from the query string. Useful for
browser WebSocket clients that cannot set custom HTTP headers (the browser's native
`WebSocket` API does not support request headers).
Both use the same `AuthResult` return type and sentinel errors as `APIKeyAuthenticator`.
### Test coverage
- 14 new test functions (Good/Bad naming convention):
- `BearerTokenAuth`: valid token, invalid token, missing header, malformed header (5 cases), case-insensitive scheme.
- `QueryTokenAuth`: valid token, invalid token, missing param, empty param.
- Integration: BearerTokenAuth accepts/rejects with Hub, QueryTokenAuth accepts/rejects/missing with Hub, QueryTokenAuth end-to-end subscribe+receive.
- Coverage: 95.9% (across all files). All tests pass with `-race`.
### API surface additions
- `BearerTokenAuth` struct (implements `Authenticator`)
- `QueryTokenAuth` struct (implements `Authenticator`)

65
auth.go
View file

@ -3,6 +3,7 @@
package ws
import (
"fmt"
"net/http"
"strings"
)
@ -97,3 +98,67 @@ func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult {
},
}
}
// BearerTokenAuth extracts an Authorization: Bearer <token> header and
// validates it using a caller-supplied function. Unlike APIKeyAuthenticator,
// this authenticator delegates validation entirely to the caller, making
// it suitable for JWT verification, token introspection, or any custom
// bearer scheme.
type BearerTokenAuth struct {
// Validate receives the raw bearer token string and should return
// an AuthResult. The caller controls UserID, Claims, and error
// semantics.
Validate func(token string) AuthResult
}
// Authenticate implements the Authenticator interface for bearer tokens.
func (b *BearerTokenAuth) Authenticate(r *http.Request) AuthResult {
header := r.Header.Get("Authorization")
if header == "" {
return AuthResult{
Valid: false,
Error: ErrMissingAuthHeader,
}
}
parts := strings.SplitN(header, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return AuthResult{
Valid: false,
Error: ErrMalformedAuthHeader,
}
}
token := strings.TrimSpace(parts[1])
if token == "" {
return AuthResult{
Valid: false,
Error: ErrMalformedAuthHeader,
}
}
return b.Validate(token)
}
// QueryTokenAuth extracts a token from the ?token= query parameter and
// validates it using a caller-supplied function. This is useful for
// browser clients that cannot set custom headers on WebSocket connections
// (e.g. the browser's native WebSocket API does not support custom headers).
type QueryTokenAuth struct {
// Validate receives the raw token value from the query string and
// should return an AuthResult.
Validate func(token string) AuthResult
}
// Authenticate implements the Authenticator interface for query parameter tokens.
func (q *QueryTokenAuth) Authenticate(r *http.Request) AuthResult {
token := r.URL.Query().Get("token")
if token == "" {
return AuthResult{
Valid: false,
Error: fmt.Errorf("missing token query parameter"),
}
}
return q.Validate(token)
}

View file

@ -447,3 +447,388 @@ func TestIntegration_AuthenticatedClient_ReceivesMessages(t *testing.T) {
assert.Equal(t, TypeEvent, msg.Type)
assert.Equal(t, "hello", msg.Data)
}
// ---------------------------------------------------------------------------
// Unit tests — BearerTokenAuth
// ---------------------------------------------------------------------------
func TestBearerTokenAuth_ValidToken_Good(t *testing.T) {
auth := &BearerTokenAuth{
Validate: func(token string) AuthResult {
if token == "jwt-abc-123" {
return AuthResult{
Valid: true,
UserID: "user-42",
Claims: map[string]any{"role": "admin", "auth_method": "jwt"},
}
}
return AuthResult{Valid: false, Error: errors.New("invalid token")}
},
}
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", "Bearer jwt-abc-123")
result := auth.Authenticate(r)
assert.True(t, result.Valid)
assert.Equal(t, "user-42", result.UserID)
assert.Equal(t, "admin", result.Claims["role"])
assert.Equal(t, "jwt", result.Claims["auth_method"])
}
func TestBearerTokenAuth_InvalidToken_Bad(t *testing.T) {
auth := &BearerTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: false, Error: errors.New("token expired")}
},
}
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", "Bearer expired-token")
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.EqualError(t, result.Error, "token expired")
}
func TestBearerTokenAuth_MissingHeader_Bad(t *testing.T) {
auth := &BearerTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: true, UserID: "should-not-reach"}
},
}
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.True(t, errors.Is(result.Error, ErrMissingAuthHeader))
}
func TestBearerTokenAuth_MalformedHeader_Bad(t *testing.T) {
auth := &BearerTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: true, UserID: "should-not-reach"}
},
}
tests := []struct {
name string
header string
}{
{"wrong scheme", "Basic dXNlcjpwYXNz"},
{"no scheme", "raw-token-value"},
{"empty bearer", "Bearer "},
{"empty bearer with spaces", "Bearer "},
{"bearer only", "Bearer"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", tt.header)
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.True(t, errors.Is(result.Error, ErrMalformedAuthHeader))
})
}
}
func TestBearerTokenAuth_CaseInsensitiveScheme_Good(t *testing.T) {
auth := &BearerTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: true, UserID: "user-1"}
},
}
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
r.Header.Set("Authorization", "bearer my-token")
result := auth.Authenticate(r)
assert.True(t, result.Valid)
assert.Equal(t, "user-1", result.UserID)
}
// ---------------------------------------------------------------------------
// Integration tests — BearerTokenAuth with Hub
// ---------------------------------------------------------------------------
func TestIntegration_BearerTokenAuth_AcceptsValidToken_Good(t *testing.T) {
auth := &BearerTokenAuth{
Validate: func(token string) AuthResult {
if token == "valid-jwt" {
return AuthResult{
Valid: true,
UserID: "jwt-user",
Claims: map[string]any{"auth_method": "bearer"},
}
}
return AuthResult{Valid: false, Error: errors.New("invalid")}
},
}
var connectedClient *Client
var mu sync.Mutex
server, _, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
OnConnect: func(client *Client) {
mu.Lock()
connectedClient = client
mu.Unlock()
},
})
header := http.Header{}
header.Set("Authorization", "Bearer valid-jwt")
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
require.NoError(t, err)
defer conn.Close()
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
time.Sleep(50 * time.Millisecond)
mu.Lock()
client := connectedClient
mu.Unlock()
require.NotNil(t, client)
assert.Equal(t, "jwt-user", client.UserID)
assert.Equal(t, "bearer", client.Claims["auth_method"])
}
func TestIntegration_BearerTokenAuth_RejectsInvalidToken_Bad(t *testing.T) {
auth := &BearerTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: false, Error: errors.New("invalid")}
},
}
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
})
header := http.Header{}
header.Set("Authorization", "Bearer bad-token")
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
if conn != nil {
conn.Close()
}
require.Error(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, 0, hub.ClientCount())
}
// ---------------------------------------------------------------------------
// Unit tests — QueryTokenAuth
// ---------------------------------------------------------------------------
func TestQueryTokenAuth_ValidToken_Good(t *testing.T) {
auth := &QueryTokenAuth{
Validate: func(token string) AuthResult {
if token == "browser-token-456" {
return AuthResult{
Valid: true,
UserID: "browser-user",
Claims: map[string]any{"auth_method": "query_param"},
}
}
return AuthResult{Valid: false, Error: errors.New("unknown token")}
},
}
r := httptest.NewRequest(http.MethodGet, "/ws?token=browser-token-456", nil)
result := auth.Authenticate(r)
assert.True(t, result.Valid)
assert.Equal(t, "browser-user", result.UserID)
assert.Equal(t, "query_param", result.Claims["auth_method"])
}
func TestQueryTokenAuth_InvalidToken_Bad(t *testing.T) {
auth := &QueryTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: false, Error: errors.New("unknown token")}
},
}
r := httptest.NewRequest(http.MethodGet, "/ws?token=bad-token", nil)
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.EqualError(t, result.Error, "unknown token")
}
func TestQueryTokenAuth_MissingParam_Bad(t *testing.T) {
auth := &QueryTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: true, UserID: "should-not-reach"}
},
}
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.Contains(t, result.Error.Error(), "missing token query parameter")
}
func TestQueryTokenAuth_EmptyParam_Bad(t *testing.T) {
auth := &QueryTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: true, UserID: "should-not-reach"}
},
}
r := httptest.NewRequest(http.MethodGet, "/ws?token=", nil)
result := auth.Authenticate(r)
assert.False(t, result.Valid)
assert.Contains(t, result.Error.Error(), "missing token query parameter")
}
// ---------------------------------------------------------------------------
// Integration tests — QueryTokenAuth with Hub
// ---------------------------------------------------------------------------
func TestIntegration_QueryTokenAuth_AcceptsValidToken_Good(t *testing.T) {
auth := &QueryTokenAuth{
Validate: func(token string) AuthResult {
if token == "browser-secret" {
return AuthResult{
Valid: true,
UserID: "browser-user-99",
Claims: map[string]any{"origin": "browser"},
}
}
return AuthResult{Valid: false, Error: errors.New("invalid")}
},
}
var connectedClient *Client
var mu sync.Mutex
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
OnConnect: func(client *Client) {
mu.Lock()
connectedClient = client
mu.Unlock()
},
})
conn, resp, err := websocket.DefaultDialer.Dial(
authWSURL(server)+"?token=browser-secret", nil)
require.NoError(t, err)
defer conn.Close()
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ClientCount())
mu.Lock()
client := connectedClient
mu.Unlock()
require.NotNil(t, client)
assert.Equal(t, "browser-user-99", client.UserID)
assert.Equal(t, "browser", client.Claims["origin"])
}
func TestIntegration_QueryTokenAuth_RejectsInvalidToken_Bad(t *testing.T) {
auth := &QueryTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: false, Error: errors.New("invalid")}
},
}
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
})
conn, resp, err := websocket.DefaultDialer.Dial(
authWSURL(server)+"?token=wrong", nil)
if conn != nil {
conn.Close()
}
require.Error(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, 0, hub.ClientCount())
}
func TestIntegration_QueryTokenAuth_RejectsMissingToken_Bad(t *testing.T) {
auth := &QueryTokenAuth{
Validate: func(token string) AuthResult {
return AuthResult{Valid: true, UserID: "should-not-reach"}
},
}
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
})
// No ?token= parameter
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil)
if conn != nil {
conn.Close()
}
require.Error(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, 0, hub.ClientCount())
}
func TestIntegration_QueryTokenAuth_EndToEnd_Good(t *testing.T) {
// Authenticated via query param, then subscribe and receive messages
auth := &QueryTokenAuth{
Validate: func(token string) AuthResult {
if token == "good-token" {
return AuthResult{Valid: true, UserID: "alice"}
}
return AuthResult{Valid: false, Error: errors.New("invalid")}
},
}
server, hub, _ := startAuthTestHub(t, HubConfig{
Authenticator: auth,
})
conn, _, err := websocket.DefaultDialer.Dial(
authWSURL(server)+"?token=good-token", nil)
require.NoError(t, err)
defer conn.Close()
time.Sleep(50 * time.Millisecond)
// Subscribe to a channel
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "events"})
require.NoError(t, err)
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, hub.ChannelSubscriberCount("events"))
// Send a message to the channel
err = hub.SendToChannel("events", Message{Type: TypeEvent, Data: "hello alice"})
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(time.Second))
var received Message
err = conn.ReadJSON(&received)
require.NoError(t, err)
assert.Equal(t, TypeEvent, received.Type)
assert.Equal(t, "hello alice", received.Data)
}