diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index b9226121f..e1428d134 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -2316,7 +2316,7 @@ impl CodexMessageProcessor { } }; - let mcp_servers = match serde_json::to_value(&config.mcp_servers) { + let mcp_servers = match serde_json::to_value(config.mcp_servers.get()) { Ok(value) => value, Err(err) => { let error = JSONRPCErrorError { @@ -2377,7 +2377,7 @@ impl CodexMessageProcessor { timeout_secs, } = params; - let Some(server) = config.mcp_servers.get(&name) else { + let Some(server) = config.mcp_servers.get().get(&name) else { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message: format!("No MCP server named '{name}' found."), diff --git a/codex-rs/app-server/src/config_api.rs b/codex-rs/app-server/src/config_api.rs index 25434ce92..dd525cb61 100644 --- a/codex-rs/app-server/src/config_api.rs +++ b/codex-rs/app-server/src/config_api.rs @@ -135,6 +135,7 @@ mod tests { CoreSandboxModeRequirement::ReadOnly, CoreSandboxModeRequirement::ExternalSandbox, ]), + mcp_server_requirements: None, }; let mapped = map_requirements_toml_to_api(requirements); diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 497ac8397..30c6fa21f 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -332,7 +332,7 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) let LoginArgs { name, scopes } = login_args; - let Some(server) = config.mcp_servers.get(&name) else { + let Some(server) = config.mcp_servers.get().get(&name) else { bail!("No MCP server named '{name}' found."); }; @@ -372,6 +372,7 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr let server = config .mcp_servers + .get() .get(&name) .ok_or_else(|| anyhow!("No MCP server named '{name}' found in configuration."))?; @@ -654,7 +655,7 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re .await .context("failed to load configuration")?; - let Some(server) = config.mcp_servers.get(&get_args.name) else { + let Some(server) = config.mcp_servers.get().get(&get_args.name) else { bail!("No MCP server named '{name}' found.", name = get_args.name); }; diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index be362e287..d913e1d3c 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -755,7 +755,7 @@ impl Session { .write() .await .initialize( - config.mcp_servers.clone(), + &config.mcp_servers, config.mcp_oauth_credentials_store_mode, auth_statuses.clone(), tx_event.clone(), @@ -1739,7 +1739,7 @@ impl Session { let mut refreshed_manager = McpConnectionManager::default(); refreshed_manager .initialize( - mcp_servers, + &mcp_servers, store_mode, auth_statuses, self.get_tx_event(), diff --git a/codex-rs/core/src/config/constraint.rs b/codex-rs/core/src/config/constraint.rs index 5a412a0d0..fa431a6eb 100644 --- a/codex-rs/core/src/config/constraint.rs +++ b/codex-rs/core/src/config/constraint.rs @@ -37,11 +37,15 @@ impl From for std::io::Error { } type ConstraintValidator = dyn Fn(&T) -> ConstraintResult<()> + Send + Sync; +/// A ConstraintNormalizer is a function which transforms a value into another of the same type. +/// `Constrained` uses normalizers to transform values to satisfy constraints or enforce values. +type ConstraintNormalizer = dyn Fn(T) -> T + Send + Sync; #[derive(Clone)] pub struct Constrained { value: T, validator: Arc>, + normalizer: Option>>, } impl Constrained { @@ -54,6 +58,23 @@ impl Constrained { Ok(Self { value: initial_value, validator, + normalizer: None, + }) + } + + /// normalized creates a `Constrained` value with a normalizer function and a validator that allows any value. + pub fn normalized( + initial_value: T, + normalizer: impl Fn(T) -> T + Send + Sync + 'static, + ) -> ConstraintResult { + let validator: Arc> = Arc::new(|_| Ok(())); + let normalizer: Arc> = Arc::new(normalizer); + let normalized = normalizer(initial_value); + validator(&normalized)?; + Ok(Self { + value: normalized, + validator, + normalizer: Some(normalizer), }) } @@ -61,6 +82,7 @@ impl Constrained { Self { value: initial_value, validator: Arc::new(|_| Ok(())), + normalizer: None, } } @@ -88,6 +110,11 @@ impl Constrained { } pub fn set(&mut self, value: T) -> ConstraintResult<()> { + let value = if let Some(normalizer) = &self.normalizer { + normalizer(value) + } else { + value + }; (self.validator)(&value)?; self.value = value; Ok(()) @@ -143,6 +170,17 @@ mod tests { assert_eq!(constrained.value(), 0); } + #[test] + fn constrained_normalizer_applies_on_init_and_set() -> anyhow::Result<()> { + let mut constrained = Constrained::normalized(-1, |value| value.max(0))?; + assert_eq!(constrained.value(), 0); + constrained.set(-5)?; + assert_eq!(constrained.value(), 0); + constrained.set(10)?; + assert_eq!(constrained.value(), 10); + Ok(()) + } + #[test] fn constrained_new_rejects_invalid_initial_value() { let result = Constrained::new(0, |value| { diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 143426825..fa1fee6c0 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -2,6 +2,7 @@ use crate::auth::AuthCredentialsStoreMode; use crate::config::types::DEFAULT_OTEL_ENVIRONMENT; use crate::config::types::History; use crate::config::types::McpServerConfig; +use crate::config::types::McpServerTransportConfig; use crate::config::types::Notice; use crate::config::types::Notifications; use crate::config::types::OtelConfig; @@ -16,6 +17,8 @@ use crate::config::types::UriBasedFileOpener; use crate::config_loader::ConfigLayerStack; use crate::config_loader::ConfigRequirements; use crate::config_loader::LoaderOverrides; +use crate::config_loader::McpServerIdentity; +use crate::config_loader::McpServerRequirement; use crate::config_loader::load_config_layers_state; use crate::features::Feature; use crate::features::FeatureOverrides; @@ -260,7 +263,7 @@ pub struct Config { pub cli_auth_credentials_store_mode: AuthCredentialsStoreMode, /// Definition for MCP servers that Codex can reach out to for tool calls. - pub mcp_servers: HashMap, + pub mcp_servers: Constrained>, /// Preferred store for MCP OAuth credentials. /// keyring: Use an OS-specific keyring service. @@ -513,6 +516,59 @@ fn deserialize_config_toml_with_base( .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e)) } +fn filter_mcp_servers_by_requirements( + mcp_servers: &mut HashMap, + mcp_requirements: Option<&BTreeMap>, +) { + let Some(allowlist) = mcp_requirements else { + return; + }; + + for (name, server) in mcp_servers.iter_mut() { + let allowed = allowlist + .get(name) + .is_some_and(|requirement| mcp_server_matches_requirement(requirement, server)); + if !allowed { + server.enabled = false; + } + } +} + +fn constrain_mcp_servers( + mcp_servers: HashMap, + mcp_requirements: Option<&BTreeMap>, +) -> ConstraintResult>> { + if mcp_requirements.is_none() { + return Ok(Constrained::allow_any(mcp_servers)); + } + + let mcp_requirements = mcp_requirements.cloned(); + Constrained::normalized(mcp_servers, move |mut servers| { + filter_mcp_servers_by_requirements(&mut servers, mcp_requirements.as_ref()); + servers + }) +} + +fn mcp_server_matches_requirement( + requirement: &McpServerRequirement, + server: &McpServerConfig, +) -> bool { + match &requirement.identity { + McpServerIdentity::Command { + command: want_command, + } => matches!( + &server.transport, + McpServerTransportConfig::Stdio { command: got_command, .. } + if got_command == want_command + ), + McpServerIdentity::Url { url: want_url } => matches!( + &server.transport, + McpServerTransportConfig::StreamableHttp { url: got_url, .. } + if got_url == want_url + ), + } +} + pub async fn load_global_mcp_servers( codex_home: &Path, ) -> std::io::Result> { @@ -1347,6 +1403,7 @@ impl Config { let ConfigRequirements { approval_policy: mut constrained_approval_policy, sandbox_policy: mut constrained_sandbox_policy, + mcp_server_requirements, } = requirements; constrained_approval_policy @@ -1356,6 +1413,12 @@ impl Config { .set(sandbox_policy) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, format!("{e}")))?; + let mcp_servers = + constrain_mcp_servers(cfg.mcp_servers.clone(), mcp_server_requirements.as_ref()) + .map_err(|e| { + std::io::Error::new(std::io::ErrorKind::InvalidInput, format!("{e}")) + })?; + let config = Self { model, review_model, @@ -1377,7 +1440,7 @@ impl Config { // The config.toml omits "_mode" because it's a config file. However, "_mode" // is important in code to differentiate the mode from the store implementation. cli_auth_credentials_store_mode: cfg.cli_auth_credentials_store.unwrap_or_default(), - mcp_servers: cfg.mcp_servers, + mcp_servers, // The config.toml omits "_mode" because it's a config file. However, "_mode" // is important in code to differentiate the mode from the store implementation. mcp_oauth_credentials_store_mode: cfg.mcp_oauth_credentials_store.unwrap_or_default(), @@ -1616,9 +1679,44 @@ mod tests { use core_test_support::test_absolute_path; use pretty_assertions::assert_eq; + use std::collections::BTreeMap; + use std::collections::HashMap; use std::time::Duration; use tempfile::TempDir; + fn stdio_mcp(command: &str) -> McpServerConfig { + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: command.to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + } + } + + fn http_mcp(url: &str) -> McpServerConfig { + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: url.to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + } + } + #[test] fn test_toml_parsing() { let history_with_persistence = r#" @@ -1823,6 +1921,122 @@ trust_level = "trusted" } } + #[test] + fn filter_mcp_servers_by_allowlist_enforces_identity_rules() { + const MISMATCHED_COMMAND_SERVER: &str = "mismatched-command-should-disable"; + const MISMATCHED_URL_SERVER: &str = "mismatched-url-should-disable"; + const MATCHED_COMMAND_SERVER: &str = "matched-command-should-allow"; + const MATCHED_URL_SERVER: &str = "matched-url-should-allow"; + const DIFFERENT_NAME_SERVER: &str = "different-name-should-disable"; + + const GOOD_CMD: &str = "good-cmd"; + const GOOD_URL: &str = "https://example.com/good"; + + let mut servers = HashMap::from([ + (MISMATCHED_COMMAND_SERVER.to_string(), stdio_mcp("docs-cmd")), + ( + MISMATCHED_URL_SERVER.to_string(), + http_mcp("https://example.com/mcp"), + ), + (MATCHED_COMMAND_SERVER.to_string(), stdio_mcp(GOOD_CMD)), + (MATCHED_URL_SERVER.to_string(), http_mcp(GOOD_URL)), + (DIFFERENT_NAME_SERVER.to_string(), stdio_mcp("same-cmd")), + ]); + filter_mcp_servers_by_requirements( + &mut servers, + Some(&BTreeMap::from([ + ( + MISMATCHED_URL_SERVER.to_string(), + McpServerRequirement { + identity: McpServerIdentity::Url { + url: "https://example.com/other".to_string(), + }, + }, + ), + ( + MISMATCHED_COMMAND_SERVER.to_string(), + McpServerRequirement { + identity: McpServerIdentity::Command { + command: "other-cmd".to_string(), + }, + }, + ), + ( + MATCHED_URL_SERVER.to_string(), + McpServerRequirement { + identity: McpServerIdentity::Url { + url: GOOD_URL.to_string(), + }, + }, + ), + ( + MATCHED_COMMAND_SERVER.to_string(), + McpServerRequirement { + identity: McpServerIdentity::Command { + command: GOOD_CMD.to_string(), + }, + }, + ), + ])), + ); + + assert_eq!( + servers + .iter() + .map(|(name, server)| (name.clone(), server.enabled)) + .collect::>(), + HashMap::from([ + (MISMATCHED_URL_SERVER.to_string(), false), + (MISMATCHED_COMMAND_SERVER.to_string(), false), + (MATCHED_URL_SERVER.to_string(), true), + (MATCHED_COMMAND_SERVER.to_string(), true), + (DIFFERENT_NAME_SERVER.to_string(), false), + ]) + ); + } + + #[test] + fn filter_mcp_servers_by_allowlist_allows_all_when_unset() { + let mut servers = HashMap::from([ + ("server-a".to_string(), stdio_mcp("cmd-a")), + ("server-b".to_string(), http_mcp("https://example.com/b")), + ]); + + filter_mcp_servers_by_requirements(&mut servers, None); + + assert_eq!( + servers + .iter() + .map(|(name, server)| (name.clone(), server.enabled)) + .collect::>(), + HashMap::from([ + ("server-a".to_string(), true), + ("server-b".to_string(), true), + ]) + ); + } + + #[test] + fn filter_mcp_servers_by_allowlist_blocks_all_when_empty() { + let mut servers = HashMap::from([ + ("server-a".to_string(), stdio_mcp("cmd-a")), + ("server-b".to_string(), http_mcp("https://example.com/b")), + ]); + + filter_mcp_servers_by_requirements(&mut servers, Some(&BTreeMap::new())); + + assert_eq!( + servers + .iter() + .map(|(name, server)| (name.clone(), server.enabled)) + .collect::>(), + HashMap::from([ + ("server-a".to_string(), false), + ("server-b".to_string(), false), + ]) + ); + } + #[test] fn add_dir_override_extends_workspace_writable_roots() -> std::io::Result<()> { let temp_dir = TempDir::new()?; @@ -3264,7 +3478,7 @@ model_verbosity = "high" notify: None, cwd: fixture.cwd(), cli_auth_credentials_store_mode: Default::default(), - mcp_servers: HashMap::new(), + mcp_servers: Constrained::allow_any(HashMap::new()), mcp_oauth_credentials_store_mode: Default::default(), mcp_oauth_callback_port: None, model_providers: fixture.model_provider_map.clone(), @@ -3351,7 +3565,7 @@ model_verbosity = "high" notify: None, cwd: fixture.cwd(), cli_auth_credentials_store_mode: Default::default(), - mcp_servers: HashMap::new(), + mcp_servers: Constrained::allow_any(HashMap::new()), mcp_oauth_credentials_store_mode: Default::default(), mcp_oauth_callback_port: None, model_providers: fixture.model_provider_map.clone(), @@ -3453,7 +3667,7 @@ model_verbosity = "high" notify: None, cwd: fixture.cwd(), cli_auth_credentials_store_mode: Default::default(), - mcp_servers: HashMap::new(), + mcp_servers: Constrained::allow_any(HashMap::new()), mcp_oauth_credentials_store_mode: Default::default(), mcp_oauth_callback_port: None, model_providers: fixture.model_provider_map.clone(), @@ -3541,7 +3755,7 @@ model_verbosity = "high" notify: None, cwd: fixture.cwd(), cli_auth_credentials_store_mode: Default::default(), - mcp_servers: HashMap::new(), + mcp_servers: Constrained::allow_any(HashMap::new()), mcp_oauth_credentials_store_mode: Default::default(), mcp_oauth_callback_port: None, model_providers: fixture.model_provider_map.clone(), diff --git a/codex-rs/core/src/config_loader/config_requirements.rs b/codex-rs/core/src/config_loader/config_requirements.rs index dd001e417..731ff7d79 100644 --- a/codex-rs/core/src/config_loader/config_requirements.rs +++ b/codex-rs/core/src/config_loader/config_requirements.rs @@ -3,6 +3,7 @@ use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::SandboxPolicy; use codex_utils_absolute_path::AbsolutePathBuf; use serde::Deserialize; +use std::collections::BTreeMap; use std::fmt; use crate::config::Constrained; @@ -43,6 +44,7 @@ impl fmt::Display for RequirementSource { pub struct ConfigRequirements { pub approval_policy: Constrained, pub sandbox_policy: Constrained, + pub mcp_server_requirements: Option>, } impl Default for ConfigRequirements { @@ -50,15 +52,29 @@ impl Default for ConfigRequirements { Self { approval_policy: Constrained::allow_any_from_default(), sandbox_policy: Constrained::allow_any(SandboxPolicy::ReadOnly), + mcp_server_requirements: None, } } } +#[derive(Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum McpServerIdentity { + Command { command: String }, + Url { url: String }, +} + +#[derive(Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct McpServerRequirement { + pub identity: McpServerIdentity, +} + /// Base config deserialized from /etc/codex/requirements.toml or MDM. #[derive(Deserialize, Debug, Clone, Default, PartialEq)] pub struct ConfigRequirementsToml { pub allowed_approval_policies: Option>, pub allowed_sandbox_modes: Option>, + pub mcp_server_requirements: Option>, } /// Value paired with the requirement source it came from, for better error @@ -87,6 +103,7 @@ impl std::ops::Deref for Sourced { pub struct ConfigRequirementsWithSources { pub allowed_approval_policies: Option>>, pub allowed_sandbox_modes: Option>>, + pub mcp_server_requirements: Option>>, } impl ConfigRequirementsWithSources { @@ -114,7 +131,11 @@ impl ConfigRequirementsWithSources { self, other, source, - { allowed_approval_policies, allowed_sandbox_modes } + { + allowed_approval_policies, + allowed_sandbox_modes, + mcp_server_requirements, + } ); } @@ -122,10 +143,12 @@ impl ConfigRequirementsWithSources { let ConfigRequirementsWithSources { allowed_approval_policies, allowed_sandbox_modes, + mcp_server_requirements, } = self; ConfigRequirementsToml { allowed_approval_policies: allowed_approval_policies.map(|sourced| sourced.value), allowed_sandbox_modes: allowed_sandbox_modes.map(|sourced| sourced.value), + mcp_server_requirements: mcp_server_requirements.map(|sourced| sourced.value), } } } @@ -159,7 +182,9 @@ impl From for SandboxModeRequirement { impl ConfigRequirementsToml { pub fn is_empty(&self) -> bool { - self.allowed_approval_policies.is_none() && self.allowed_sandbox_modes.is_none() + self.allowed_approval_policies.is_none() + && self.allowed_sandbox_modes.is_none() + && self.mcp_server_requirements.is_none() } } @@ -170,6 +195,7 @@ impl TryFrom for ConfigRequirements { let ConfigRequirementsWithSources { allowed_approval_policies, allowed_sandbox_modes, + mcp_server_requirements, } = toml; let approval_policy: Constrained = match allowed_approval_policies { @@ -247,6 +273,7 @@ impl TryFrom for ConfigRequirements { Ok(ConfigRequirements { approval_policy, sandbox_policy, + mcp_server_requirements: mcp_server_requirements.map(|sourced| sourced.value), }) } } @@ -264,12 +291,15 @@ mod tests { let ConfigRequirementsToml { allowed_approval_policies, allowed_sandbox_modes, + mcp_server_requirements, } = toml; ConfigRequirementsWithSources { allowed_approval_policies: allowed_approval_policies .map(|value| Sourced::new(value, RequirementSource::Unknown)), allowed_sandbox_modes: allowed_sandbox_modes .map(|value| Sourced::new(value, RequirementSource::Unknown)), + mcp_server_requirements: mcp_server_requirements + .map(|value| Sourced::new(value, RequirementSource::Unknown)), } } @@ -289,6 +319,7 @@ mod tests { let other = ConfigRequirementsToml { allowed_approval_policies: Some(allowed_approval_policies.clone()), allowed_sandbox_modes: Some(allowed_sandbox_modes.clone()), + mcp_server_requirements: None, }; target.merge_unset_fields(source.clone(), other); @@ -301,6 +332,7 @@ mod tests { source.clone() )), allowed_sandbox_modes: Some(Sourced::new(allowed_sandbox_modes, source)), + mcp_server_requirements: None, } ); } @@ -328,6 +360,7 @@ mod tests { source_location, )), allowed_sandbox_modes: None, + mcp_server_requirements: None, } ); Ok(()) @@ -363,6 +396,7 @@ mod tests { existing_source, )), allowed_sandbox_modes: None, + mcp_server_requirements: None, } ); Ok(()) @@ -523,4 +557,40 @@ mod tests { Ok(()) } + + #[test] + fn deserialize_mcp_server_requirements() -> Result<()> { + let toml_str = r#" + [mcp_server_requirements.docs.identity] + command = "codex-mcp" + + [mcp_server_requirements.remote.identity] + url = "https://example.com/mcp" + "#; + let requirements: ConfigRequirements = + with_unknown_source(from_str(toml_str)?).try_into()?; + + assert_eq!( + requirements.mcp_server_requirements, + Some(BTreeMap::from([ + ( + "docs".to_string(), + McpServerRequirement { + identity: McpServerIdentity::Command { + command: "codex-mcp".to_string(), + }, + }, + ), + ( + "remote".to_string(), + McpServerRequirement { + identity: McpServerIdentity::Url { + url: "https://example.com/mcp".to_string(), + }, + }, + ), + ])) + ); + Ok(()) + } } diff --git a/codex-rs/core/src/config_loader/mod.rs b/codex-rs/core/src/config_loader/mod.rs index 1710ec12c..a793aa223 100644 --- a/codex-rs/core/src/config_loader/mod.rs +++ b/codex-rs/core/src/config_loader/mod.rs @@ -26,6 +26,8 @@ use toml::Value as TomlValue; pub use config_requirements::ConfigRequirements; pub use config_requirements::ConfigRequirementsToml; +pub use config_requirements::McpServerIdentity; +pub use config_requirements::McpServerRequirement; pub use config_requirements::RequirementSource; pub use config_requirements::SandboxModeRequirement; pub use merge::merge_toml_values; diff --git a/codex-rs/core/src/mcp/mod.rs b/codex-rs/core/src/mcp/mod.rs index 677483646..9e5446a74 100644 --- a/codex-rs/core/src/mcp/mod.rs +++ b/codex-rs/core/src/mcp/mod.rs @@ -47,7 +47,7 @@ pub async fn collect_mcp_snapshot(config: &Config) -> McpListToolsResponseEvent mcp_connection_manager .initialize( - config.mcp_servers.clone(), + &config.mcp_servers, config.mcp_oauth_credentials_store_mode, auth_status_entries.clone(), tx_event, diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index dcd1edf80..6574437bd 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -312,7 +312,7 @@ pub(crate) struct McpConnectionManager { impl McpConnectionManager { pub async fn initialize( &mut self, - mcp_servers: HashMap, + mcp_servers: &HashMap, store_mode: OAuthCredentialsStoreMode, auth_entries: HashMap, tx_event: Sender, @@ -325,6 +325,7 @@ impl McpConnectionManager { let mut clients = HashMap::new(); let mut join_set = JoinSet::new(); let elicitation_requests = ElicitationRequestManager::default(); + let mcp_servers = mcp_servers.clone(); for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) { let cancel_token = cancel_token.child_token(); let _ = emit_update( diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index dc6d47fe7..0d47e296c 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -73,7 +73,8 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { let fixture = test_codex() .with_config(move |config| { - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), McpServerConfig { transport: McpServerTransportConfig::Stdio { @@ -93,6 +94,10 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); }) .build(&server) .await?; @@ -204,7 +209,8 @@ async fn stdio_image_responses_round_trip() -> anyhow::Result<()> { let fixture = test_codex() .with_config(move |config| { - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), McpServerConfig { transport: McpServerTransportConfig::Stdio { @@ -224,6 +230,10 @@ async fn stdio_image_responses_round_trip() -> anyhow::Result<()> { disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); }) .build(&server) .await?; @@ -393,7 +403,8 @@ async fn stdio_image_completions_round_trip() -> anyhow::Result<()> { let fixture = test_codex() .with_config(move |config| { config.model_provider.wire_api = codex_core::WireApi::Chat; - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), McpServerConfig { transport: McpServerTransportConfig::Stdio { @@ -413,6 +424,10 @@ async fn stdio_image_completions_round_trip() -> anyhow::Result<()> { disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); }) .build(&server) .await?; @@ -533,7 +548,8 @@ async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> { let fixture = test_codex() .with_config(move |config| { - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), McpServerConfig { transport: McpServerTransportConfig::Stdio { @@ -550,6 +566,10 @@ async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> { disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); }) .build(&server) .await?; @@ -676,7 +696,8 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { let fixture = test_codex() .with_config(move |config| { - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), McpServerConfig { transport: McpServerTransportConfig::StreamableHttp { @@ -692,6 +713,10 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); }) .build(&server) .await?; @@ -850,7 +875,8 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> { let fixture = test_codex() .with_config(move |config| { - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), McpServerConfig { transport: McpServerTransportConfig::StreamableHttp { @@ -866,6 +892,10 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> { disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); }) .build(&server) .await?; diff --git a/codex-rs/core/tests/suite/truncation.rs b/codex-rs/core/tests/suite/truncation.rs index c2bbd2d53..80204e8e1 100644 --- a/codex-rs/core/tests/suite/truncation.rs +++ b/codex-rs/core/tests/suite/truncation.rs @@ -414,7 +414,8 @@ async fn mcp_tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> let rmcp_test_server_bin = stdio_server_bin()?; let mut builder = test_codex().with_config(move |config| { - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), codex_core::config::types::McpServerConfig { transport: codex_core::config::types::McpServerTransportConfig::Stdio { @@ -431,6 +432,10 @@ async fn mcp_tool_call_output_exceeds_limit_truncated_for_model() -> Result<()> disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); config.tool_output_token_limit = Some(500); }); let fixture = builder.build(&server).await?; @@ -497,7 +502,8 @@ async fn mcp_image_output_preserves_image_and_no_text_summary() -> Result<()> { let openai_png = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMB/ee9bQAAAABJRU5ErkJggg=="; let mut builder = test_codex().with_config(move |config| { - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), McpServerConfig { transport: McpServerTransportConfig::Stdio { @@ -517,6 +523,10 @@ async fn mcp_image_output_preserves_image_and_no_text_summary() -> Result<()> { disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); }); let fixture = builder.build(&server).await?; let session_model = fixture.session_configured.model.clone(); @@ -754,7 +764,8 @@ async fn mcp_tool_call_output_not_truncated_with_custom_limit() -> Result<()> { let mut builder = test_codex().with_config(move |config| { config.tool_output_token_limit = Some(50_000); - config.mcp_servers.insert( + let mut servers = config.mcp_servers.get().clone(); + servers.insert( server_name.to_string(), codex_core::config::types::McpServerConfig { transport: codex_core::config::types::McpServerTransportConfig::Stdio { @@ -771,6 +782,10 @@ async fn mcp_tool_call_output_not_truncated_with_custom_limit() -> Result<()> { disabled_tools: None, }, ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); }); let fixture = builder.build(&server).await?; diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index 3f14d84fa..34925ee50 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -1923,7 +1923,8 @@ mod tests { enabled_tools: None, disabled_tools: None, }; - config.mcp_servers.insert("docs".to_string(), stdio_config); + let mut servers = config.mcp_servers.get().clone(); + servers.insert("docs".to_string(), stdio_config); let mut headers = HashMap::new(); headers.insert("Authorization".to_string(), "Bearer secret".to_string()); @@ -1942,7 +1943,11 @@ mod tests { enabled_tools: None, disabled_tools: None, }; - config.mcp_servers.insert("http".to_string(), http_config); + servers.insert("http".to_string(), http_config); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); let mut tools: HashMap = HashMap::new(); tools.insert( diff --git a/codex-rs/tui2/src/history_cell.rs b/codex-rs/tui2/src/history_cell.rs index 46e7bed34..f8a15c50e 100644 --- a/codex-rs/tui2/src/history_cell.rs +++ b/codex-rs/tui2/src/history_cell.rs @@ -1962,7 +1962,8 @@ mod tests { enabled_tools: None, disabled_tools: None, }; - config.mcp_servers.insert("docs".to_string(), stdio_config); + let mut servers = config.mcp_servers.get().clone(); + servers.insert("docs".to_string(), stdio_config); let mut headers = HashMap::new(); headers.insert("Authorization".to_string(), "Bearer secret".to_string()); @@ -1981,7 +1982,11 @@ mod tests { enabled_tools: None, disabled_tools: None, }; - config.mcp_servers.insert("http".to_string(), http_config); + servers.insert("http".to_string(), http_config); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); let mut tools: HashMap = HashMap::new(); tools.insert(