diff --git a/auth.go b/auth.go index 0513a81..c35b4f6 100644 --- a/auth.go +++ b/auth.go @@ -3,6 +3,7 @@ package ws import ( + "fmt" "net/http" "strings" ) @@ -97,3 +98,67 @@ func (a *APIKeyAuthenticator) Authenticate(r *http.Request) AuthResult { }, } } + +// BearerTokenAuth extracts an Authorization: Bearer header and +// validates it using a caller-supplied function. Unlike APIKeyAuthenticator, +// this authenticator delegates validation entirely to the caller, making +// it suitable for JWT verification, token introspection, or any custom +// bearer scheme. +type BearerTokenAuth struct { + // Validate receives the raw bearer token string and should return + // an AuthResult. The caller controls UserID, Claims, and error + // semantics. + Validate func(token string) AuthResult +} + +// Authenticate implements the Authenticator interface for bearer tokens. +func (b *BearerTokenAuth) Authenticate(r *http.Request) AuthResult { + header := r.Header.Get("Authorization") + if header == "" { + return AuthResult{ + Valid: false, + Error: ErrMissingAuthHeader, + } + } + + parts := strings.SplitN(header, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return AuthResult{ + Valid: false, + Error: ErrMalformedAuthHeader, + } + } + + token := strings.TrimSpace(parts[1]) + if token == "" { + return AuthResult{ + Valid: false, + Error: ErrMalformedAuthHeader, + } + } + + return b.Validate(token) +} + +// QueryTokenAuth extracts a token from the ?token= query parameter and +// validates it using a caller-supplied function. This is useful for +// browser clients that cannot set custom headers on WebSocket connections +// (e.g. the browser's native WebSocket API does not support custom headers). +type QueryTokenAuth struct { + // Validate receives the raw token value from the query string and + // should return an AuthResult. + Validate func(token string) AuthResult +} + +// Authenticate implements the Authenticator interface for query parameter tokens. +func (q *QueryTokenAuth) Authenticate(r *http.Request) AuthResult { + token := r.URL.Query().Get("token") + if token == "" { + return AuthResult{ + Valid: false, + Error: fmt.Errorf("missing token query parameter"), + } + } + + return q.Validate(token) +} diff --git a/auth_test.go b/auth_test.go index d2a460d..9c4a7a4 100644 --- a/auth_test.go +++ b/auth_test.go @@ -447,3 +447,388 @@ func TestIntegration_AuthenticatedClient_ReceivesMessages(t *testing.T) { assert.Equal(t, TypeEvent, msg.Type) assert.Equal(t, "hello", msg.Data) } + +// --------------------------------------------------------------------------- +// Unit tests — BearerTokenAuth +// --------------------------------------------------------------------------- + +func TestBearerTokenAuth_ValidToken_Good(t *testing.T) { + auth := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + if token == "jwt-abc-123" { + return AuthResult{ + Valid: true, + UserID: "user-42", + Claims: map[string]any{"role": "admin", "auth_method": "jwt"}, + } + } + return AuthResult{Valid: false, Error: errors.New("invalid token")} + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer jwt-abc-123") + + result := auth.Authenticate(r) + + assert.True(t, result.Valid) + assert.Equal(t, "user-42", result.UserID) + assert.Equal(t, "admin", result.Claims["role"]) + assert.Equal(t, "jwt", result.Claims["auth_method"]) +} + +func TestBearerTokenAuth_InvalidToken_Bad(t *testing.T) { + auth := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: false, Error: errors.New("token expired")} + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "Bearer expired-token") + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.EqualError(t, result.Error, "token expired") +} + +func TestBearerTokenAuth_MissingHeader_Bad(t *testing.T) { + auth := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: true, UserID: "should-not-reach"} + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.True(t, errors.Is(result.Error, ErrMissingAuthHeader)) +} + +func TestBearerTokenAuth_MalformedHeader_Bad(t *testing.T) { + auth := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: true, UserID: "should-not-reach"} + }, + } + + tests := []struct { + name string + header string + }{ + {"wrong scheme", "Basic dXNlcjpwYXNz"}, + {"no scheme", "raw-token-value"}, + {"empty bearer", "Bearer "}, + {"empty bearer with spaces", "Bearer "}, + {"bearer only", "Bearer"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", tt.header) + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.True(t, errors.Is(result.Error, ErrMalformedAuthHeader)) + }) + } +} + +func TestBearerTokenAuth_CaseInsensitiveScheme_Good(t *testing.T) { + auth := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: true, UserID: "user-1"} + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + r.Header.Set("Authorization", "bearer my-token") + + result := auth.Authenticate(r) + + assert.True(t, result.Valid) + assert.Equal(t, "user-1", result.UserID) +} + +// --------------------------------------------------------------------------- +// Integration tests — BearerTokenAuth with Hub +// --------------------------------------------------------------------------- + +func TestIntegration_BearerTokenAuth_AcceptsValidToken_Good(t *testing.T) { + auth := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + if token == "valid-jwt" { + return AuthResult{ + Valid: true, + UserID: "jwt-user", + Claims: map[string]any{"auth_method": "bearer"}, + } + } + return AuthResult{Valid: false, Error: errors.New("invalid")} + }, + } + + var connectedClient *Client + var mu sync.Mutex + + server, _, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + OnConnect: func(client *Client) { + mu.Lock() + connectedClient = client + mu.Unlock() + }, + }) + + header := http.Header{} + header.Set("Authorization", "Bearer valid-jwt") + + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) + require.NoError(t, err) + defer conn.Close() + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + time.Sleep(50 * time.Millisecond) + + mu.Lock() + client := connectedClient + mu.Unlock() + + require.NotNil(t, client) + assert.Equal(t, "jwt-user", client.UserID) + assert.Equal(t, "bearer", client.Claims["auth_method"]) +} + +func TestIntegration_BearerTokenAuth_RejectsInvalidToken_Bad(t *testing.T) { + auth := &BearerTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: false, Error: errors.New("invalid")} + }, + } + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + }) + + header := http.Header{} + header.Set("Authorization", "Bearer bad-token") + + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), header) + if conn != nil { + conn.Close() + } + + require.Error(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, 0, hub.ClientCount()) +} + +// --------------------------------------------------------------------------- +// Unit tests — QueryTokenAuth +// --------------------------------------------------------------------------- + +func TestQueryTokenAuth_ValidToken_Good(t *testing.T) { + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + if token == "browser-token-456" { + return AuthResult{ + Valid: true, + UserID: "browser-user", + Claims: map[string]any{"auth_method": "query_param"}, + } + } + return AuthResult{Valid: false, Error: errors.New("unknown token")} + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws?token=browser-token-456", nil) + + result := auth.Authenticate(r) + + assert.True(t, result.Valid) + assert.Equal(t, "browser-user", result.UserID) + assert.Equal(t, "query_param", result.Claims["auth_method"]) +} + +func TestQueryTokenAuth_InvalidToken_Bad(t *testing.T) { + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: false, Error: errors.New("unknown token")} + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws?token=bad-token", nil) + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.EqualError(t, result.Error, "unknown token") +} + +func TestQueryTokenAuth_MissingParam_Bad(t *testing.T) { + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: true, UserID: "should-not-reach"} + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws", nil) + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.Contains(t, result.Error.Error(), "missing token query parameter") +} + +func TestQueryTokenAuth_EmptyParam_Bad(t *testing.T) { + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: true, UserID: "should-not-reach"} + }, + } + + r := httptest.NewRequest(http.MethodGet, "/ws?token=", nil) + + result := auth.Authenticate(r) + + assert.False(t, result.Valid) + assert.Contains(t, result.Error.Error(), "missing token query parameter") +} + +// --------------------------------------------------------------------------- +// Integration tests — QueryTokenAuth with Hub +// --------------------------------------------------------------------------- + +func TestIntegration_QueryTokenAuth_AcceptsValidToken_Good(t *testing.T) { + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + if token == "browser-secret" { + return AuthResult{ + Valid: true, + UserID: "browser-user-99", + Claims: map[string]any{"origin": "browser"}, + } + } + return AuthResult{Valid: false, Error: errors.New("invalid")} + }, + } + + var connectedClient *Client + var mu sync.Mutex + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + OnConnect: func(client *Client) { + mu.Lock() + connectedClient = client + mu.Unlock() + }, + }) + + conn, resp, err := websocket.DefaultDialer.Dial( + authWSURL(server)+"?token=browser-secret", nil) + require.NoError(t, err) + defer conn.Close() + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, hub.ClientCount()) + + mu.Lock() + client := connectedClient + mu.Unlock() + + require.NotNil(t, client) + assert.Equal(t, "browser-user-99", client.UserID) + assert.Equal(t, "browser", client.Claims["origin"]) +} + +func TestIntegration_QueryTokenAuth_RejectsInvalidToken_Bad(t *testing.T) { + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: false, Error: errors.New("invalid")} + }, + } + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + }) + + conn, resp, err := websocket.DefaultDialer.Dial( + authWSURL(server)+"?token=wrong", nil) + if conn != nil { + conn.Close() + } + + require.Error(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, 0, hub.ClientCount()) +} + +func TestIntegration_QueryTokenAuth_RejectsMissingToken_Bad(t *testing.T) { + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + return AuthResult{Valid: true, UserID: "should-not-reach"} + }, + } + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + }) + + // No ?token= parameter + conn, resp, err := websocket.DefaultDialer.Dial(authWSURL(server), nil) + if conn != nil { + conn.Close() + } + + require.Error(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, 0, hub.ClientCount()) +} + +func TestIntegration_QueryTokenAuth_EndToEnd_Good(t *testing.T) { + // Authenticated via query param, then subscribe and receive messages + auth := &QueryTokenAuth{ + Validate: func(token string) AuthResult { + if token == "good-token" { + return AuthResult{Valid: true, UserID: "alice"} + } + return AuthResult{Valid: false, Error: errors.New("invalid")} + }, + } + + server, hub, _ := startAuthTestHub(t, HubConfig{ + Authenticator: auth, + }) + + conn, _, err := websocket.DefaultDialer.Dial( + authWSURL(server)+"?token=good-token", nil) + require.NoError(t, err) + defer conn.Close() + + time.Sleep(50 * time.Millisecond) + + // Subscribe to a channel + err = conn.WriteJSON(Message{Type: TypeSubscribe, Data: "events"}) + require.NoError(t, err) + time.Sleep(50 * time.Millisecond) + + assert.Equal(t, 1, hub.ChannelSubscriberCount("events")) + + // Send a message to the channel + err = hub.SendToChannel("events", Message{Type: TypeEvent, Data: "hello alice"}) + require.NoError(t, err) + + conn.SetReadDeadline(time.Now().Add(time.Second)) + var received Message + err = conn.ReadJSON(&received) + require.NoError(t, err) + assert.Equal(t, TypeEvent, received.Type) + assert.Equal(t, "hello alice", received.Data) +}