diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 5b684a0eb..73d979b5d 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -216,6 +216,8 @@ use codex_core::find_thread_name_by_id; use codex_core::find_thread_names_by_ids; use codex_core::find_thread_path_by_id_str; use codex_core::git_info::git_diff_to_remote; +use codex_core::mcp::auth::discover_supported_scopes; +use codex_core::mcp::auth::resolve_oauth_scopes; use codex_core::mcp::collect_mcp_snapshot; use codex_core::mcp::group_tools_by_server; use codex_core::models_manager::collaboration_mode_presets::CollaborationModesConfig; @@ -4554,7 +4556,13 @@ impl CodexMessageProcessor { } }; - let scopes = scopes.or_else(|| server.scopes.clone()); + let discovered_scopes = if scopes.is_none() && server.scopes.is_none() { + discover_supported_scopes(&server.transport).await + } else { + None + }; + let resolved_scopes = + resolve_oauth_scopes(scopes, server.scopes.clone(), discovered_scopes); match perform_oauth_login_return_url( &name, @@ -4562,7 +4570,7 @@ impl CodexMessageProcessor { config.mcp_oauth_credentials_store_mode, http_headers, env_http_headers, - scopes.as_deref().unwrap_or_default(), + &resolved_scopes.scopes, server.oauth_resource.as_deref(), timeout_secs, config.mcp_oauth_callback_port, diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 00a04693f..d4e1888b8 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -14,8 +14,12 @@ use codex_core::config::types::McpServerConfig; use codex_core::config::types::McpServerTransportConfig; use codex_core::mcp::McpManager; use codex_core::mcp::auth::McpOAuthLoginSupport; +use codex_core::mcp::auth::ResolvedMcpOAuthScopes; use codex_core::mcp::auth::compute_auth_statuses; +use codex_core::mcp::auth::discover_supported_scopes; use codex_core::mcp::auth::oauth_login_support; +use codex_core::mcp::auth::resolve_oauth_scopes; +use codex_core::mcp::auth::should_retry_without_scopes; use codex_core::plugins::PluginsManager; use codex_protocol::protocol::McpAuthStatus; use codex_rmcp_client::delete_oauth_tokens; @@ -183,6 +187,54 @@ impl McpCli { } } +/// Preserve compatibility with servers that still expect the legacy empty-scope +/// OAuth request. If a discovered-scope request is rejected by the provider, +/// retry the login flow once without scopes. +#[allow(clippy::too_many_arguments)] +async fn perform_oauth_login_retry_without_scopes( + name: &str, + url: &str, + store_mode: codex_rmcp_client::OAuthCredentialsStoreMode, + http_headers: Option>, + env_http_headers: Option>, + resolved_scopes: &ResolvedMcpOAuthScopes, + oauth_resource: Option<&str>, + callback_port: Option, + callback_url: Option<&str>, +) -> Result<()> { + match perform_oauth_login( + name, + url, + store_mode, + http_headers.clone(), + env_http_headers.clone(), + &resolved_scopes.scopes, + oauth_resource, + callback_port, + callback_url, + ) + .await + { + Ok(()) => Ok(()), + Err(err) if should_retry_without_scopes(resolved_scopes, &err) => { + println!("OAuth provider rejected discovered scopes. Retrying without scopes…"); + perform_oauth_login( + name, + url, + store_mode, + http_headers, + env_http_headers, + &[], + oauth_resource, + callback_port, + callback_url, + ) + .await + } + Err(err) => Err(err), + } +} + async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<()> { // Validate any provided overrides even though they are not currently applied. let overrides = config_overrides @@ -269,13 +321,15 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re match oauth_login_support(&transport).await { McpOAuthLoginSupport::Supported(oauth_config) => { println!("Detected OAuth support. Starting OAuth flow…"); - perform_oauth_login( + let resolved_scopes = + resolve_oauth_scopes(None, None, oauth_config.discovered_scopes.clone()); + perform_oauth_login_retry_without_scopes( &name, &oauth_config.url, config.mcp_oauth_credentials_store_mode, oauth_config.http_headers, oauth_config.env_http_headers, - &Vec::new(), + &resolved_scopes, None, config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), @@ -351,18 +405,22 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) _ => bail!("OAuth login is only supported for streamable HTTP servers."), }; - let mut scopes = scopes; - if scopes.is_empty() { - scopes = server.scopes.clone().unwrap_or_default(); - } + let explicit_scopes = (!scopes.is_empty()).then_some(scopes); + let discovered_scopes = if explicit_scopes.is_none() && server.scopes.is_none() { + discover_supported_scopes(&server.transport).await + } else { + None + }; + let resolved_scopes = + resolve_oauth_scopes(explicit_scopes, server.scopes.clone(), discovered_scopes); - perform_oauth_login( + perform_oauth_login_retry_without_scopes( &name, &url, config.mcp_oauth_credentials_store_mode, http_headers, env_http_headers, - &scopes, + &resolved_scopes, server.oauth_resource.as_deref(), config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), diff --git a/codex-rs/core/src/mcp/auth.rs b/codex-rs/core/src/mcp/auth.rs index f095c930d..06ddbdd51 100644 --- a/codex-rs/core/src/mcp/auth.rs +++ b/codex-rs/core/src/mcp/auth.rs @@ -3,8 +3,9 @@ use std::collections::HashMap; use anyhow::Result; use codex_protocol::protocol::McpAuthStatus; use codex_rmcp_client::OAuthCredentialsStoreMode; +use codex_rmcp_client::OAuthProviderError; use codex_rmcp_client::determine_streamable_http_auth_status; -use codex_rmcp_client::supports_oauth_login; +use codex_rmcp_client::discover_streamable_http_oauth; use futures::future::join_all; use tracing::warn; @@ -16,6 +17,7 @@ pub struct McpOAuthLoginConfig { pub url: String, pub http_headers: Option>, pub env_http_headers: Option>, + pub discovered_scopes: Option>, } #[derive(Debug)] @@ -25,6 +27,20 @@ pub enum McpOAuthLoginSupport { Unknown(anyhow::Error), } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum McpOAuthScopesSource { + Explicit, + Configured, + Discovered, + Empty, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ResolvedMcpOAuthScopes { + pub scopes: Vec, + pub source: McpOAuthScopesSource, +} + pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAuthLoginSupport { let McpServerTransportConfig::StreamableHttp { url, @@ -40,17 +56,67 @@ pub async fn oauth_login_support(transport: &McpServerTransportConfig) -> McpOAu return McpOAuthLoginSupport::Unsupported; } - match supports_oauth_login(url).await { - Ok(true) => McpOAuthLoginSupport::Supported(McpOAuthLoginConfig { + match discover_streamable_http_oauth(url, http_headers.clone(), env_http_headers.clone()).await + { + Ok(Some(discovery)) => McpOAuthLoginSupport::Supported(McpOAuthLoginConfig { url: url.clone(), http_headers: http_headers.clone(), env_http_headers: env_http_headers.clone(), + discovered_scopes: discovery.scopes_supported, }), - Ok(false) => McpOAuthLoginSupport::Unsupported, + Ok(None) => McpOAuthLoginSupport::Unsupported, Err(err) => McpOAuthLoginSupport::Unknown(err), } } +pub async fn discover_supported_scopes( + transport: &McpServerTransportConfig, +) -> Option> { + match oauth_login_support(transport).await { + McpOAuthLoginSupport::Supported(config) => config.discovered_scopes, + McpOAuthLoginSupport::Unsupported | McpOAuthLoginSupport::Unknown(_) => None, + } +} + +pub fn resolve_oauth_scopes( + explicit_scopes: Option>, + configured_scopes: Option>, + discovered_scopes: Option>, +) -> ResolvedMcpOAuthScopes { + if let Some(scopes) = explicit_scopes { + return ResolvedMcpOAuthScopes { + scopes, + source: McpOAuthScopesSource::Explicit, + }; + } + + if let Some(scopes) = configured_scopes { + return ResolvedMcpOAuthScopes { + scopes, + source: McpOAuthScopesSource::Configured, + }; + } + + if let Some(scopes) = discovered_scopes + && !scopes.is_empty() + { + return ResolvedMcpOAuthScopes { + scopes, + source: McpOAuthScopesSource::Discovered, + }; + } + + ResolvedMcpOAuthScopes { + scopes: Vec::new(), + source: McpOAuthScopesSource::Empty, + } +} + +pub fn should_retry_without_scopes(scopes: &ResolvedMcpOAuthScopes, error: &anyhow::Error) -> bool { + scopes.source == McpOAuthScopesSource::Discovered + && error.downcast_ref::().is_some() +} + #[derive(Debug, Clone)] pub struct McpAuthStatusEntry { pub config: McpServerConfig, @@ -111,3 +177,112 @@ async fn compute_auth_status( } } } + +#[cfg(test)] +mod tests { + use anyhow::anyhow; + use pretty_assertions::assert_eq; + + use super::McpOAuthScopesSource; + use super::OAuthProviderError; + use super::ResolvedMcpOAuthScopes; + use super::resolve_oauth_scopes; + use super::should_retry_without_scopes; + + #[test] + fn resolve_oauth_scopes_prefers_explicit() { + let resolved = resolve_oauth_scopes( + Some(vec!["explicit".to_string()]), + Some(vec!["configured".to_string()]), + Some(vec!["discovered".to_string()]), + ); + + assert_eq!( + resolved, + ResolvedMcpOAuthScopes { + scopes: vec!["explicit".to_string()], + source: McpOAuthScopesSource::Explicit, + } + ); + } + + #[test] + fn resolve_oauth_scopes_prefers_configured_over_discovered() { + let resolved = resolve_oauth_scopes( + None, + Some(vec!["configured".to_string()]), + Some(vec!["discovered".to_string()]), + ); + + assert_eq!( + resolved, + ResolvedMcpOAuthScopes { + scopes: vec!["configured".to_string()], + source: McpOAuthScopesSource::Configured, + } + ); + } + + #[test] + fn resolve_oauth_scopes_uses_discovered_when_needed() { + let resolved = resolve_oauth_scopes(None, None, Some(vec!["discovered".to_string()])); + + assert_eq!( + resolved, + ResolvedMcpOAuthScopes { + scopes: vec!["discovered".to_string()], + source: McpOAuthScopesSource::Discovered, + } + ); + } + + #[test] + fn resolve_oauth_scopes_preserves_explicitly_empty_configured_scopes() { + let resolved = resolve_oauth_scopes(None, Some(Vec::new()), Some(vec!["ignored".into()])); + + assert_eq!( + resolved, + ResolvedMcpOAuthScopes { + scopes: Vec::new(), + source: McpOAuthScopesSource::Configured, + } + ); + } + + #[test] + fn resolve_oauth_scopes_falls_back_to_empty() { + let resolved = resolve_oauth_scopes(None, None, None); + + assert_eq!( + resolved, + ResolvedMcpOAuthScopes { + scopes: Vec::new(), + source: McpOAuthScopesSource::Empty, + } + ); + } + + #[test] + fn should_retry_without_scopes_only_for_discovered_provider_errors() { + let discovered = ResolvedMcpOAuthScopes { + scopes: vec!["scope".to_string()], + source: McpOAuthScopesSource::Discovered, + }; + let provider_error = anyhow!(OAuthProviderError::new( + Some("invalid_scope".to_string()), + Some("scope rejected".to_string()), + )); + + assert!(should_retry_without_scopes(&discovered, &provider_error)); + + let configured = ResolvedMcpOAuthScopes { + scopes: vec!["scope".to_string()], + source: McpOAuthScopesSource::Configured, + }; + assert!(!should_retry_without_scopes(&configured, &provider_error)); + assert!(!should_retry_without_scopes( + &discovered, + &anyhow!("timed out waiting for OAuth callback"), + )); + } +} diff --git a/codex-rs/core/src/mcp/skill_dependencies.rs b/codex-rs/core/src/mcp/skill_dependencies.rs index e9d77a33f..15f09932a 100644 --- a/codex-rs/core/src/mcp/skill_dependencies.rs +++ b/codex-rs/core/src/mcp/skill_dependencies.rs @@ -13,6 +13,8 @@ use tracing::warn; use super::auth::McpOAuthLoginSupport; use super::auth::oauth_login_support; +use super::auth::resolve_oauth_scopes; +use super::auth::should_retry_without_scopes; use crate::codex::Session; use crate::codex::TurnContext; use crate::config::Config; @@ -236,20 +238,52 @@ pub(crate) async fn maybe_install_mcp_dependencies( ) .await; - if let Err(err) = perform_oauth_login( + let resolved_scopes = resolve_oauth_scopes( + None, + server_config.scopes.clone(), + oauth_config.discovered_scopes.clone(), + ); + let first_attempt = perform_oauth_login( &name, &oauth_config.url, config.mcp_oauth_credentials_store_mode, - oauth_config.http_headers, - oauth_config.env_http_headers, - &[], + oauth_config.http_headers.clone(), + oauth_config.env_http_headers.clone(), + &resolved_scopes.scopes, server_config.oauth_resource.as_deref(), config.mcp_oauth_callback_port, config.mcp_oauth_callback_url.as_deref(), ) - .await - { - warn!("failed to login to MCP dependency {name}: {err}"); + .await; + + if let Err(err) = first_attempt { + if should_retry_without_scopes(&resolved_scopes, &err) { + sess.notify_background_event( + turn_context, + format!( + "Retrying MCP {name} authentication without scopes after provider rejection." + ), + ) + .await; + + if let Err(err) = perform_oauth_login( + &name, + &oauth_config.url, + config.mcp_oauth_credentials_store_mode, + oauth_config.http_headers, + oauth_config.env_http_headers, + &[], + server_config.oauth_resource.as_deref(), + config.mcp_oauth_callback_port, + config.mcp_oauth_callback_url.as_deref(), + ) + .await + { + warn!("failed to login to MCP dependency {name}: {err}"); + } + } else { + warn!("failed to login to MCP dependency {name}: {err}"); + } } } diff --git a/codex-rs/rmcp-client/src/auth_status.rs b/codex-rs/rmcp-client/src/auth_status.rs index 7ab72088b..67ef0f756 100644 --- a/codex-rs/rmcp-client/src/auth_status.rs +++ b/codex-rs/rmcp-client/src/auth_status.rs @@ -21,6 +21,11 @@ const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5); const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version"; const OAUTH_DISCOVERY_VERSION: &str = "2024-11-05"; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamableHttpOAuthDiscovery { + pub scopes_supported: Option>, +} + /// Determine the authentication status for a streamable HTTP MCP server. pub async fn determine_streamable_http_auth_status( server_name: &str, @@ -43,9 +48,9 @@ pub async fn determine_streamable_http_auth_status( return Ok(McpAuthStatus::OAuth); } - match supports_oauth_login_with_headers(url, &default_headers).await { - Ok(true) => Ok(McpAuthStatus::NotLoggedIn), - Ok(false) => Ok(McpAuthStatus::Unsupported), + match discover_streamable_http_oauth_with_headers(url, &default_headers).await { + Ok(Some(_)) => Ok(McpAuthStatus::NotLoggedIn), + Ok(None) => Ok(McpAuthStatus::Unsupported), Err(error) => { debug!( "failed to detect OAuth support for MCP server `{server_name}` at {url}: {error:?}" @@ -57,10 +62,24 @@ pub async fn determine_streamable_http_auth_status( /// Attempt to determine whether a streamable HTTP MCP server advertises OAuth login. pub async fn supports_oauth_login(url: &str) -> Result { - supports_oauth_login_with_headers(url, &HeaderMap::new()).await + Ok(discover_streamable_http_oauth(url, None, None) + .await? + .is_some()) } -async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result { +pub async fn discover_streamable_http_oauth( + url: &str, + http_headers: Option>, + env_http_headers: Option>, +) -> Result> { + let default_headers = build_default_headers(http_headers, env_http_headers)?; + discover_streamable_http_oauth_with_headers(url, &default_headers).await +} + +async fn discover_streamable_http_oauth_with_headers( + url: &str, + default_headers: &HeaderMap, +) -> Result> { let base_url = Url::parse(url)?; // Use no_proxy to avoid a bug in the system-configuration crate that @@ -99,7 +118,9 @@ async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMa }; if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() { - return Ok(true); + return Ok(Some(StreamableHttpOAuthDiscovery { + scopes_supported: normalize_scopes(metadata.scopes_supported), + })); } } @@ -107,7 +128,7 @@ async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMa debug!("OAuth discovery requests failed for {url}: {err:?}"); } - Ok(false) + Ok(None) } #[derive(Debug, Deserialize)] @@ -116,6 +137,30 @@ struct OAuthDiscoveryMetadata { authorization_endpoint: Option, #[serde(default)] token_endpoint: Option, + #[serde(default)] + scopes_supported: Option>, +} + +fn normalize_scopes(scopes_supported: Option>) -> Option> { + let scopes_supported = scopes_supported?; + + let mut normalized = Vec::new(); + for scope in scopes_supported { + let scope = scope.trim(); + if scope.is_empty() { + continue; + } + let scope = scope.to_string(); + if !normalized.contains(&scope) { + normalized.push(scope); + } + } + + if normalized.is_empty() { + None + } else { + Some(normalized) + } } /// Implements RFC 8414 section 3.1 for discovering well-known oauth endpoints. @@ -147,10 +192,50 @@ fn discovery_paths(base_path: &str) -> Vec { #[cfg(test)] mod tests { use super::*; + use axum::Json; + use axum::Router; + use axum::routing::get; use pretty_assertions::assert_eq; use serial_test::serial; use std::collections::HashMap; use std::ffi::OsString; + use tokio::task::JoinHandle; + + struct TestServer { + url: String, + handle: JoinHandle<()>, + } + + impl Drop for TestServer { + fn drop(&mut self) { + self.handle.abort(); + } + } + + async fn spawn_oauth_discovery_server(metadata: serde_json::Value) -> TestServer { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let address = listener.local_addr().expect("listener should have address"); + let app = Router::new().route( + "/.well-known/oauth-authorization-server/mcp", + get({ + let metadata = metadata.clone(); + move || { + let metadata = metadata.clone(); + async move { Json(metadata) } + } + }), + ); + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.expect("server should run"); + }); + + TestServer { + url: format!("http://{address}/mcp"), + handle, + } + } struct EnvVarGuard { key: String, @@ -223,4 +308,56 @@ mod tests { assert_eq!(status, McpAuthStatus::BearerToken); } + + #[tokio::test] + async fn discover_streamable_http_oauth_returns_normalized_scopes() { + let server = spawn_oauth_discovery_server(serde_json::json!({ + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + "scopes_supported": ["profile", " email ", "profile", "", " "], + })) + .await; + + let discovery = discover_streamable_http_oauth(&server.url, None, None) + .await + .expect("discovery should succeed") + .expect("oauth support should be detected"); + + assert_eq!( + discovery.scopes_supported, + Some(vec!["profile".to_string(), "email".to_string()]) + ); + } + + #[tokio::test] + async fn discover_streamable_http_oauth_ignores_empty_scopes() { + let server = spawn_oauth_discovery_server(serde_json::json!({ + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + "scopes_supported": ["", " "], + })) + .await; + + let discovery = discover_streamable_http_oauth(&server.url, None, None) + .await + .expect("discovery should succeed") + .expect("oauth support should be detected"); + + assert_eq!(discovery.scopes_supported, None); + } + + #[tokio::test] + async fn supports_oauth_login_does_not_require_scopes_supported() { + let server = spawn_oauth_discovery_server(serde_json::json!({ + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + })) + .await; + + let supported = supports_oauth_login(&server.url) + .await + .expect("support check should succeed"); + + assert!(supported); + } } diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index a10d3b29a..0edd0f152 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -6,7 +6,9 @@ mod program_resolver; mod rmcp_client; mod utils; +pub use auth_status::StreamableHttpOAuthDiscovery; pub use auth_status::determine_streamable_http_auth_status; +pub use auth_status::discover_streamable_http_oauth; pub use auth_status::supports_oauth_login; pub use codex_protocol::protocol::McpAuthStatus; pub use oauth::OAuthCredentialsStoreMode; @@ -15,6 +17,7 @@ pub use oauth::WrappedOAuthTokenResponse; pub use oauth::delete_oauth_tokens; pub(crate) use oauth::load_oauth_tokens; pub use oauth::save_oauth_tokens; +pub use perform_oauth_login::OAuthProviderError; pub use perform_oauth_login::OauthLoginHandle; pub use perform_oauth_login::perform_oauth_login; pub use perform_oauth_login::perform_oauth_login_return_url; diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index c71799c62..71ae01396 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -39,6 +39,36 @@ impl Drop for CallbackServerGuard { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OAuthProviderError { + error: Option, + error_description: Option, +} + +impl OAuthProviderError { + pub fn new(error: Option, error_description: Option) -> Self { + Self { + error, + error_description, + } + } +} + +impl std::fmt::Display for OAuthProviderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match (self.error.as_deref(), self.error_description.as_deref()) { + (Some(error), Some(error_description)) => { + write!(f, "OAuth provider returned `{error}`: {error_description}") + } + (Some(error), None) => write!(f, "OAuth provider returned `{error}`"), + (None, Some(error_description)) => write!(f, "OAuth error: {error_description}"), + (None, None) => write!(f, "OAuth provider returned an error"), + } + } +} + +impl std::error::Error for OAuthProviderError {} + #[allow(clippy::too_many_arguments)] pub async fn perform_oauth_login( server_name: &str, @@ -111,7 +141,7 @@ pub async fn perform_oauth_login_return_url( fn spawn_callback_server( server: Arc, - tx: oneshot::Sender<(String, String)>, + tx: oneshot::Sender, expected_callback_path: String, ) { tokio::task::spawn_blocking(move || { @@ -125,17 +155,22 @@ fn spawn_callback_server( if let Err(err) = request.respond(response) { eprintln!("Failed to respond to OAuth callback: {err}"); } - if let Err(err) = tx.send((code, state)) { + if let Err(err) = + tx.send(CallbackResult::Success(OauthCallbackResult { code, state })) + { eprintln!("Failed to send OAuth callback: {err:?}"); } break; } - CallbackOutcome::Error(description) => { - let response = Response::from_string(format!("OAuth error: {description}")) - .with_status_code(400); + CallbackOutcome::Error(error) => { + let response = Response::from_string(error.to_string()).with_status_code(400); if let Err(err) = request.respond(response) { eprintln!("Failed to respond to OAuth callback: {err}"); } + if let Err(err) = tx.send(CallbackResult::Error(error)) { + eprintln!("Failed to send OAuth callback error: {err:?}"); + } + break; } CallbackOutcome::Invalid => { let response = @@ -149,14 +184,22 @@ fn spawn_callback_server( }); } +#[derive(Debug, Clone, PartialEq, Eq)] struct OauthCallbackResult { code: String, state: String, } +#[derive(Debug)] +enum CallbackResult { + Success(OauthCallbackResult), + Error(OAuthProviderError), +} + +#[derive(Debug, PartialEq, Eq)] enum CallbackOutcome { Success(OauthCallbackResult), - Error(String), + Error(OAuthProviderError), Invalid, } @@ -170,6 +213,7 @@ fn parse_oauth_callback(path: &str, expected_callback_path: &str) -> CallbackOut let mut code = None; let mut state = None; + let mut error = None; let mut error_description = None; for pair in query.split('&') { @@ -183,6 +227,7 @@ fn parse_oauth_callback(path: &str, expected_callback_path: &str) -> CallbackOut match key { "code" => code = Some(decoded), "state" => state = Some(decoded), + "error" => error = Some(decoded), "error_description" => error_description = Some(decoded), _ => {} } @@ -192,8 +237,8 @@ fn parse_oauth_callback(path: &str, expected_callback_path: &str) -> CallbackOut return CallbackOutcome::Success(OauthCallbackResult { code, state }); } - if let Some(description) = error_description { - return CallbackOutcome::Error(description); + if error.is_some() || error_description.is_some() { + return CallbackOutcome::Error(OAuthProviderError::new(error, error_description)); } CallbackOutcome::Invalid @@ -230,7 +275,7 @@ impl OauthLoginHandle { struct OauthLoginFlow { auth_url: String, oauth_state: OAuthState, - rx: oneshot::Receiver<(String, String)>, + rx: oneshot::Receiver, guard: CallbackServerGuard, server_name: String, server_url: String, @@ -384,10 +429,17 @@ impl OauthLoginFlow { } let result = async { - let (code, csrf_state) = timeout(self.timeout, &mut self.rx) + let callback = timeout(self.timeout, &mut self.rx) .await .context("timed out waiting for OAuth callback")? .context("OAuth callback was cancelled")?; + let OauthCallbackResult { + code, + state: csrf_state, + } = match callback { + CallbackResult::Success(callback) => callback, + CallbackResult::Error(error) => return Err(anyhow!(error)), + }; self.oauth_state .handle_callback(&code, &csrf_state) @@ -462,6 +514,7 @@ mod tests { use pretty_assertions::assert_eq; use super::CallbackOutcome; + use super::OAuthProviderError; use super::append_query_param; use super::callback_path_from_redirect_uri; use super::parse_oauth_callback; @@ -484,6 +537,22 @@ mod tests { assert!(matches!(parsed, CallbackOutcome::Invalid)); } + #[test] + fn parse_oauth_callback_returns_provider_error() { + let parsed = parse_oauth_callback( + "/callback?error=invalid_scope&error_description=scope%20rejected", + "/callback", + ); + + assert_eq!( + parsed, + CallbackOutcome::Error(OAuthProviderError::new( + Some("invalid_scope".to_string()), + Some("scope rejected".to_string()), + )) + ); + } + #[test] fn callback_path_comes_from_redirect_uri() { let path = callback_path_from_redirect_uri("https://example.com/oauth/callback")