Merge remote-tracking branch 'origin/main'
# Conflicts: # CLAUDE.md # FINDINGS.md
This commit is contained in:
commit
23aa979271
2 changed files with 450 additions and 0 deletions
65
auth.go
65
auth.go
|
|
@ -3,6 +3,7 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
|
@ -97,3 +98,67 @@ func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult {
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
// BearerTokenAuth extracts an Authorization: Bearer <token> header and
|
||||
// validates it using a caller-supplied function. Unlike APIKeyAuthenticator,
|
||||
// this authenticator delegates validation entirely to the caller, making
|
||||
// it suitable for JWT verification, token introspection, or any custom
|
||||
// bearer scheme.
|
||||
type BearerTokenAuth struct {
|
||||
// Validate receives the raw bearer token string and should return
|
||||
// an AuthResult. The caller controls UserID, Claims, and error
|
||||
// semantics.
|
||||
Validate func(token string) AuthResult
|
||||
}
|
||||
|
||||
// Authenticate implements the Authenticator interface for bearer tokens.
|
||||
func (b *BearerTokenAuth) Authenticate(r *http.Request) AuthResult {
|
||||
header := r.Header.Get("Authorization")
|
||||
if header == "" {
|
||||
return AuthResult{
|
||||
Valid: false,
|
||||
Error: ErrMissingAuthHeader,
|
||||
}
|
||||
}
|
||||
|
||||
parts := strings.SplitN(header, " ", 2)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
return AuthResult{
|
||||
Valid: false,
|
||||
Error: ErrMalformedAuthHeader,
|
||||
}
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(parts[1])
|
||||
if token == "" {
|
||||
return AuthResult{
|
||||
Valid: false,
|
||||
Error: ErrMalformedAuthHeader,
|
||||
}
|
||||
}
|
||||
|
||||
return b.Validate(token)
|
||||
}
|
||||
|
||||
// QueryTokenAuth extracts a token from the ?token= query parameter and
|
||||
// validates it using a caller-supplied function. This is useful for
|
||||
// browser clients that cannot set custom headers on WebSocket connections
|
||||
// (e.g. the browser's native WebSocket API does not support custom headers).
|
||||
type QueryTokenAuth struct {
|
||||
// Validate receives the raw token value from the query string and
|
||||
// should return an AuthResult.
|
||||
Validate func(token string) AuthResult
|
||||
}
|
||||
|
||||
// Authenticate implements the Authenticator interface for query parameter tokens.
|
||||
func (q *QueryTokenAuth) Authenticate(r *http.Request) AuthResult {
|
||||
token := r.URL.Query().Get("token")
|
||||
if token == "" {
|
||||
return AuthResult{
|
||||
Valid: false,
|
||||
Error: fmt.Errorf("missing token query parameter"),
|
||||
}
|
||||
}
|
||||
|
||||
return q.Validate(token)
|
||||
}
|
||||
|
|
|
|||
385
auth_test.go
385
auth_test.go
|
|
@ -447,3 +447,388 @@ func TestIntegration_AuthenticatedClient_ReceivesMessages(t *testing.T) {
|
|||
assert.Equal(t, TypeEvent, msg.Type)
|
||||
assert.Equal(t, "hello", msg.Data)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Unit tests — BearerTokenAuth
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBearerTokenAuth_ValidToken_Good(t *testing.T) {
|
||||
auth := &BearerTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
if token == "jwt-abc-123" {
|
||||
return AuthResult{
|
||||
Valid: true,
|
||||
UserID: "user-42",
|
||||
Claims: map[string]any{"role": "admin", "auth_method": "jwt"},
|
||||
}
|
||||
}
|
||||
return AuthResult{Valid: false, Error: errors.New("invalid token")}
|
||||
},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
||||
r.Header.Set("Authorization", "Bearer jwt-abc-123")
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.True(t, result.Valid)
|
||||
assert.Equal(t, "user-42", result.UserID)
|
||||
assert.Equal(t, "admin", result.Claims["role"])
|
||||
assert.Equal(t, "jwt", result.Claims["auth_method"])
|
||||
}
|
||||
|
||||
func TestBearerTokenAuth_InvalidToken_Bad(t *testing.T) {
|
||||
auth := &BearerTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: false, Error: errors.New("token expired")}
|
||||
},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
||||
r.Header.Set("Authorization", "Bearer expired-token")
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.False(t, result.Valid)
|
||||
assert.EqualError(t, result.Error, "token expired")
|
||||
}
|
||||
|
||||
func TestBearerTokenAuth_MissingHeader_Bad(t *testing.T) {
|
||||
auth := &BearerTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
||||
},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.False(t, result.Valid)
|
||||
assert.True(t, errors.Is(result.Error, ErrMissingAuthHeader))
|
||||
}
|
||||
|
||||
func TestBearerTokenAuth_MalformedHeader_Bad(t *testing.T) {
|
||||
auth := &BearerTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
}{
|
||||
{"wrong scheme", "Basic dXNlcjpwYXNz"},
|
||||
{"no scheme", "raw-token-value"},
|
||||
{"empty bearer", "Bearer "},
|
||||
{"empty bearer with spaces", "Bearer "},
|
||||
{"bearer only", "Bearer"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
||||
r.Header.Set("Authorization", tt.header)
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.False(t, result.Valid)
|
||||
assert.True(t, errors.Is(result.Error, ErrMalformedAuthHeader))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBearerTokenAuth_CaseInsensitiveScheme_Good(t *testing.T) {
|
||||
auth := &BearerTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: true, UserID: "user-1"}
|
||||
},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
||||
r.Header.Set("Authorization", "bearer my-token")
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.True(t, result.Valid)
|
||||
assert.Equal(t, "user-1", result.UserID)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Integration tests — BearerTokenAuth with Hub
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIntegration_BearerTokenAuth_AcceptsValidToken_Good(t *testing.T) {
|
||||
auth := &BearerTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
if token == "valid-jwt" {
|
||||
return AuthResult{
|
||||
Valid: true,
|
||||
UserID: "jwt-user",
|
||||
Claims: map[string]any{"auth_method": "bearer"},
|
||||
}
|
||||
}
|
||||
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
||||
},
|
||||
}
|
||||
|
||||
var connectedClient *Client
|
||||
var mu sync.Mutex
|
||||
|
||||
server, _, _ := startAuthTestHub(t, HubConfig{
|
||||
Authenticator: auth,
|
||||
OnConnect: func(client *Client) {
|
||||
mu.Lock()
|
||||
connectedClient = client
|
||||
mu.Unlock()
|
||||
},
|
||||
})
|
||||
|
||||
header := http.Header{}
|
||||
header.Set("Authorization", "Bearer valid-jwt")
|
||||
|
||||
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
client := connectedClient
|
||||
mu.Unlock()
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, "jwt-user", client.UserID)
|
||||
assert.Equal(t, "bearer", client.Claims["auth_method"])
|
||||
}
|
||||
|
||||
func TestIntegration_BearerTokenAuth_RejectsInvalidToken_Bad(t *testing.T) {
|
||||
auth := &BearerTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
||||
},
|
||||
}
|
||||
|
||||
server, hub, _ := startAuthTestHub(t, HubConfig{
|
||||
Authenticator: auth,
|
||||
})
|
||||
|
||||
header := http.Header{}
|
||||
header.Set("Authorization", "Bearer bad-token")
|
||||
|
||||
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header)
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, 0, hub.ClientCount())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Unit tests — QueryTokenAuth
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestQueryTokenAuth_ValidToken_Good(t *testing.T) {
|
||||
auth := &QueryTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
if token == "browser-token-456" {
|
||||
return AuthResult{
|
||||
Valid: true,
|
||||
UserID: "browser-user",
|
||||
Claims: map[string]any{"auth_method": "query_param"},
|
||||
}
|
||||
}
|
||||
return AuthResult{Valid: false, Error: errors.New("unknown token")}
|
||||
},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws?token=browser-token-456", nil)
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.True(t, result.Valid)
|
||||
assert.Equal(t, "browser-user", result.UserID)
|
||||
assert.Equal(t, "query_param", result.Claims["auth_method"])
|
||||
}
|
||||
|
||||
func TestQueryTokenAuth_InvalidToken_Bad(t *testing.T) {
|
||||
auth := &QueryTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: false, Error: errors.New("unknown token")}
|
||||
},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws?token=bad-token", nil)
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.False(t, result.Valid)
|
||||
assert.EqualError(t, result.Error, "unknown token")
|
||||
}
|
||||
|
||||
func TestQueryTokenAuth_MissingParam_Bad(t *testing.T) {
|
||||
auth := &QueryTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
||||
},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws", nil)
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.False(t, result.Valid)
|
||||
assert.Contains(t, result.Error.Error(), "missing token query parameter")
|
||||
}
|
||||
|
||||
func TestQueryTokenAuth_EmptyParam_Bad(t *testing.T) {
|
||||
auth := &QueryTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
||||
},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/ws?token=", nil)
|
||||
|
||||
result := auth.Authenticate(r)
|
||||
|
||||
assert.False(t, result.Valid)
|
||||
assert.Contains(t, result.Error.Error(), "missing token query parameter")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Integration tests — QueryTokenAuth with Hub
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIntegration_QueryTokenAuth_AcceptsValidToken_Good(t *testing.T) {
|
||||
auth := &QueryTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
if token == "browser-secret" {
|
||||
return AuthResult{
|
||||
Valid: true,
|
||||
UserID: "browser-user-99",
|
||||
Claims: map[string]any{"origin": "browser"},
|
||||
}
|
||||
}
|
||||
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
||||
},
|
||||
}
|
||||
|
||||
var connectedClient *Client
|
||||
var mu sync.Mutex
|
||||
|
||||
server, hub, _ := startAuthTestHub(t, HubConfig{
|
||||
Authenticator: auth,
|
||||
OnConnect: func(client *Client) {
|
||||
mu.Lock()
|
||||
connectedClient = client
|
||||
mu.Unlock()
|
||||
},
|
||||
})
|
||||
|
||||
conn, resp, err := websocket.DefaultDialer.Dial(
|
||||
authWSURL(server)+"?token=browser-secret", nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 1, hub.ClientCount())
|
||||
|
||||
mu.Lock()
|
||||
client := connectedClient
|
||||
mu.Unlock()
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, "browser-user-99", client.UserID)
|
||||
assert.Equal(t, "browser", client.Claims["origin"])
|
||||
}
|
||||
|
||||
func TestIntegration_QueryTokenAuth_RejectsInvalidToken_Bad(t *testing.T) {
|
||||
auth := &QueryTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
||||
},
|
||||
}
|
||||
|
||||
server, hub, _ := startAuthTestHub(t, HubConfig{
|
||||
Authenticator: auth,
|
||||
})
|
||||
|
||||
conn, resp, err := websocket.DefaultDialer.Dial(
|
||||
authWSURL(server)+"?token=wrong", nil)
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, 0, hub.ClientCount())
|
||||
}
|
||||
|
||||
func TestIntegration_QueryTokenAuth_RejectsMissingToken_Bad(t *testing.T) {
|
||||
auth := &QueryTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
return AuthResult{Valid: true, UserID: "should-not-reach"}
|
||||
},
|
||||
}
|
||||
|
||||
server, hub, _ := startAuthTestHub(t, HubConfig{
|
||||
Authenticator: auth,
|
||||
})
|
||||
|
||||
// No ?token= parameter
|
||||
conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil)
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, 0, hub.ClientCount())
|
||||
}
|
||||
|
||||
func TestIntegration_QueryTokenAuth_EndToEnd_Good(t *testing.T) {
|
||||
// Authenticated via query param, then subscribe and receive messages
|
||||
auth := &QueryTokenAuth{
|
||||
Validate: func(token string) AuthResult {
|
||||
if token == "good-token" {
|
||||
return AuthResult{Valid: true, UserID: "alice"}
|
||||
}
|
||||
return AuthResult{Valid: false, Error: errors.New("invalid")}
|
||||
},
|
||||
}
|
||||
|
||||
server, hub, _ := startAuthTestHub(t, HubConfig{
|
||||
Authenticator: auth,
|
||||
})
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(
|
||||
authWSURL(server)+"?token=good-token", nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Subscribe to a channel
|
||||
err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "events"})
|
||||
require.NoError(t, err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, 1, hub.ChannelSubscriberCount("events"))
|
||||
|
||||
// Send a message to the channel
|
||||
err = hub.SendToChannel("events", Message{Type: TypeEvent, Data: "hello alice"})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
var received Message
|
||||
err = conn.ReadJSON(&received)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, TypeEvent, received.Type)
|
||||
assert.Equal(t, "hello alice", received.Data)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue