feat: support mcp in-session login (#7751)

### Summary
* Added `mcpServer/oauthLogin` in app server for supporting in session
MCP server login
* Added `McpServerOauthLoginParams` and `McpServerOauthLoginResponse` to
support above method with response returning the auth URL for consumer
to open browser or display accordingly.
* Added `McpServerOauthLoginCompletedNotification` which the app server
would emit on MCP server login success or failure (i.e. timeout).
* Refactored rmcp-client oath_login to have the ability on starting a
auth server which the codex_message_processor uses for in-session auth.
This commit is contained in:
Shijie Rao 2025-12-09 17:43:53 -08:00 committed by GitHub
parent fa4cac1e6b
commit 893f5261eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 392 additions and 59 deletions

1
codex-rs/Cargo.lock generated
View file

@ -887,6 +887,7 @@ dependencies = [
"codex-file-search",
"codex-login",
"codex-protocol",
"codex-rmcp-client",
"codex-utils-json-to-toml",
"core_test_support",
"mcp-types",

View file

@ -139,6 +139,11 @@ client_request_definitions! {
response: v2::ModelListResponse,
},
McpServerOauthLogin => "mcpServer/oauth/login" {
params: v2::McpServerOauthLoginParams,
response: v2::McpServerOauthLoginResponse,
},
McpServersList => "mcpServers/list" {
params: v2::ListMcpServersParams,
response: v2::ListMcpServersResponse,
@ -524,6 +529,7 @@ server_notification_definitions! {
CommandExecutionOutputDelta => "item/commandExecution/outputDelta" (v2::CommandExecutionOutputDeltaNotification),
FileChangeOutputDelta => "item/fileChange/outputDelta" (v2::FileChangeOutputDeltaNotification),
McpToolCallProgress => "item/mcpToolCall/progress" (v2::McpToolCallProgressNotification),
McpServerOauthLoginCompleted => "mcpServer/oauthLogin/completed" (v2::McpServerOauthLoginCompletedNotification),
AccountUpdated => "account/updated" (v2::AccountUpdatedNotification),
AccountRateLimitsUpdated => "account/rateLimits/updated" (v2::AccountRateLimitsUpdatedNotification),
ReasoningSummaryTextDelta => "item/reasoning/summaryTextDelta" (v2::ReasoningSummaryTextDeltaNotification),

View file

@ -688,6 +688,26 @@ pub struct ListMcpServersResponse {
pub next_cursor: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct McpServerOauthLoginParams {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub scopes: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub timeout_secs: Option<i64>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct McpServerOauthLoginResponse {
pub authorization_url: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
@ -1467,6 +1487,17 @@ pub struct McpToolCallProgressNotification {
pub message: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct McpServerOauthLoginCompletedNotification {
pub name: String,
pub success: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub error: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]

View file

@ -26,6 +26,7 @@ codex-login = { workspace = true }
codex-protocol = { workspace = true }
codex-app-server-protocol = { workspace = true }
codex-feedback = { workspace = true }
codex-rmcp-client = { workspace = true }
codex-utils-json-to-toml = { workspace = true }
chrono = { workspace = true }
serde = { workspace = true, features = ["derive"] }

View file

@ -55,6 +55,9 @@ use codex_app_server_protocol::LoginChatGptResponse;
use codex_app_server_protocol::LogoutAccountResponse;
use codex_app_server_protocol::LogoutChatGptResponse;
use codex_app_server_protocol::McpServer;
use codex_app_server_protocol::McpServerOauthLoginCompletedNotification;
use codex_app_server_protocol::McpServerOauthLoginParams;
use codex_app_server_protocol::McpServerOauthLoginResponse;
use codex_app_server_protocol::ModelListParams;
use codex_app_server_protocol::ModelListResponse;
use codex_app_server_protocol::NewConversationParams;
@ -115,6 +118,7 @@ use codex_core::config::Config;
use codex_core::config::ConfigOverrides;
use codex_core::config::ConfigToml;
use codex_core::config::edit::ConfigEditsBuilder;
use codex_core::config::types::McpServerTransportConfig;
use codex_core::config_loader::load_config_as_toml;
use codex_core::default_client::get_codex_user_agent;
use codex_core::exec::ExecParams;
@ -147,6 +151,7 @@ use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
use codex_protocol::user_input::UserInput as CoreInputItem;
use codex_rmcp_client::perform_oauth_login_return_url;
use codex_utils_json_to_toml::json_to_toml;
use std::collections::HashMap;
use std::collections::HashSet;
@ -161,6 +166,7 @@ use std::time::Duration;
use tokio::select;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use toml::Value as TomlValue;
use tracing::error;
use tracing::info;
use tracing::warn;
@ -198,6 +204,7 @@ pub(crate) struct CodexMessageProcessor {
outgoing: Arc<OutgoingMessageSender>,
codex_linux_sandbox_exe: Option<PathBuf>,
config: Arc<Config>,
cli_overrides: Vec<(String, TomlValue)>,
conversation_listeners: HashMap<Uuid, oneshot::Sender<()>>,
active_login: Arc<Mutex<Option<ActiveLogin>>>,
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
@ -244,6 +251,7 @@ impl CodexMessageProcessor {
outgoing: Arc<OutgoingMessageSender>,
codex_linux_sandbox_exe: Option<PathBuf>,
config: Arc<Config>,
cli_overrides: Vec<(String, TomlValue)>,
feedback: CodexFeedback,
) -> Self {
Self {
@ -252,6 +260,7 @@ impl CodexMessageProcessor {
outgoing,
codex_linux_sandbox_exe,
config,
cli_overrides,
conversation_listeners: HashMap::new(),
active_login: Arc::new(Mutex::new(None)),
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
@ -261,6 +270,16 @@ impl CodexMessageProcessor {
}
}
async fn load_latest_config(&self) -> Result<Config, JSONRPCErrorError> {
Config::load_with_cli_overrides(self.cli_overrides.clone(), ConfigOverrides::default())
.await
.map_err(|err| JSONRPCErrorError {
code: INTERNAL_ERROR_CODE,
message: format!("failed to reload config: {err}"),
data: None,
})
}
fn review_request_from_target(
target: ApiReviewTarget,
) -> Result<(ReviewRequest, String), JSONRPCErrorError> {
@ -369,6 +388,9 @@ impl CodexMessageProcessor {
ClientRequest::ModelList { request_id, params } => {
self.list_models(request_id, params).await;
}
ClientRequest::McpServerOauthLogin { request_id, params } => {
self.mcp_server_oauth_login(request_id, params).await;
}
ClientRequest::McpServersList { request_id, params } => {
self.list_mcp_servers(request_id, params).await;
}
@ -1916,6 +1938,110 @@ impl CodexMessageProcessor {
self.outgoing.send_response(request_id, response).await;
}
async fn mcp_server_oauth_login(
&self,
request_id: RequestId,
params: McpServerOauthLoginParams,
) {
let config = match self.load_latest_config().await {
Ok(config) => config,
Err(error) => {
self.outgoing.send_error(request_id, error).await;
return;
}
};
if !config.features.enabled(Feature::RmcpClient) {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: "OAuth login is only supported when [features].rmcp_client is true in config.toml".to_string(),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
let McpServerOauthLoginParams {
name,
scopes,
timeout_secs,
} = params;
let Some(server) = config.mcp_servers.get(&name) else {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("No MCP server named '{name}' found."),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
};
let (url, http_headers, env_http_headers) = match &server.transport {
McpServerTransportConfig::StreamableHttp {
url,
http_headers,
env_http_headers,
..
} => (url.clone(), http_headers.clone(), env_http_headers.clone()),
_ => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: "OAuth login is only supported for streamable HTTP servers."
.to_string(),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
};
match perform_oauth_login_return_url(
&name,
&url,
config.mcp_oauth_credentials_store_mode,
http_headers,
env_http_headers,
scopes.as_deref().unwrap_or_default(),
timeout_secs,
)
.await
{
Ok(handle) => {
let authorization_url = handle.authorization_url().to_string();
let notification_name = name.clone();
let outgoing = Arc::clone(&self.outgoing);
tokio::spawn(async move {
let (success, error) = match handle.wait().await {
Ok(()) => (true, None),
Err(err) => (false, Some(err.to_string())),
};
let notification = ServerNotification::McpServerOauthLoginCompleted(
McpServerOauthLoginCompletedNotification {
name: notification_name,
success,
error,
},
);
outgoing.send_server_notification(notification).await;
});
let response = McpServerOauthLoginResponse { authorization_url };
self.outgoing.send_response(request_id, response).await;
}
Err(err) => {
let error = JSONRPCErrorError {
code: INTERNAL_ERROR_CODE,
message: format!("failed to login to MCP server '{name}': {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
}
}
}
async fn list_mcp_servers(&self, request_id: RequestId, params: ListMcpServersParams) {
let snapshot = collect_mcp_snapshot(self.config.as_ref()).await;

View file

@ -59,6 +59,7 @@ impl MessageProcessor {
outgoing.clone(),
codex_linux_sandbox_exe,
Arc::clone(&config),
cli_overrides.clone(),
feedback,
);
let config_api = ConfigApi::new(config.codex_home.clone(), cli_overrides);

View file

@ -16,7 +16,9 @@ pub use oauth::WrappedOAuthTokenResponse;
pub use oauth::delete_oauth_tokens;
pub(crate) use oauth::load_oauth_tokens;
pub use oauth::save_oauth_tokens;
pub use perform_oauth_login::OauthLoginHandle;
pub use perform_oauth_login::perform_oauth_login;
pub use perform_oauth_login::perform_oauth_login_return_url;
pub use rmcp::model::ElicitationAction;
pub use rmcp_client::Elicitation;
pub use rmcp_client::ElicitationResponse;

View file

@ -22,6 +22,11 @@ use crate::save_oauth_tokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
struct OauthHeaders {
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
}
struct CallbackServerGuard {
server: Arc<Server>,
}
@ -40,70 +45,52 @@ pub async fn perform_oauth_login(
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
) -> Result<()> {
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
let guard = CallbackServerGuard {
server: Arc::clone(&server),
let headers = OauthHeaders {
http_headers,
env_http_headers,
};
OauthLoginFlow::new(
server_name,
server_url,
store_mode,
headers,
scopes,
true,
None,
)
.await?
.finish()
.await
}
let redirect_uri = match server.server_addr() {
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
format!("http://{}:{}/callback", addr.ip(), addr.port())
}
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
format!("http://[{}]:{}/callback", addr.ip(), addr.port())
}
#[cfg(not(target_os = "windows"))]
_ => return Err(anyhow!("unable to determine callback address")),
pub async fn perform_oauth_login_return_url(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
timeout_secs: Option<i64>,
) -> Result<OauthLoginHandle> {
let headers = OauthHeaders {
http_headers,
env_http_headers,
};
let flow = OauthLoginFlow::new(
server_name,
server_url,
store_mode,
headers,
scopes,
false,
timeout_secs,
)
.await?;
let (tx, rx) = oneshot::channel();
spawn_callback_server(server, tx);
let authorization_url = flow.authorization_url();
let completion = flow.spawn();
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
oauth_state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await?;
let auth_url = oauth_state.get_authorization_url().await?;
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
if webbrowser::open(&auth_url).is_err() {
println!("(Browser launch failed; please copy the URL above manually.)");
}
let (code, csrf_state) = timeout(Duration::from_secs(300), rx)
.await
.context("timed out waiting for OAuth callback")?
.context("OAuth callback was cancelled")?;
oauth_state
.handle_callback(&code, &csrf_state)
.await
.context("failed to handle OAuth callback")?;
let (client_id, credentials_opt) = oauth_state
.get_credentials()
.await
.context("failed to retrieve OAuth credentials")?;
let credentials =
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: server_name.to_string(),
url: server_url.to_string(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials),
expires_at,
};
save_oauth_tokens(server_name, &stored, store_mode)?;
drop(guard);
Ok(())
Ok(OauthLoginHandle::new(authorization_url, completion))
}
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
@ -160,3 +147,181 @@ fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
state: state?,
})
}
pub struct OauthLoginHandle {
authorization_url: String,
completion: oneshot::Receiver<Result<()>>,
}
impl OauthLoginHandle {
fn new(authorization_url: String, completion: oneshot::Receiver<Result<()>>) -> Self {
Self {
authorization_url,
completion,
}
}
pub fn authorization_url(&self) -> &str {
&self.authorization_url
}
pub fn into_parts(self) -> (String, oneshot::Receiver<Result<()>>) {
(self.authorization_url, self.completion)
}
pub async fn wait(self) -> Result<()> {
self.completion
.await
.map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))?
}
}
struct OauthLoginFlow {
auth_url: String,
oauth_state: OAuthState,
rx: oneshot::Receiver<(String, String)>,
guard: CallbackServerGuard,
server_name: String,
server_url: String,
store_mode: OAuthCredentialsStoreMode,
launch_browser: bool,
timeout: Duration,
}
impl OauthLoginFlow {
async fn new(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
headers: OauthHeaders,
scopes: &[String],
launch_browser: bool,
timeout_secs: Option<i64>,
) -> Result<Self> {
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
let guard = CallbackServerGuard {
server: Arc::clone(&server),
};
let redirect_uri = match server.server_addr() {
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
let ip = addr.ip();
let port = addr.port();
format!("http://{ip}:{port}/callback")
}
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
let ip = addr.ip();
let port = addr.port();
format!("http://[{ip}]:{port}/callback")
}
#[cfg(not(target_os = "windows"))]
_ => return Err(anyhow!("unable to determine callback address")),
};
let (tx, rx) = oneshot::channel();
spawn_callback_server(server, tx);
let OauthHeaders {
http_headers,
env_http_headers,
} = headers;
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
oauth_state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await?;
let auth_url = oauth_state.get_authorization_url().await?;
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
let timeout = Duration::from_secs(timeout_secs as u64);
Ok(Self {
auth_url,
oauth_state,
rx,
guard,
server_name: server_name.to_string(),
server_url: server_url.to_string(),
store_mode,
launch_browser,
timeout,
})
}
fn authorization_url(&self) -> String {
self.auth_url.clone()
}
async fn finish(mut self) -> Result<()> {
if self.launch_browser {
let server_name = &self.server_name;
let auth_url = &self.auth_url;
println!(
"Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n"
);
if webbrowser::open(auth_url).is_err() {
println!("(Browser launch failed; please copy the URL above manually.)");
}
}
let result = async {
let (code, csrf_state) = timeout(self.timeout, &mut self.rx)
.await
.context("timed out waiting for OAuth callback")?
.context("OAuth callback was cancelled")?;
self.oauth_state
.handle_callback(&code, &csrf_state)
.await
.context("failed to handle OAuth callback")?;
let (client_id, credentials_opt) = self
.oauth_state
.get_credentials()
.await
.context("failed to retrieve OAuth credentials")?;
let credentials = credentials_opt
.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: self.server_name.clone(),
url: self.server_url.clone(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials),
expires_at,
};
save_oauth_tokens(&self.server_name, &stored, self.store_mode)?;
Ok(())
}
.await;
drop(self.guard);
result
}
fn spawn(self) -> oneshot::Receiver<Result<()>> {
let server_name_for_logging = self.server_name.clone();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let result = self.finish().await;
if let Err(err) = &result {
eprintln!(
"Failed to complete OAuth login for '{server_name_for_logging}': {err:#}"
);
}
let _ = tx.send(result);
});
rx
}
}