Phase 2: Authenticator interface, BearerTokenAuth, QueryTokenAuth. Reject unauthenticated connections before upgrade. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
834 lines
22 KiB
Go
834 lines
22 KiB
Go
// SPDX-Licence-Identifier: EUPL-1.2
|
|
|
|
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Unit tests — APIKeyAuthenticator
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestAPIKeyAuthenticator_ValidKey(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"key-abc": "user-1",
|
|
"key-def": "user-2",
|
|
})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", "Bearer key-abc")
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.True(t, result.Valid)
|
|
assert.Equal(t, "user-1", result.UserID)
|
|
assert.Equal(t, "api_key", result.Claims["auth_method"])
|
|
assert.NoError(t, result.Error)
|
|
}
|
|
|
|
func TestAPIKeyAuthenticator_InvalidKey(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"key-abc": "user-1",
|
|
})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", "Bearer wrong-key")
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.Empty(t, result.UserID)
|
|
assert.True(t, errors.Is(result.Error, ErrInvalidAPIKey))
|
|
}
|
|
|
|
func TestAPIKeyAuthenticator_MissingHeader(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"key-abc": "user-1",
|
|
})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
// No Authorization header set
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.True(t, errors.Is(result.Error, ErrMissingAuthHeader))
|
|
}
|
|
|
|
func TestAPIKeyAuthenticator_MalformedHeader(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"key-abc": "user-1",
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
header string
|
|
}{
|
|
{"bearer without token", "Bearer "},
|
|
{"bearer only", "Bearer"},
|
|
{"wrong scheme", "Basic key-abc"},
|
|
{"no scheme", "key-abc"},
|
|
{"empty bearer with spaces", "Bearer "},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", tt.header)
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.True(t, errors.Is(result.Error, ErrMalformedAuthHeader))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAPIKeyAuthenticator_CaseInsensitiveScheme(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"key-abc": "user-1",
|
|
})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", "bearer key-abc")
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.True(t, result.Valid)
|
|
assert.Equal(t, "user-1", result.UserID)
|
|
}
|
|
|
|
func TestAPIKeyAuthenticator_SecondKey(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"key-abc": "user-1",
|
|
"key-def": "user-2",
|
|
})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", "Bearer key-def")
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.True(t, result.Valid)
|
|
assert.Equal(t, "user-2", result.UserID)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Unit tests — AuthenticatorFunc adapter
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestAuthenticatorFunc_Adapter(t *testing.T) {
|
|
called := false
|
|
fn := AuthenticatorFunc(func(r *http.Request) AuthResult {
|
|
called = true
|
|
return AuthResult{Valid: true, UserID: "func-user"}
|
|
})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
result := fn.Authenticate(r)
|
|
|
|
assert.True(t, called)
|
|
assert.True(t, result.Valid)
|
|
assert.Equal(t, "func-user", result.UserID)
|
|
}
|
|
|
|
func TestAuthenticatorFunc_Rejection(t *testing.T) {
|
|
fn := AuthenticatorFunc(func(r *http.Request) AuthResult {
|
|
return AuthResult{Valid: false, Error: errors.New("custom rejection")}
|
|
})
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
result := fn.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.EqualError(t, result.Error, "custom rejection")
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Unit tests — nil Authenticator (backward compat)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestNilAuthenticator_AllConnectionsAccepted(t *testing.T) {
|
|
hub := NewHub() // No authenticator set
|
|
assert.Nil(t, hub.config.Authenticator)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Integration tests — httptest + gorilla/websocket Dial
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// helper: start a hub with the given config, return server + cleanup
|
|
func startAuthTestHub(t *testing.T, config HubConfig) (*httptest.Server, *Hub, context.CancelFunc) {
|
|
t.Helper()
|
|
hub := NewHubWithConfig(config)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go hub.Run(ctx)
|
|
|
|
server := httptest.NewServer(hub.Handler())
|
|
t.Cleanup(func() {
|
|
server.Close()
|
|
cancel()
|
|
})
|
|
return server, hub, cancel
|
|
}
|
|
|
|
func authWSURL(server *httptest.Server) string {
|
|
return "ws" + strings.TrimPrefix(server.URL, "http")
|
|
}
|
|
|
|
func TestIntegration_AuthenticatedConnect(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"valid-key": "user-42",
|
|
})
|
|
|
|
var connectedClient *Client
|
|
var mu sync.Mutex
|
|
|
|
server, _, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
OnConnect: func(client *Client) {
|
|
mu.Lock()
|
|
connectedClient = client
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
|
|
header := http.Header{}
|
|
header.Set("Authorization", "Bearer valid-key")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
|
|
// Give the hub a moment to process registration
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
mu.Lock()
|
|
client := connectedClient
|
|
mu.Unlock()
|
|
|
|
require.NotNil(t, client, "OnConnect should have fired")
|
|
assert.Equal(t, "user-42", client.UserID)
|
|
assert.Equal(t, "api_key", client.Claims["auth_method"])
|
|
}
|
|
|
|
func TestIntegration_RejectedConnect_InvalidKey(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"valid-key": "user-42",
|
|
})
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
})
|
|
|
|
header := http.Header{}
|
|
header.Set("Authorization", "Bearer wrong-key")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
}
|
|
|
|
func TestIntegration_RejectedConnect_NoAuthHeader(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"valid-key": "user-42",
|
|
})
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
})
|
|
|
|
// No Authorization header
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil)
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
}
|
|
|
|
func TestIntegration_NilAuthenticator_BackwardCompat(t *testing.T) {
|
|
// No authenticator — all connections should be accepted
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{})
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
}
|
|
|
|
func TestIntegration_OnAuthFailure_Callback(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"valid-key": "user-42",
|
|
})
|
|
|
|
var failureMu sync.Mutex
|
|
var failureResult AuthResult
|
|
var failureRequest *http.Request
|
|
failureCalled := false
|
|
|
|
server, _, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
OnAuthFailure: func(r *http.Request, result AuthResult) {
|
|
failureMu.Lock()
|
|
failureCalled = true
|
|
failureResult = result
|
|
failureRequest = r
|
|
failureMu.Unlock()
|
|
},
|
|
})
|
|
|
|
header := http.Header{}
|
|
header.Set("Authorization", "Bearer bad-key")
|
|
|
|
conn, _, _ := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
|
|
// Give callback time to execute
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
failureMu.Lock()
|
|
defer failureMu.Unlock()
|
|
|
|
assert.True(t, failureCalled, "OnAuthFailure should have been called")
|
|
assert.False(t, failureResult.Valid)
|
|
assert.True(t, errors.Is(failureResult.Error, ErrInvalidAPIKey))
|
|
assert.NotNil(t, failureRequest)
|
|
}
|
|
|
|
func TestIntegration_MultipleClients_DifferentKeys(t *testing.T) {
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"key-alpha": "user-alpha",
|
|
"key-beta": "user-beta",
|
|
"key-gamma": "user-gamma",
|
|
})
|
|
|
|
var mu sync.Mutex
|
|
connectedClients := make(map[string]*Client)
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
OnConnect: func(client *Client) {
|
|
mu.Lock()
|
|
connectedClients[client.UserID] = client
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
|
|
keys := []struct {
|
|
key string
|
|
userID string
|
|
}{
|
|
{"key-alpha", "user-alpha"},
|
|
{"key-beta", "user-beta"},
|
|
{"key-gamma", "user-gamma"},
|
|
}
|
|
|
|
conns := make([]*websocket.Conn, 0, len(keys))
|
|
for _, k := range keys {
|
|
header := http.Header{}
|
|
header.Set("Authorization", "Bearer "+k.key)
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
|
require.NoError(t, err, "key %s should connect", k.key)
|
|
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
conns = append(conns, conn)
|
|
}
|
|
defer func() {
|
|
for _, c := range conns {
|
|
c.Close()
|
|
}
|
|
}()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
assert.Equal(t, 3, hub.ClientCount())
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
for _, k := range keys {
|
|
client, ok := connectedClients[k.userID]
|
|
require.True(t, ok, "should have client for %s", k.userID)
|
|
assert.Equal(t, k.userID, client.UserID)
|
|
}
|
|
}
|
|
|
|
func TestIntegration_AuthenticatorFunc_WithHub(t *testing.T) {
|
|
// Use AuthenticatorFunc as the hub's authenticator
|
|
fn := AuthenticatorFunc(func(r *http.Request) AuthResult {
|
|
token := r.URL.Query().Get("token")
|
|
if token == "magic" {
|
|
return AuthResult{
|
|
Valid: true,
|
|
UserID: "magic-user",
|
|
Claims: map[string]any{"source": "query_param"},
|
|
}
|
|
}
|
|
return AuthResult{Valid: false, Error: errors.New("bad token")}
|
|
})
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: fn,
|
|
})
|
|
|
|
// Valid token via query parameter
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server)+"?token=magic", nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
|
|
// Invalid token
|
|
conn2, resp2, _ := websocket.DefaultDialer.Dial(authWSURL(server)+"?token=wrong", nil)
|
|
if conn2 != nil {
|
|
conn2.Close()
|
|
}
|
|
assert.Equal(t, http.StatusUnauthorized, resp2.StatusCode)
|
|
}
|
|
|
|
func TestIntegration_AuthenticatedClient_ReceivesMessages(t *testing.T) {
|
|
// Verify that an authenticated client can still receive messages normally
|
|
auth := NewAPIKeyAuth(map[string]string{
|
|
"key-1": "user-1",
|
|
})
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
})
|
|
|
|
header := http.Header{}
|
|
header.Set("Authorization", "Bearer key-1")
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Broadcast a message
|
|
err = hub.Broadcast(Message{Type: TypeEvent, Data: "hello"})
|
|
require.NoError(t, err)
|
|
|
|
// Read it
|
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
_, data, err := conn.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
var msg Message
|
|
require.NoError(t, json.Unmarshal(data, &msg))
|
|
assert.Equal(t, TypeEvent, msg.Type)
|
|
assert.Equal(t, "hello", msg.Data)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Unit tests — BearerTokenAuth
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestBearerTokenAuth_ValidToken_Good(t *testing.T) {
|
|
auth := &BearerTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
if token == "jwt-abc-123" {
|
|
return AuthResult{
|
|
Valid: true,
|
|
UserID: "user-42",
|
|
Claims: map[string]any{"role": "admin", "auth_method": "jwt"},
|
|
}
|
|
}
|
|
return AuthResult{Valid: false, Error: errors.New("invalid token")}
|
|
},
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", "Bearer jwt-abc-123")
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.True(t, result.Valid)
|
|
assert.Equal(t, "user-42", result.UserID)
|
|
assert.Equal(t, "admin", result.Claims["role"])
|
|
assert.Equal(t, "jwt", result.Claims["auth_method"])
|
|
}
|
|
|
|
func TestBearerTokenAuth_InvalidToken_Bad(t *testing.T) {
|
|
auth := &BearerTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: false, Error: errors.New("token expired")}
|
|
},
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", "Bearer expired-token")
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.EqualError(t, result.Error, "token expired")
|
|
}
|
|
|
|
func TestBearerTokenAuth_MissingHeader_Bad(t *testing.T) {
|
|
auth := &BearerTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
|
},
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.True(t, errors.Is(result.Error, ErrMissingAuthHeader))
|
|
}
|
|
|
|
func TestBearerTokenAuth_MalformedHeader_Bad(t *testing.T) {
|
|
auth := &BearerTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
header string
|
|
}{
|
|
{"wrong scheme", "Basic dXNlcjpwYXNz"},
|
|
{"no scheme", "raw-token-value"},
|
|
{"empty bearer", "Bearer "},
|
|
{"empty bearer with spaces", "Bearer "},
|
|
{"bearer only", "Bearer"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", tt.header)
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.True(t, errors.Is(result.Error, ErrMalformedAuthHeader))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBearerTokenAuth_CaseInsensitiveScheme_Good(t *testing.T) {
|
|
auth := &BearerTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: true, UserID: "user-1"}
|
|
},
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
r.Header.Set("Authorization", "bearer my-token")
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.True(t, result.Valid)
|
|
assert.Equal(t, "user-1", result.UserID)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Integration tests — BearerTokenAuth with Hub
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestIntegration_BearerTokenAuth_AcceptsValidToken_Good(t *testing.T) {
|
|
auth := &BearerTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
if token == "valid-jwt" {
|
|
return AuthResult{
|
|
Valid: true,
|
|
UserID: "jwt-user",
|
|
Claims: map[string]any{"auth_method": "bearer"},
|
|
}
|
|
}
|
|
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
|
},
|
|
}
|
|
|
|
var connectedClient *Client
|
|
var mu sync.Mutex
|
|
|
|
server, _, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
OnConnect: func(client *Client) {
|
|
mu.Lock()
|
|
connectedClient = client
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
|
|
header := http.Header{}
|
|
header.Set("Authorization", "Bearer valid-jwt")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
mu.Lock()
|
|
client := connectedClient
|
|
mu.Unlock()
|
|
|
|
require.NotNil(t, client)
|
|
assert.Equal(t, "jwt-user", client.UserID)
|
|
assert.Equal(t, "bearer", client.Claims["auth_method"])
|
|
}
|
|
|
|
func TestIntegration_BearerTokenAuth_RejectsInvalidToken_Bad(t *testing.T) {
|
|
auth := &BearerTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
|
},
|
|
}
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
})
|
|
|
|
header := http.Header{}
|
|
header.Set("Authorization", "Bearer bad-token")
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Unit tests — QueryTokenAuth
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestQueryTokenAuth_ValidToken_Good(t *testing.T) {
|
|
auth := &QueryTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
if token == "browser-token-456" {
|
|
return AuthResult{
|
|
Valid: true,
|
|
UserID: "browser-user",
|
|
Claims: map[string]any{"auth_method": "query_param"},
|
|
}
|
|
}
|
|
return AuthResult{Valid: false, Error: errors.New("unknown token")}
|
|
},
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws?token=browser-token-456", nil)
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.True(t, result.Valid)
|
|
assert.Equal(t, "browser-user", result.UserID)
|
|
assert.Equal(t, "query_param", result.Claims["auth_method"])
|
|
}
|
|
|
|
func TestQueryTokenAuth_InvalidToken_Bad(t *testing.T) {
|
|
auth := &QueryTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: false, Error: errors.New("unknown token")}
|
|
},
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws?token=bad-token", nil)
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.EqualError(t, result.Error, "unknown token")
|
|
}
|
|
|
|
func TestQueryTokenAuth_MissingParam_Bad(t *testing.T) {
|
|
auth := &QueryTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
|
},
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.Contains(t, result.Error.Error(), "missing token query parameter")
|
|
}
|
|
|
|
func TestQueryTokenAuth_EmptyParam_Bad(t *testing.T) {
|
|
auth := &QueryTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
|
},
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, "/ws?token=", nil)
|
|
|
|
result := auth.Authenticate(r)
|
|
|
|
assert.False(t, result.Valid)
|
|
assert.Contains(t, result.Error.Error(), "missing token query parameter")
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Integration tests — QueryTokenAuth with Hub
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestIntegration_QueryTokenAuth_AcceptsValidToken_Good(t *testing.T) {
|
|
auth := &QueryTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
if token == "browser-secret" {
|
|
return AuthResult{
|
|
Valid: true,
|
|
UserID: "browser-user-99",
|
|
Claims: map[string]any{"origin": "browser"},
|
|
}
|
|
}
|
|
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
|
},
|
|
}
|
|
|
|
var connectedClient *Client
|
|
var mu sync.Mutex
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
OnConnect: func(client *Client) {
|
|
mu.Lock()
|
|
connectedClient = client
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(
|
|
authWSURL(server)+"?token=browser-secret", nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
assert.Equal(t, 1, hub.ClientCount())
|
|
|
|
mu.Lock()
|
|
client := connectedClient
|
|
mu.Unlock()
|
|
|
|
require.NotNil(t, client)
|
|
assert.Equal(t, "browser-user-99", client.UserID)
|
|
assert.Equal(t, "browser", client.Claims["origin"])
|
|
}
|
|
|
|
func TestIntegration_QueryTokenAuth_RejectsInvalidToken_Bad(t *testing.T) {
|
|
auth := &QueryTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
|
},
|
|
}
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
})
|
|
|
|
conn, resp, err := websocket.DefaultDialer.Dial(
|
|
authWSURL(server)+"?token=wrong", nil)
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
}
|
|
|
|
func TestIntegration_QueryTokenAuth_RejectsMissingToken_Bad(t *testing.T) {
|
|
auth := &QueryTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
|
},
|
|
}
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
})
|
|
|
|
// No ?token= parameter
|
|
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil)
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
|
|
require.Error(t, err)
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
assert.Equal(t, 0, hub.ClientCount())
|
|
}
|
|
|
|
func TestIntegration_QueryTokenAuth_EndToEnd_Good(t *testing.T) {
|
|
// Authenticated via query param, then subscribe and receive messages
|
|
auth := &QueryTokenAuth{
|
|
Validate: func(token string) AuthResult {
|
|
if token == "good-token" {
|
|
return AuthResult{Valid: true, UserID: "alice"}
|
|
}
|
|
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
|
},
|
|
}
|
|
|
|
server, hub, _ := startAuthTestHub(t, HubConfig{
|
|
Authenticator: auth,
|
|
})
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(
|
|
authWSURL(server)+"?token=good-token", nil)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Subscribe to a channel
|
|
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "events"})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
assert.Equal(t, 1, hub.ChannelSubscriberCount("events"))
|
|
|
|
// Send a message to the channel
|
|
err = hub.SendToChannel("events", Message{Type: TypeEvent, Data: "hello alice"})
|
|
require.NoError(t, err)
|
|
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
var received Message
|
|
err = conn.ReadJSON(&received)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, TypeEvent, received.Type)
|
|
assert.Equal(t, "hello alice", received.Data)
|
|
}
|