From dadffd27d45dd3b330e7b71094b828ce2c1a2d84 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Thu, 12 Mar 2026 13:38:52 -0700 Subject: [PATCH] Fix MCP tool calling (#14491) Properly escape mcp tool names and make tools only available via imports. --- codex-rs/core/src/tools/code_mode/bridge.js | 110 +++++------ .../core/src/tools/code_mode/description.md | 1 - codex-rs/core/src/tools/code_mode/mod.rs | 4 +- codex-rs/core/src/tools/code_mode/protocol.rs | 1 + codex-rs/core/src/tools/code_mode/runner.cjs | 32 +++- .../core/src/tools/code_mode_description.rs | 20 +- .../src/tools/code_mode_description_tests.rs | 29 +++ codex-rs/core/tests/suite/code_mode.rs | 179 ++++++++++++++++++ .../rmcp-client/src/bin/test_stdio_server.rs | 23 ++- 9 files changed, 317 insertions(+), 82 deletions(-) diff --git a/codex-rs/core/src/tools/code_mode/bridge.js b/codex-rs/core/src/tools/code_mode/bridge.js index 435e94e74..d7967faab 100644 --- a/codex-rs/core/src/tools/code_mode/bridge.js +++ b/codex-rs/core/src/tools/code_mode/bridge.js @@ -1,43 +1,8 @@ const __codexEnabledTools = __CODE_MODE_ENABLED_TOOLS_PLACEHOLDER__; -const __codexEnabledToolNames = __codexEnabledTools.map((tool) => tool.tool_name); const __codexContentItems = Array.isArray(globalThis.__codexContentItems) ? globalThis.__codexContentItems : []; -function __codexCloneContentItem(item) { - if (!item || typeof item !== 'object') { - throw new TypeError('content item must be an object'); - } - switch (item.type) { - case 'input_text': - if (typeof item.text !== 'string') { - throw new TypeError('content item "input_text" requires a string text field'); - } - return { type: 'input_text', text: item.text }; - case 'input_image': - if (typeof item.image_url !== 'string') { - throw new TypeError('content item "input_image" requires a string image_url field'); - } - return { type: 'input_image', image_url: item.image_url }; - default: - throw new TypeError(`unsupported content item type "${item.type}"`); - } -} - -function __codexNormalizeRawContentItems(value) { - if (Array.isArray(value)) { - return value.flatMap((entry) => __codexNormalizeRawContentItems(entry)); - } - return [__codexCloneContentItem(value)]; -} - -function __codexNormalizeContentItems(value) { - if (typeof value === 'string') { - return [{ type: 'input_text', text: value }]; - } - return __codexNormalizeRawContentItems(value); -} - Object.defineProperty(globalThis, '__codexContentItems', { value: __codexContentItems, configurable: true, @@ -45,33 +10,54 @@ Object.defineProperty(globalThis, '__codexContentItems', { writable: false, }); -globalThis.codex = { - enabledTools: Object.freeze(__codexEnabledToolNames.slice()), -}; - -globalThis.add_content = (value) => { - const contentItems = __codexNormalizeContentItems(value); - __codexContentItems.push(...contentItems); - return contentItems; -}; - -globalThis.console = Object.freeze({ - log() {}, - info() {}, - warn() {}, - error() {}, - debug() {}, -}); - -for (const name of __codexEnabledToolNames) { - if (!(name in globalThis)) { - Object.defineProperty(globalThis, name, { - value: async (args) => __codex_tool_call(name, args), - configurable: true, - enumerable: false, - writable: false, - }); +(() => { + function cloneContentItem(item) { + if (!item || typeof item !== 'object') { + throw new TypeError('content item must be an object'); + } + switch (item.type) { + case 'input_text': + if (typeof item.text !== 'string') { + throw new TypeError('content item "input_text" requires a string text field'); + } + return { type: 'input_text', text: item.text }; + case 'input_image': + if (typeof item.image_url !== 'string') { + throw new TypeError('content item "input_image" requires a string image_url field'); + } + return { type: 'input_image', image_url: item.image_url }; + default: + throw new TypeError(`unsupported content item type "${item.type}"`); + } } -} + + function normalizeRawContentItems(value) { + if (Array.isArray(value)) { + return value.flatMap((entry) => normalizeRawContentItems(entry)); + } + return [cloneContentItem(value)]; + } + + function normalizeContentItems(value) { + if (typeof value === 'string') { + return [{ type: 'input_text', text: value }]; + } + return normalizeRawContentItems(value); + } + + globalThis.add_content = (value) => { + const contentItems = normalizeContentItems(value); + __codexContentItems.push(...contentItems); + return contentItems; + }; + + globalThis.console = Object.freeze({ + log() {}, + info() {}, + warn() {}, + error() {}, + debug() {}, + }); +})(); __CODE_MODE_USER_CODE_PLACEHOLDER__ diff --git a/codex-rs/core/src/tools/code_mode/description.md b/codex-rs/core/src/tools/code_mode/description.md index b494ef52b..482e07afe 100644 --- a/codex-rs/core/src/tools/code_mode/description.md +++ b/codex-rs/core/src/tools/code_mode/description.md @@ -16,4 +16,3 @@ - `set_max_output_tokens_per_exec_call(value)`: sets the token budget for direct `exec` results. By default the result is truncated to 10000 tokens. - `set_yield_time(value)`: asks `exec` to yield early after that many milliseconds if the script is still running. - `yield_control()`: yields the accumulated output to the model immediately while the script keeps running. - diff --git a/codex-rs/core/src/tools/code_mode/mod.rs b/codex-rs/core/src/tools/code_mode/mod.rs index 50e08fa70..ce72f7ba1 100644 --- a/codex-rs/core/src/tools/code_mode/mod.rs +++ b/codex-rs/core/src/tools/code_mode/mod.rs @@ -17,6 +17,7 @@ use crate::codex::TurnContext; use crate::tools::ToolRouter; use crate::tools::code_mode_description::augment_tool_spec_for_code_mode; use crate::tools::code_mode_description::code_mode_tool_reference; +use crate::tools::code_mode_description::normalize_code_mode_identifier; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolPayload; use crate::tools::parallel::ToolCallRuntime; @@ -233,10 +234,11 @@ fn enabled_tool_from_spec(spec: ToolSpec) -> Option { }; Some(protocol::EnabledTool { + global_name: normalize_code_mode_identifier(&tool_name), tool_name, module_path: reference.module_path, namespace: reference.namespace, - name: reference.tool_key, + name: normalize_code_mode_identifier(&reference.tool_key), description, kind, }) diff --git a/codex-rs/core/src/tools/code_mode/protocol.rs b/codex-rs/core/src/tools/code_mode/protocol.rs index ee5220982..6cd50d3f9 100644 --- a/codex-rs/core/src/tools/code_mode/protocol.rs +++ b/codex-rs/core/src/tools/code_mode/protocol.rs @@ -17,6 +17,7 @@ pub(super) enum CodeModeToolKind { #[derive(Clone, Debug, Serialize)] pub(super) struct EnabledTool { pub(super) tool_name: String, + pub(super) global_name: String, #[serde(rename = "module")] pub(super) module_path: String, pub(super) namespace: Vec, diff --git a/codex-rs/core/src/tools/code_mode/runner.cjs b/codex-rs/core/src/tools/code_mode/runner.cjs index bc6afe561..7668eb2ef 100644 --- a/codex-rs/core/src/tools/code_mode/runner.cjs +++ b/codex-rs/core/src/tools/code_mode/runner.cjs @@ -134,8 +134,8 @@ function codeModeWorkerMain() { function createToolsNamespace(callTool, enabledTools) { const tools = Object.create(null); - for (const { tool_name } of enabledTools) { - Object.defineProperty(tools, tool_name, { + for (const { tool_name, global_name } of enabledTools) { + Object.defineProperty(tools, global_name, { value: async (args) => callTool(tool_name, args), configurable: false, enumerable: true, @@ -163,9 +163,9 @@ function codeModeWorkerMain() { const allTools = createAllToolsMetadata(enabledTools); const exportNames = ['ALL_TOOLS']; - for (const { tool_name } of enabledTools) { - if (tool_name !== 'ALL_TOOLS') { - exportNames.push(tool_name); + for (const { global_name } of enabledTools) { + if (global_name !== 'ALL_TOOLS') { + exportNames.push(global_name); } } @@ -382,6 +382,24 @@ function codeModeWorkerMain() { }; } + async function resolveDynamicModule(specifier, resolveModule) { + const module = resolveModule(specifier); + + if (module.status === 'unlinked') { + await module.link(resolveModule); + } + + if (module.status === 'linked' || module.status === 'evaluating') { + await module.evaluate(); + } + + if (module.status === 'errored') { + throw module.error; + } + + return module; + } + async function runModule(context, start, state, callTool) { const resolveModule = createModuleResolver( context, @@ -392,7 +410,8 @@ function codeModeWorkerMain() { const mainModule = new SourceTextModule(start.source, { context, identifier: 'exec_main.mjs', - importModuleDynamically: async (specifier) => resolveModule(specifier), + importModuleDynamically: async (specifier) => + resolveDynamicModule(specifier, resolveModule), }); await mainModule.link(resolveModule); @@ -408,7 +427,6 @@ function codeModeWorkerMain() { const callTool = createToolCaller(); const context = vm.createContext({ __codexContentItems: createContentItems(), - __codex_tool_call: callTool, }); try { diff --git a/codex-rs/core/src/tools/code_mode_description.rs b/codex-rs/core/src/tools/code_mode_description.rs index 8ed9fc6f5..318e6f495 100644 --- a/codex-rs/core/src/tools/code_mode_description.rs +++ b/codex-rs/core/src/tools/code_mode_description.rs @@ -75,13 +75,15 @@ fn append_code_mode_sample( output_type: String, ) -> String { let reference = code_mode_tool_reference(tool_name); - format!( - "{description}\n\nCode mode declaration:\n```ts\nimport {{ {} }} from \"{}\";\ndeclare function {}({input_name}: {input_type}): Promise<{output_type}>;\n```", - reference.tool_key, reference.module_path, reference.tool_key - ) + let local_name = normalize_code_mode_identifier(&reference.tool_key); + let declaration = format!( + "import {{ {local_name} }} from \"{}\";\ndeclare function {local_name}({input_name}: {input_type}): Promise<{output_type}>;", + reference.module_path + ); + format!("{description}\n\nCode mode declaration:\n```ts\n{declaration}\n```") } -fn code_mode_local_name(tool_key: &str) -> String { +pub(crate) fn normalize_code_mode_identifier(tool_key: &str) -> String { let mut identifier = String::new(); for (index, ch) in tool_key.chars().enumerate() { @@ -98,7 +100,11 @@ fn code_mode_local_name(tool_key: &str) -> String { } } - identifier + if identifier.is_empty() { + "_".to_string() + } else { + identifier + } } fn render_json_schema_to_typescript(schema: &JsonValue) -> String { @@ -279,7 +285,7 @@ fn render_json_schema_object(map: &serde_json::Map, indent: u } fn render_json_schema_property_name(name: &str) -> String { - if code_mode_local_name(name) == name { + if normalize_code_mode_identifier(name) == name { name.to_string() } else { serde_json::to_string(name).unwrap_or_else(|_| format!("\"{}\"", name.replace('"', "\\\""))) diff --git a/codex-rs/core/src/tools/code_mode_description_tests.rs b/codex-rs/core/src/tools/code_mode_description_tests.rs index 500d7bf67..f5b4f8820 100644 --- a/codex-rs/core/src/tools/code_mode_description_tests.rs +++ b/codex-rs/core/src/tools/code_mode_description_tests.rs @@ -1,3 +1,4 @@ +use super::append_code_mode_sample; use super::render_json_schema_to_typescript; use pretty_assertions::assert_eq; use serde_json::json; @@ -73,3 +74,31 @@ fn render_json_schema_to_typescript_sorts_object_properties() { "{\n _meta?: string;\n content: Array;\n isError?: boolean;\n structuredContent?: string;\n}" ); } + +#[test] +fn append_code_mode_sample_uses_static_import_for_valid_identifiers() { + assert_eq!( + append_code_mode_sample( + "desc", + "mcp__ologs__get_profile", + "args", + "{ foo: string }".to_string(), + "unknown".to_string(), + ), + "desc\n\nCode mode declaration:\n```ts\nimport { get_profile } from \"tools/mcp/ologs.js\";\ndeclare function get_profile(args: { foo: string }): Promise;\n```" + ); +} + +#[test] +fn append_code_mode_sample_normalizes_non_identifier_tool_names() { + assert_eq!( + append_code_mode_sample( + "desc", + "mcp__rmcp__echo-tool", + "args", + "{ foo: string }".to_string(), + "unknown".to_string(), + ), + "desc\n\nCode mode declaration:\n```ts\nimport { echo_tool } from \"tools/mcp/rmcp.js\";\ndeclare function echo_tool(args: { foo: string }): Promise;\n```" + ); +} diff --git a/codex-rs/core/tests/suite/code_mode.rs b/codex-rs/core/tests/suite/code_mode.rs index 05fa28751..6baa50ada 100644 --- a/codex-rs/core/tests/suite/code_mode.rs +++ b/codex-rs/core/tests/suite/code_mode.rs @@ -20,6 +20,7 @@ use core_test_support::test_codex::test_codex; use pretty_assertions::assert_eq; use serde_json::Value; use std::collections::HashMap; +use std::collections::HashSet; use std::fs; use std::path::Path; use std::time::Duration; @@ -1584,6 +1585,184 @@ contentLength=0" Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_can_dynamically_import_namespaced_mcp_tools() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let code = r#" +const rmcp = await import("tools/mcp/rmcp.js"); +const { content, structuredContent, isError } = await rmcp.echo({ + message: "ping", +}); +add_content( + `hasEcho=${String(Object.keys(rmcp).includes("echo"))}\n` + + `echoType=${typeof rmcp.echo}\n` + + `echo=${structuredContent?.echo ?? "missing"}\n` + + `isError=${String(isError)}\n` + + `contentLength=${content.length}` +); +"#; + + let (_test, second_mock) = run_code_mode_turn_with_rmcp( + &server, + "use exec to dynamically import the rmcp module", + code, + ) + .await?; + + let req = second_mock.single_request(); + let (output, success) = custom_tool_output_body_and_success(&req, "call-1"); + assert_ne!( + success, + Some(false), + "exec dynamic rmcp import failed unexpectedly: {output}" + ); + assert_eq!( + output, + "hasEcho=true +echoType=function +echo=ECHOING: ping +isError=false +contentLength=0" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_normalizes_illegal_namespaced_mcp_tool_identifiers() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let code = r#" +import { echo_tool } from "tools/mcp/rmcp.js"; + +const result = await echo_tool({ message: "ping" }); +add_content(`echo=${result.structuredContent.echo}`); +"#; + + let (_test, second_mock) = run_code_mode_turn_with_rmcp( + &server, + "use exec to import a normalized rmcp tool name", + code, + ) + .await?; + + let req = second_mock.single_request(); + let (output, success) = custom_tool_output_body_and_success(&req, "call-1"); + assert_ne!( + success, + Some(false), + "exec normalized rmcp import failed unexpectedly: {output}" + ); + assert_eq!(output, "echo=ECHOING: ping"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_lists_global_scope_items() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let code = r#" +add_content(JSON.stringify(Object.getOwnPropertyNames(globalThis).sort())); +"#; + + let (_test, second_mock) = + run_code_mode_turn_with_rmcp(&server, "use exec to inspect global scope", code).await?; + + let req = second_mock.single_request(); + let (output, success) = custom_tool_output_body_and_success(&req, "call-1"); + assert_ne!( + success, + Some(false), + "exec global scope inspection failed unexpectedly: {output}" + ); + let globals = serde_json::from_str::>(&output)?; + let globals = globals.into_iter().collect::>(); + let expected = [ + "AggregateError", + "Array", + "ArrayBuffer", + "AsyncDisposableStack", + "Atomics", + "BigInt", + "BigInt64Array", + "BigUint64Array", + "Boolean", + "DataView", + "Date", + "DisposableStack", + "Error", + "EvalError", + "FinalizationRegistry", + "Float16Array", + "Float32Array", + "Float64Array", + "Function", + "Infinity", + "Int16Array", + "Int32Array", + "Int8Array", + "Intl", + "Iterator", + "JSON", + "Map", + "Math", + "NaN", + "Number", + "Object", + "Promise", + "Proxy", + "RangeError", + "ReferenceError", + "Reflect", + "RegExp", + "Set", + "SharedArrayBuffer", + "String", + "SuppressedError", + "Symbol", + "SyntaxError", + "TypeError", + "URIError", + "Uint16Array", + "Uint32Array", + "Uint8Array", + "Uint8ClampedArray", + "WeakMap", + "WeakRef", + "WeakSet", + "WebAssembly", + "__codexContentItems", + "add_content", + "console", + "decodeURI", + "decodeURIComponent", + "encodeURI", + "encodeURIComponent", + "escape", + "eval", + "globalThis", + "isFinite", + "isNaN", + "parseFloat", + "parseInt", + "undefined", + "unescape", + ]; + for g in &globals { + assert!( + expected.contains(&g.as_str()), + "unexpected global {g} in {globals:?}" + ); + } + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn code_mode_exports_all_tools_metadata_for_builtin_tools() -> Result<()> { skip_if_no_network!(Ok(())); diff --git a/codex-rs/rmcp-client/src/bin/test_stdio_server.rs b/codex-rs/rmcp-client/src/bin/test_stdio_server.rs index d7708bf5e..cd0830776 100644 --- a/codex-rs/rmcp-client/src/bin/test_stdio_server.rs +++ b/codex-rs/rmcp-client/src/bin/test_stdio_server.rs @@ -45,6 +45,7 @@ impl TestToolServer { fn new() -> Self { let tools = vec![ Self::echo_tool(), + Self::echo_dash_tool(), Self::image_tool(), Self::image_scenario_tool(), ]; @@ -58,6 +59,20 @@ impl TestToolServer { } fn echo_tool() -> Tool { + Self::build_echo_tool( + "echo", + "Echo back the provided message and include environment data.", + ) + } + + fn echo_dash_tool() -> Tool { + Self::build_echo_tool( + "echo-tool", + "Echo back the provided message via a tool name that is not a legal JS identifier.", + ) + } + + fn build_echo_tool(name: &'static str, description: &'static str) -> Tool { #[expect(clippy::expect_used)] let schema: JsonObject = serde_json::from_value(json!({ "type": "object", @@ -71,8 +86,8 @@ impl TestToolServer { .expect("echo tool schema should deserialize"); Tool::new( - Cow::Borrowed("echo"), - Cow::Borrowed("Echo back the provided message and include environment data."), + Cow::Borrowed(name), + Cow::Borrowed(description), Arc::new(schema), ) } @@ -296,7 +311,7 @@ impl ServerHandler for TestToolServer { _context: rmcp::service::RequestContext, ) -> Result { match request.name.as_ref() { - "echo" => { + "echo" | "echo-tool" => { let args: EchoArgs = match request.arguments { Some(arguments) => serde_json::from_value(serde_json::Value::Object( arguments.into_iter().collect(), @@ -304,7 +319,7 @@ impl ServerHandler for TestToolServer { .map_err(|err| McpError::invalid_params(err.to_string(), None))?, None => { return Err(McpError::invalid_params( - "missing arguments for echo tool", + format!("missing arguments for {} tool", request.name), None, )); }