diff --git a/codex-rs/rmcp-client/src/auth_status.rs b/codex-rs/rmcp-client/src/auth_status.rs index c752aec78..7ab72088b 100644 --- a/codex-rs/rmcp-client/src/auth_status.rs +++ b/codex-rs/rmcp-client/src/auth_status.rs @@ -7,6 +7,7 @@ use codex_protocol::protocol::McpAuthStatus; use reqwest::Client; use reqwest::StatusCode; use reqwest::Url; +use reqwest::header::AUTHORIZATION; use reqwest::header::HeaderMap; use serde::Deserialize; use tracing::debug; @@ -33,12 +34,15 @@ pub async fn determine_streamable_http_auth_status( return Ok(McpAuthStatus::BearerToken); } + let default_headers = build_default_headers(http_headers, env_http_headers)?; + if default_headers.contains_key(AUTHORIZATION) { + return Ok(McpAuthStatus::BearerToken); + } + if has_oauth_tokens(server_name, url, store_mode)? { return Ok(McpAuthStatus::OAuth); } - let default_headers = build_default_headers(http_headers, env_http_headers)?; - match supports_oauth_login_with_headers(url, &default_headers).await { Ok(true) => Ok(McpAuthStatus::NotLoggedIn), Ok(false) => Ok(McpAuthStatus::Unsupported), @@ -139,3 +143,84 @@ fn discovery_paths(base_path: &str) -> Vec { candidates } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use serial_test::serial; + use std::collections::HashMap; + use std::ffi::OsString; + + struct EnvVarGuard { + key: String, + original: Option, + } + + impl EnvVarGuard { + fn set(key: &str, value: &str) -> Self { + let original = std::env::var_os(key); + unsafe { + std::env::set_var(key, value); + } + Self { + key: key.to_string(), + original, + } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + if let Some(value) = &self.original { + unsafe { + std::env::set_var(&self.key, value); + } + } else { + unsafe { + std::env::remove_var(&self.key); + } + } + } + } + + #[tokio::test] + async fn determine_auth_status_uses_bearer_token_when_authorization_header_present() { + let status = determine_streamable_http_auth_status( + "server", + "not-a-url", + None, + Some(HashMap::from([( + "Authorization".to_string(), + "Bearer token".to_string(), + )])), + None, + OAuthCredentialsStoreMode::Keyring, + ) + .await + .expect("status should compute"); + + assert_eq!(status, McpAuthStatus::BearerToken); + } + + #[tokio::test] + #[serial(auth_status_env)] + async fn determine_auth_status_uses_bearer_token_when_env_authorization_header_present() { + let _guard = EnvVarGuard::set("CODEX_RMCP_CLIENT_AUTH_STATUS_TEST_TOKEN", "Bearer token"); + let status = determine_streamable_http_auth_status( + "server", + "not-a-url", + None, + None, + Some(HashMap::from([( + "Authorization".to_string(), + "CODEX_RMCP_CLIENT_AUTH_STATUS_TEST_TOKEN".to_string(), + )])), + OAuthCredentialsStoreMode::Keyring, + ) + .await + .expect("status should compute"); + + assert_eq!(status, McpAuthStatus::BearerToken); + } +} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 0deccf391..e1d704596 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -11,6 +11,7 @@ use anyhow::anyhow; use futures::FutureExt; use futures::future::BoxFuture; use oauth2::TokenResponse; +use reqwest::header::AUTHORIZATION; use reqwest::header::HeaderMap; use rmcp::model::CallToolRequestParams; use rmcp::model::CallToolResult; @@ -244,16 +245,18 @@ impl RmcpClient { ) -> Result { let default_headers = build_default_headers(http_headers, env_http_headers)?; - let initial_oauth_tokens = match bearer_token { - Some(_) => None, - None => match load_oauth_tokens(server_name, url, store_mode) { - Ok(tokens) => tokens, - Err(err) => { - warn!("failed to read tokens for server `{server_name}`: {err}"); - None + let initial_oauth_tokens = + if bearer_token.is_none() && !default_headers.contains_key(AUTHORIZATION) { + match load_oauth_tokens(server_name, url, store_mode) { + Ok(tokens) => tokens, + Err(err) => { + warn!("failed to read tokens for server `{server_name}`: {err}"); + None + } } - }, - }; + } else { + None + }; let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() { match create_oauth_transport_and_runtime(