From 893f5261eb620b9fd36ec61cfcae929ceb11b1cd Mon Sep 17 00:00:00 2001 From: Shijie Rao Date: Tue, 9 Dec 2025 17:43:53 -0800 Subject: [PATCH] feat: support mcp in-session login (#7751) ### Summary * Added `mcpServer/oauthLogin` in app server for supporting in session MCP server login * Added `McpServerOauthLoginParams` and `McpServerOauthLoginResponse` to support above method with response returning the auth URL for consumer to open browser or display accordingly. * Added `McpServerOauthLoginCompletedNotification` which the app server would emit on MCP server login success or failure (i.e. timeout). * Refactored rmcp-client oath_login to have the ability on starting a auth server which the codex_message_processor uses for in-session auth. --- codex-rs/Cargo.lock | 1 + .../src/protocol/common.rs | 6 + .../app-server-protocol/src/protocol/v2.rs | 31 ++ codex-rs/app-server/Cargo.toml | 1 + .../app-server/src/codex_message_processor.rs | 126 ++++++++ codex-rs/app-server/src/message_processor.rs | 1 + codex-rs/rmcp-client/src/lib.rs | 2 + .../rmcp-client/src/perform_oauth_login.rs | 283 ++++++++++++++---- 8 files changed, 392 insertions(+), 59 deletions(-) diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 58ba4f2a9..9a3cd95df 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -887,6 +887,7 @@ dependencies = [ "codex-file-search", "codex-login", "codex-protocol", + "codex-rmcp-client", "codex-utils-json-to-toml", "core_test_support", "mcp-types", diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index 285836673..c62acc883 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -139,6 +139,11 @@ client_request_definitions! { response: v2::ModelListResponse, }, + McpServerOauthLogin => "mcpServer/oauth/login" { + params: v2::McpServerOauthLoginParams, + response: v2::McpServerOauthLoginResponse, + }, + McpServersList => "mcpServers/list" { params: v2::ListMcpServersParams, response: v2::ListMcpServersResponse, @@ -524,6 +529,7 @@ server_notification_definitions! { CommandExecutionOutputDelta => "item/commandExecution/outputDelta" (v2::CommandExecutionOutputDeltaNotification), FileChangeOutputDelta => "item/fileChange/outputDelta" (v2::FileChangeOutputDeltaNotification), McpToolCallProgress => "item/mcpToolCall/progress" (v2::McpToolCallProgressNotification), + McpServerOauthLoginCompleted => "mcpServer/oauthLogin/completed" (v2::McpServerOauthLoginCompletedNotification), AccountUpdated => "account/updated" (v2::AccountUpdatedNotification), AccountRateLimitsUpdated => "account/rateLimits/updated" (v2::AccountRateLimitsUpdatedNotification), ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification), diff --git a/codex-rs/app-server-protocol/src/protocol/v2.rs b/codex-rs/app-server-protocol/src/protocol/v2.rs index ea70b805b..dbef55ed1 100644 --- a/codex-rs/app-server-protocol/src/protocol/v2.rs +++ b/codex-rs/app-server-protocol/src/protocol/v2.rs @@ -688,6 +688,26 @@ pub struct ListMcpServersResponse { pub next_cursor: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct McpServerOauthLoginParams { + pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub scopes: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub timeout_secs: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct McpServerOauthLoginResponse { + pub authorization_url: String, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] @@ -1467,6 +1487,17 @@ pub struct McpToolCallProgressNotification { pub message: String, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export_to = "v2/")] +pub struct McpServerOauthLoginCompletedNotification { + pub name: String, + pub success: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[ts(optional)] + pub error: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] #[serde(rename_all = "camelCase")] #[ts(export_to = "v2/")] diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index 99d5a7a14..e4a326a2c 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -26,6 +26,7 @@ codex-login = { workspace = true } codex-protocol = { workspace = true } codex-app-server-protocol = { workspace = true } codex-feedback = { workspace = true } +codex-rmcp-client = { workspace = true } codex-utils-json-to-toml = { workspace = true } chrono = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 65721a698..0a8445055 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -55,6 +55,9 @@ use codex_app_server_protocol::LoginChatGptResponse; use codex_app_server_protocol::LogoutAccountResponse; use codex_app_server_protocol::LogoutChatGptResponse; use codex_app_server_protocol::McpServer; +use codex_app_server_protocol::McpServerOauthLoginCompletedNotification; +use codex_app_server_protocol::McpServerOauthLoginParams; +use codex_app_server_protocol::McpServerOauthLoginResponse; use codex_app_server_protocol::ModelListParams; use codex_app_server_protocol::ModelListResponse; use codex_app_server_protocol::NewConversationParams; @@ -115,6 +118,7 @@ use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config::ConfigToml; use codex_core::config::edit::ConfigEditsBuilder; +use codex_core::config::types::McpServerTransportConfig; use codex_core::config_loader::load_config_as_toml; use codex_core::default_client::get_codex_user_agent; use codex_core::exec::ExecParams; @@ -147,6 +151,7 @@ use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::SessionMetaLine; use codex_protocol::protocol::USER_MESSAGE_BEGIN; use codex_protocol::user_input::UserInput as CoreInputItem; +use codex_rmcp_client::perform_oauth_login_return_url; use codex_utils_json_to_toml::json_to_toml; use std::collections::HashMap; use std::collections::HashSet; @@ -161,6 +166,7 @@ use std::time::Duration; use tokio::select; use tokio::sync::Mutex; use tokio::sync::oneshot; +use toml::Value as TomlValue; use tracing::error; use tracing::info; use tracing::warn; @@ -198,6 +204,7 @@ pub(crate) struct CodexMessageProcessor { outgoing: Arc, codex_linux_sandbox_exe: Option, config: Arc, + cli_overrides: Vec<(String, TomlValue)>, conversation_listeners: HashMap>, active_login: Arc>>, // Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives. @@ -244,6 +251,7 @@ impl CodexMessageProcessor { outgoing: Arc, codex_linux_sandbox_exe: Option, config: Arc, + cli_overrides: Vec<(String, TomlValue)>, feedback: CodexFeedback, ) -> Self { Self { @@ -252,6 +260,7 @@ impl CodexMessageProcessor { outgoing, codex_linux_sandbox_exe, config, + cli_overrides, conversation_listeners: HashMap::new(), active_login: Arc::new(Mutex::new(None)), pending_interrupts: Arc::new(Mutex::new(HashMap::new())), @@ -261,6 +270,16 @@ impl CodexMessageProcessor { } } + async fn load_latest_config(&self) -> Result { + Config::load_with_cli_overrides(self.cli_overrides.clone(), ConfigOverrides::default()) + .await + .map_err(|err| JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("failed to reload config: {err}"), + data: None, + }) + } + fn review_request_from_target( target: ApiReviewTarget, ) -> Result<(ReviewRequest, String), JSONRPCErrorError> { @@ -369,6 +388,9 @@ impl CodexMessageProcessor { ClientRequest::ModelList { request_id, params } => { self.list_models(request_id, params).await; } + ClientRequest::McpServerOauthLogin { request_id, params } => { + self.mcp_server_oauth_login(request_id, params).await; + } ClientRequest::McpServersList { request_id, params } => { self.list_mcp_servers(request_id, params).await; } @@ -1916,6 +1938,110 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } + async fn mcp_server_oauth_login( + &self, + request_id: RequestId, + params: McpServerOauthLoginParams, + ) { + let config = match self.load_latest_config().await { + Ok(config) => config, + Err(error) => { + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + if !config.features.enabled(Feature::RmcpClient) { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: "OAuth login is only supported when [features].rmcp_client is true in config.toml".to_string(), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + + let McpServerOauthLoginParams { + name, + scopes, + timeout_secs, + } = params; + + let Some(server) = config.mcp_servers.get(&name) else { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: format!("No MCP server named '{name}' found."), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + }; + + let (url, http_headers, env_http_headers) = match &server.transport { + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + .. + } => (url.clone(), http_headers.clone(), env_http_headers.clone()), + _ => { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: "OAuth login is only supported for streamable HTTP servers." + .to_string(), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + match perform_oauth_login_return_url( + &name, + &url, + config.mcp_oauth_credentials_store_mode, + http_headers, + env_http_headers, + scopes.as_deref().unwrap_or_default(), + timeout_secs, + ) + .await + { + Ok(handle) => { + let authorization_url = handle.authorization_url().to_string(); + let notification_name = name.clone(); + let outgoing = Arc::clone(&self.outgoing); + + tokio::spawn(async move { + let (success, error) = match handle.wait().await { + Ok(()) => (true, None), + Err(err) => (false, Some(err.to_string())), + }; + + let notification = ServerNotification::McpServerOauthLoginCompleted( + McpServerOauthLoginCompletedNotification { + name: notification_name, + success, + error, + }, + ); + outgoing.send_server_notification(notification).await; + }); + + let response = McpServerOauthLoginResponse { authorization_url }; + self.outgoing.send_response(request_id, response).await; + } + Err(err) => { + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("failed to login to MCP server '{name}': {err}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + } + } + } + async fn list_mcp_servers(&self, request_id: RequestId, params: ListMcpServersParams) { let snapshot = collect_mcp_snapshot(self.config.as_ref()).await; diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 90560e9b3..6a6cf5edb 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -59,6 +59,7 @@ impl MessageProcessor { outgoing.clone(), codex_linux_sandbox_exe, Arc::clone(&config), + cli_overrides.clone(), feedback, ); let config_api = ConfigApi::new(config.codex_home.clone(), cli_overrides); diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index ac617f3d2..954898cea 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -16,7 +16,9 @@ pub use oauth::WrappedOAuthTokenResponse; pub use oauth::delete_oauth_tokens; pub(crate) use oauth::load_oauth_tokens; pub use oauth::save_oauth_tokens; +pub use perform_oauth_login::OauthLoginHandle; pub use perform_oauth_login::perform_oauth_login; +pub use perform_oauth_login::perform_oauth_login_return_url; pub use rmcp::model::ElicitationAction; pub use rmcp_client::Elicitation; pub use rmcp_client::ElicitationResponse; diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index d8ffdd394..9815a3a22 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -22,6 +22,11 @@ use crate::save_oauth_tokens; use crate::utils::apply_default_headers; use crate::utils::build_default_headers; +struct OauthHeaders { + http_headers: Option>, + env_http_headers: Option>, +} + struct CallbackServerGuard { server: Arc, } @@ -40,70 +45,52 @@ pub async fn perform_oauth_login( env_http_headers: Option>, scopes: &[String], ) -> Result<()> { - let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?); - let guard = CallbackServerGuard { - server: Arc::clone(&server), + let headers = OauthHeaders { + http_headers, + env_http_headers, }; + OauthLoginFlow::new( + server_name, + server_url, + store_mode, + headers, + scopes, + true, + None, + ) + .await? + .finish() + .await +} - let redirect_uri = match server.server_addr() { - tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => { - format!("http://{}:{}/callback", addr.ip(), addr.port()) - } - tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => { - format!("http://[{}]:{}/callback", addr.ip(), addr.port()) - } - #[cfg(not(target_os = "windows"))] - _ => return Err(anyhow!("unable to determine callback address")), +pub async fn perform_oauth_login_return_url( + server_name: &str, + server_url: &str, + store_mode: OAuthCredentialsStoreMode, + http_headers: Option>, + env_http_headers: Option>, + scopes: &[String], + timeout_secs: Option, +) -> Result { + let headers = OauthHeaders { + http_headers, + env_http_headers, }; + let flow = OauthLoginFlow::new( + server_name, + server_url, + store_mode, + headers, + scopes, + false, + timeout_secs, + ) + .await?; - let (tx, rx) = oneshot::channel(); - spawn_callback_server(server, tx); + let authorization_url = flow.authorization_url(); + let completion = flow.spawn(); - let default_headers = build_default_headers(http_headers, env_http_headers)?; - let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?; - - let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?; - let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect(); - oauth_state - .start_authorization(&scope_refs, &redirect_uri, Some("Codex")) - .await?; - let auth_url = oauth_state.get_authorization_url().await?; - - println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n"); - - if webbrowser::open(&auth_url).is_err() { - println!("(Browser launch failed; please copy the URL above manually.)"); - } - - let (code, csrf_state) = timeout(Duration::from_secs(300), rx) - .await - .context("timed out waiting for OAuth callback")? - .context("OAuth callback was cancelled")?; - - oauth_state - .handle_callback(&code, &csrf_state) - .await - .context("failed to handle OAuth callback")?; - - let (client_id, credentials_opt) = oauth_state - .get_credentials() - .await - .context("failed to retrieve OAuth credentials")?; - let credentials = - credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?; - - let expires_at = compute_expires_at_millis(&credentials); - let stored = StoredOAuthTokens { - server_name: server_name.to_string(), - url: server_url.to_string(), - client_id, - token_response: WrappedOAuthTokenResponse(credentials), - expires_at, - }; - save_oauth_tokens(server_name, &stored, store_mode)?; - - drop(guard); - Ok(()) + Ok(OauthLoginHandle::new(authorization_url, completion)) } fn spawn_callback_server(server: Arc, tx: oneshot::Sender<(String, String)>) { @@ -160,3 +147,181 @@ fn parse_oauth_callback(path: &str) -> Option { state: state?, }) } + +pub struct OauthLoginHandle { + authorization_url: String, + completion: oneshot::Receiver>, +} + +impl OauthLoginHandle { + fn new(authorization_url: String, completion: oneshot::Receiver>) -> Self { + Self { + authorization_url, + completion, + } + } + + pub fn authorization_url(&self) -> &str { + &self.authorization_url + } + + pub fn into_parts(self) -> (String, oneshot::Receiver>) { + (self.authorization_url, self.completion) + } + + pub async fn wait(self) -> Result<()> { + self.completion + .await + .map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))? + } +} + +struct OauthLoginFlow { + auth_url: String, + oauth_state: OAuthState, + rx: oneshot::Receiver<(String, String)>, + guard: CallbackServerGuard, + server_name: String, + server_url: String, + store_mode: OAuthCredentialsStoreMode, + launch_browser: bool, + timeout: Duration, +} + +impl OauthLoginFlow { + async fn new( + server_name: &str, + server_url: &str, + store_mode: OAuthCredentialsStoreMode, + headers: OauthHeaders, + scopes: &[String], + launch_browser: bool, + timeout_secs: Option, + ) -> Result { + const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300; + + let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?); + let guard = CallbackServerGuard { + server: Arc::clone(&server), + }; + + let redirect_uri = match server.server_addr() { + tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => { + let ip = addr.ip(); + let port = addr.port(); + format!("http://{ip}:{port}/callback") + } + tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => { + let ip = addr.ip(); + let port = addr.port(); + format!("http://[{ip}]:{port}/callback") + } + #[cfg(not(target_os = "windows"))] + _ => return Err(anyhow!("unable to determine callback address")), + }; + + let (tx, rx) = oneshot::channel(); + spawn_callback_server(server, tx); + + let OauthHeaders { + http_headers, + env_http_headers, + } = headers; + let default_headers = build_default_headers(http_headers, env_http_headers)?; + let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?; + + let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?; + let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect(); + oauth_state + .start_authorization(&scope_refs, &redirect_uri, Some("Codex")) + .await?; + let auth_url = oauth_state.get_authorization_url().await?; + let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1); + let timeout = Duration::from_secs(timeout_secs as u64); + + Ok(Self { + auth_url, + oauth_state, + rx, + guard, + server_name: server_name.to_string(), + server_url: server_url.to_string(), + store_mode, + launch_browser, + timeout, + }) + } + + fn authorization_url(&self) -> String { + self.auth_url.clone() + } + + async fn finish(mut self) -> Result<()> { + if self.launch_browser { + let server_name = &self.server_name; + let auth_url = &self.auth_url; + println!( + "Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n" + ); + + if webbrowser::open(auth_url).is_err() { + println!("(Browser launch failed; please copy the URL above manually.)"); + } + } + + let result = async { + let (code, csrf_state) = timeout(self.timeout, &mut self.rx) + .await + .context("timed out waiting for OAuth callback")? + .context("OAuth callback was cancelled")?; + + self.oauth_state + .handle_callback(&code, &csrf_state) + .await + .context("failed to handle OAuth callback")?; + + let (client_id, credentials_opt) = self + .oauth_state + .get_credentials() + .await + .context("failed to retrieve OAuth credentials")?; + let credentials = credentials_opt + .ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?; + + let expires_at = compute_expires_at_millis(&credentials); + let stored = StoredOAuthTokens { + server_name: self.server_name.clone(), + url: self.server_url.clone(), + client_id, + token_response: WrappedOAuthTokenResponse(credentials), + expires_at, + }; + save_oauth_tokens(&self.server_name, &stored, self.store_mode)?; + + Ok(()) + } + .await; + + drop(self.guard); + result + } + + fn spawn(self) -> oneshot::Receiver> { + let server_name_for_logging = self.server_name.clone(); + let (tx, rx) = oneshot::channel(); + + tokio::spawn(async move { + let result = self.finish().await; + + if let Err(err) = &result { + eprintln!( + "Failed to complete OAuth login for '{server_name_for_logging}': {err:#}" + ); + } + + let _ = tx.send(result); + }); + + rx + } +}