diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 737b13024..831cc6b3a 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -14,6 +14,7 @@ use codex_protocol::mcp::CallToolResult; use codex_protocol::models::FunctionCallOutputBody; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; +use codex_protocol::openai_models::InputModality; use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::ReviewDecision; use codex_protocol::protocol::SandboxPolicy; @@ -75,10 +76,17 @@ pub(crate) async fn handle_mcp_tool_call( .await; let start = Instant::now(); - let result: Result = sess + let result = sess .call_tool(&server, &tool_name, arguments_value.clone()) .await .map_err(|e| format!("tool call error: {e:?}")); + let result = sanitize_mcp_tool_result_for_model( + turn_context + .model_info + .input_modalities + .contains(&InputModality::Image), + result, + ); if let Err(e) = &result { tracing::warn!("MCP tool call error: {e:?}"); } @@ -136,10 +144,17 @@ pub(crate) async fn handle_mcp_tool_call( let start = Instant::now(); // Perform the tool call. - let result: Result = sess + let result = sess .call_tool(&server, &tool_name, arguments_value.clone()) .await .map_err(|e| format!("tool call error: {e:?}")); + let result = sanitize_mcp_tool_result_for_model( + turn_context + .model_info + .input_modalities + .contains(&InputModality::Image), + result, + ); if let Err(e) = &result { tracing::warn!("MCP tool call error: {e:?}"); } @@ -160,6 +175,37 @@ pub(crate) async fn handle_mcp_tool_call( ResponseInputItem::McpToolCallOutput { call_id, result } } +fn sanitize_mcp_tool_result_for_model( + supports_image_input: bool, + result: Result, +) -> Result { + if supports_image_input { + return result; + } + + result.map(|call_tool_result| CallToolResult { + content: call_tool_result + .content + .iter() + .map(|block| { + if let Some(content_type) = block.get("type").and_then(serde_json::Value::as_str) + && content_type == "image" + { + return serde_json::json!({ + "type": "text", + "text": "", + }); + } + + block.clone() + }) + .collect::>(), + structured_content: call_tool_result.structured_content, + is_error: call_tool_result.is_error, + meta: call_tool_result.meta, + }) +} + async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext, event: EventMsg) { sess.send_event(turn_context, event).await; } @@ -450,4 +496,59 @@ mod tests { let annotations = annotations(Some(true), Some(true), Some(true)); assert_eq!(requires_mcp_tool_approval(&annotations), false); } + + #[test] + fn sanitize_mcp_tool_result_for_model_rewrites_image_content() { + let result = Ok(CallToolResult { + content: vec![ + serde_json::json!({ + "type": "image", + "data": "Zm9v", + "mimeType": "image/png", + }), + serde_json::json!({ + "type": "text", + "text": "hello", + }), + ], + structured_content: None, + is_error: Some(false), + meta: None, + }); + + let got = sanitize_mcp_tool_result_for_model(false, result).expect("sanitized result"); + + assert_eq!( + got.content, + vec![ + serde_json::json!({ + "type": "text", + "text": "", + }), + serde_json::json!({ + "type": "text", + "text": "hello", + }), + ] + ); + } + + #[test] + fn sanitize_mcp_tool_result_for_model_preserves_image_when_supported() { + let original = CallToolResult { + content: vec![serde_json::json!({ + "type": "image", + "data": "Zm9v", + "mimeType": "image/png", + })], + structured_content: Some(serde_json::json!({"x": 1})), + is_error: Some(false), + meta: Some(serde_json::json!({"k": "v"})), + }; + + let got = sanitize_mcp_tool_result_for_model(true, Ok(original.clone())) + .expect("unsanitized result"); + + assert_eq!(got, original); + } } diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 0b3c9dd78..82d4d13a6 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -8,8 +8,11 @@ use std::time::Duration; use std::time::SystemTime; use std::time::UNIX_EPOCH; +use codex_core::CodexAuth; use codex_core::config::types::McpServerConfig; use codex_core::config::types::McpServerTransportConfig; +use codex_core::features::Feature; +use codex_core::models_manager::manager::RefreshStrategy; use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; @@ -18,9 +21,17 @@ use codex_core::protocol::McpToolCallBeginEvent; use codex_core::protocol::Op; use codex_core::protocol::SandboxPolicy; use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::openai_models::ConfigShellToolType; +use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelVisibility; +use codex_protocol::openai_models::ModelsResponse; +use codex_protocol::openai_models::ReasoningEffortPreset; +use codex_protocol::openai_models::TruncationPolicyConfig; use codex_protocol::user_input::UserInput; use codex_utils_cargo_bin::cargo_bin; use core_test_support::responses; +use core_test_support::responses::mount_models_once; use core_test_support::responses::mount_sse_once; use core_test_support::skip_if_no_network; use core_test_support::stdio_server_bin; @@ -356,6 +367,166 @@ async fn stdio_image_responses_round_trip() -> anyhow::Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[serial(mcp_test_value)] +async fn stdio_image_responses_are_sanitized_for_text_only_model() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + + let call_id = "img-text-only-1"; + let server_name = "rmcp"; + let tool_name = format!("mcp__{server_name}__image"); + let text_only_model_slug = "rmcp-text-only-model"; + + let models_mock = mount_models_once( + &server, + ModelsResponse { + models: vec![ModelInfo { + slug: text_only_model_slug.to_string(), + display_name: "RMCP Text Only".to_string(), + description: Some("Test model without image input support".to_string()), + default_reasoning_level: None, + supported_reasoning_levels: vec![ReasoningEffortPreset { + effort: codex_protocol::openai_models::ReasoningEffort::Medium, + description: "Medium".to_string(), + }], + shell_type: ConfigShellToolType::Default, + visibility: ModelVisibility::List, + supported_in_api: true, + priority: 1, + upgrade: None, + base_instructions: "base instructions".to_string(), + model_messages: None, + supports_reasoning_summaries: false, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: None, + truncation_policy: TruncationPolicyConfig::bytes(10_000), + supports_parallel_tool_calls: false, + context_window: Some(272_000), + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities: vec![InputModality::Text], + }], + }, + ) + .await; + + // First stream: model decides to call the image tool. + mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call(call_id, &tool_name, "{}"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + // Second stream: after tool execution, assistant emits a message and completes. + let final_mock = mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_assistant_message("msg-1", "rmcp image tool completed successfully."), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + let rmcp_test_server_bin = stdio_server_bin()?; + + let fixture = test_codex() + .with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) + .with_config(move |config| { + config.features.enable(Feature::RemoteModels); + + let mut servers = config.mcp_servers.get().clone(); + servers.insert( + server_name.to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: rmcp_test_server_bin, + args: Vec::new(), + env: Some(HashMap::from([( + "MCP_TEST_IMAGE_DATA_URL".to_string(), + OPENAI_PNG.to_string(), + )])), + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: Some(Duration::from_secs(10)), + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + }, + ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); + }) + .build(&server) + .await?; + + fixture + .thread_manager + .get_models_manager() + .list_models(&fixture.config, RefreshStrategy::Online) + .await; + assert_eq!(models_mock.requests().len(), 1); + + fixture + .codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "call the rmcp image tool".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, + model: text_only_model_slug.to_string(), + effort: None, + summary: ReasoningSummary::Auto, + collaboration_mode: None, + personality: None, + }) + .await?; + + wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallBegin(_)) + }) + .await; + wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let output_item = final_mock.single_request().function_call_output(call_id); + let output_text = output_item + .get("output") + .and_then(Value::as_str) + .expect("function_call_output output should be a JSON string"); + let output_json: Value = serde_json::from_str(output_text) + .expect("function_call_output output should be valid JSON"); + assert_eq!( + output_json, + json!([{ + "type": "text", + "text": "" + }]) + ); + server.verify().await; + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[serial(mcp_test_value)] async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> {