Fix MCP tool calling (#14491)

Properly escape mcp tool names and make tools only available via
imports.
This commit is contained in:
pakrym-oai 2026-03-12 13:38:52 -07:00 committed by GitHub
parent a5a4899d0c
commit dadffd27d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 317 additions and 82 deletions

View file

@ -1,43 +1,8 @@
const __codexEnabledTools = __CODE_MODE_ENABLED_TOOLS_PLACEHOLDER__;
const __codexEnabledToolNames = __codexEnabledTools.map((tool) => tool.tool_name);
const __codexContentItems = Array.isArray(globalThis.__codexContentItems)
? globalThis.__codexContentItems
: [];
function __codexCloneContentItem(item) {
if (!item || typeof item !== 'object') {
throw new TypeError('content item must be an object');
}
switch (item.type) {
case 'input_text':
if (typeof item.text !== 'string') {
throw new TypeError('content item "input_text" requires a string text field');
}
return { type: 'input_text', text: item.text };
case 'input_image':
if (typeof item.image_url !== 'string') {
throw new TypeError('content item "input_image" requires a string image_url field');
}
return { type: 'input_image', image_url: item.image_url };
default:
throw new TypeError(`unsupported content item type "${item.type}"`);
}
}
function __codexNormalizeRawContentItems(value) {
if (Array.isArray(value)) {
return value.flatMap((entry) => __codexNormalizeRawContentItems(entry));
}
return [__codexCloneContentItem(value)];
}
function __codexNormalizeContentItems(value) {
if (typeof value === 'string') {
return [{ type: 'input_text', text: value }];
}
return __codexNormalizeRawContentItems(value);
}
Object.defineProperty(globalThis, '__codexContentItems', {
value: __codexContentItems,
configurable: true,
@ -45,33 +10,54 @@ Object.defineProperty(globalThis, '__codexContentItems', {
writable: false,
});
globalThis.codex = {
enabledTools: Object.freeze(__codexEnabledToolNames.slice()),
};
globalThis.add_content = (value) => {
const contentItems = __codexNormalizeContentItems(value);
__codexContentItems.push(...contentItems);
return contentItems;
};
globalThis.console = Object.freeze({
log() {},
info() {},
warn() {},
error() {},
debug() {},
});
for (const name of __codexEnabledToolNames) {
if (!(name in globalThis)) {
Object.defineProperty(globalThis, name, {
value: async (args) => __codex_tool_call(name, args),
configurable: true,
enumerable: false,
writable: false,
});
(() => {
function cloneContentItem(item) {
if (!item || typeof item !== 'object') {
throw new TypeError('content item must be an object');
}
switch (item.type) {
case 'input_text':
if (typeof item.text !== 'string') {
throw new TypeError('content item "input_text" requires a string text field');
}
return { type: 'input_text', text: item.text };
case 'input_image':
if (typeof item.image_url !== 'string') {
throw new TypeError('content item "input_image" requires a string image_url field');
}
return { type: 'input_image', image_url: item.image_url };
default:
throw new TypeError(`unsupported content item type "${item.type}"`);
}
}
}
function normalizeRawContentItems(value) {
if (Array.isArray(value)) {
return value.flatMap((entry) => normalizeRawContentItems(entry));
}
return [cloneContentItem(value)];
}
function normalizeContentItems(value) {
if (typeof value === 'string') {
return [{ type: 'input_text', text: value }];
}
return normalizeRawContentItems(value);
}
globalThis.add_content = (value) => {
const contentItems = normalizeContentItems(value);
__codexContentItems.push(...contentItems);
return contentItems;
};
globalThis.console = Object.freeze({
log() {},
info() {},
warn() {},
error() {},
debug() {},
});
})();
__CODE_MODE_USER_CODE_PLACEHOLDER__

View file

@ -16,4 +16,3 @@
- `set_max_output_tokens_per_exec_call(value)`: sets the token budget for direct `exec` results. By default the result is truncated to 10000 tokens.
- `set_yield_time(value)`: asks `exec` to yield early after that many milliseconds if the script is still running.
- `yield_control()`: yields the accumulated output to the model immediately while the script keeps running.

View file

@ -17,6 +17,7 @@ use crate::codex::TurnContext;
use crate::tools::ToolRouter;
use crate::tools::code_mode_description::augment_tool_spec_for_code_mode;
use crate::tools::code_mode_description::code_mode_tool_reference;
use crate::tools::code_mode_description::normalize_code_mode_identifier;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::parallel::ToolCallRuntime;
@ -233,10 +234,11 @@ fn enabled_tool_from_spec(spec: ToolSpec) -> Option<protocol::EnabledTool> {
};
Some(protocol::EnabledTool {
global_name: normalize_code_mode_identifier(&tool_name),
tool_name,
module_path: reference.module_path,
namespace: reference.namespace,
name: reference.tool_key,
name: normalize_code_mode_identifier(&reference.tool_key),
description,
kind,
})

View file

@ -17,6 +17,7 @@ pub(super) enum CodeModeToolKind {
#[derive(Clone, Debug, Serialize)]
pub(super) struct EnabledTool {
pub(super) tool_name: String,
pub(super) global_name: String,
#[serde(rename = "module")]
pub(super) module_path: String,
pub(super) namespace: Vec<String>,

View file

@ -134,8 +134,8 @@ function codeModeWorkerMain() {
function createToolsNamespace(callTool, enabledTools) {
const tools = Object.create(null);
for (const { tool_name } of enabledTools) {
Object.defineProperty(tools, tool_name, {
for (const { tool_name, global_name } of enabledTools) {
Object.defineProperty(tools, global_name, {
value: async (args) => callTool(tool_name, args),
configurable: false,
enumerable: true,
@ -163,9 +163,9 @@ function codeModeWorkerMain() {
const allTools = createAllToolsMetadata(enabledTools);
const exportNames = ['ALL_TOOLS'];
for (const { tool_name } of enabledTools) {
if (tool_name !== 'ALL_TOOLS') {
exportNames.push(tool_name);
for (const { global_name } of enabledTools) {
if (global_name !== 'ALL_TOOLS') {
exportNames.push(global_name);
}
}
@ -382,6 +382,24 @@ function codeModeWorkerMain() {
};
}
async function resolveDynamicModule(specifier, resolveModule) {
const module = resolveModule(specifier);
if (module.status === 'unlinked') {
await module.link(resolveModule);
}
if (module.status === 'linked' || module.status === 'evaluating') {
await module.evaluate();
}
if (module.status === 'errored') {
throw module.error;
}
return module;
}
async function runModule(context, start, state, callTool) {
const resolveModule = createModuleResolver(
context,
@ -392,7 +410,8 @@ function codeModeWorkerMain() {
const mainModule = new SourceTextModule(start.source, {
context,
identifier: 'exec_main.mjs',
importModuleDynamically: async (specifier) => resolveModule(specifier),
importModuleDynamically: async (specifier) =>
resolveDynamicModule(specifier, resolveModule),
});
await mainModule.link(resolveModule);
@ -408,7 +427,6 @@ function codeModeWorkerMain() {
const callTool = createToolCaller();
const context = vm.createContext({
__codexContentItems: createContentItems(),
__codex_tool_call: callTool,
});
try {

View file

@ -75,13 +75,15 @@ fn append_code_mode_sample(
output_type: String,
) -> String {
let reference = code_mode_tool_reference(tool_name);
format!(
"{description}\n\nCode mode declaration:\n```ts\nimport {{ {} }} from \"{}\";\ndeclare function {}({input_name}: {input_type}): Promise<{output_type}>;\n```",
reference.tool_key, reference.module_path, reference.tool_key
)
let local_name = normalize_code_mode_identifier(&reference.tool_key);
let declaration = format!(
"import {{ {local_name} }} from \"{}\";\ndeclare function {local_name}({input_name}: {input_type}): Promise<{output_type}>;",
reference.module_path
);
format!("{description}\n\nCode mode declaration:\n```ts\n{declaration}\n```")
}
fn code_mode_local_name(tool_key: &str) -> String {
pub(crate) fn normalize_code_mode_identifier(tool_key: &str) -> String {
let mut identifier = String::new();
for (index, ch) in tool_key.chars().enumerate() {
@ -98,7 +100,11 @@ fn code_mode_local_name(tool_key: &str) -> String {
}
}
identifier
if identifier.is_empty() {
"_".to_string()
} else {
identifier
}
}
fn render_json_schema_to_typescript(schema: &JsonValue) -> String {
@ -279,7 +285,7 @@ fn render_json_schema_object(map: &serde_json::Map<String, JsonValue>, indent: u
}
fn render_json_schema_property_name(name: &str) -> String {
if code_mode_local_name(name) == name {
if normalize_code_mode_identifier(name) == name {
name.to_string()
} else {
serde_json::to_string(name).unwrap_or_else(|_| format!("\"{}\"", name.replace('"', "\\\"")))

View file

@ -1,3 +1,4 @@
use super::append_code_mode_sample;
use super::render_json_schema_to_typescript;
use pretty_assertions::assert_eq;
use serde_json::json;
@ -73,3 +74,31 @@ fn render_json_schema_to_typescript_sorts_object_properties() {
"{\n _meta?: string;\n content: Array<string>;\n isError?: boolean;\n structuredContent?: string;\n}"
);
}
#[test]
fn append_code_mode_sample_uses_static_import_for_valid_identifiers() {
assert_eq!(
append_code_mode_sample(
"desc",
"mcp__ologs__get_profile",
"args",
"{ foo: string }".to_string(),
"unknown".to_string(),
),
"desc\n\nCode mode declaration:\n```ts\nimport { get_profile } from \"tools/mcp/ologs.js\";\ndeclare function get_profile(args: { foo: string }): Promise<unknown>;\n```"
);
}
#[test]
fn append_code_mode_sample_normalizes_non_identifier_tool_names() {
assert_eq!(
append_code_mode_sample(
"desc",
"mcp__rmcp__echo-tool",
"args",
"{ foo: string }".to_string(),
"unknown".to_string(),
),
"desc\n\nCode mode declaration:\n```ts\nimport { echo_tool } from \"tools/mcp/rmcp.js\";\ndeclare function echo_tool(args: { foo: string }): Promise<unknown>;\n```"
);
}

View file

@ -20,6 +20,7 @@ use core_test_support::test_codex::test_codex;
use pretty_assertions::assert_eq;
use serde_json::Value;
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs;
use std::path::Path;
use std::time::Duration;
@ -1584,6 +1585,184 @@ contentLength=0"
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn code_mode_can_dynamically_import_namespaced_mcp_tools() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
let code = r#"
const rmcp = await import("tools/mcp/rmcp.js");
const { content, structuredContent, isError } = await rmcp.echo({
message: "ping",
});
add_content(
`hasEcho=${String(Object.keys(rmcp).includes("echo"))}\n` +
`echoType=${typeof rmcp.echo}\n` +
`echo=${structuredContent?.echo ?? "missing"}\n` +
`isError=${String(isError)}\n` +
`contentLength=${content.length}`
);
"#;
let (_test, second_mock) = run_code_mode_turn_with_rmcp(
&server,
"use exec to dynamically import the rmcp module",
code,
)
.await?;
let req = second_mock.single_request();
let (output, success) = custom_tool_output_body_and_success(&req, "call-1");
assert_ne!(
success,
Some(false),
"exec dynamic rmcp import failed unexpectedly: {output}"
);
assert_eq!(
output,
"hasEcho=true
echoType=function
echo=ECHOING: ping
isError=false
contentLength=0"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn code_mode_normalizes_illegal_namespaced_mcp_tool_identifiers() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
let code = r#"
import { echo_tool } from "tools/mcp/rmcp.js";
const result = await echo_tool({ message: "ping" });
add_content(`echo=${result.structuredContent.echo}`);
"#;
let (_test, second_mock) = run_code_mode_turn_with_rmcp(
&server,
"use exec to import a normalized rmcp tool name",
code,
)
.await?;
let req = second_mock.single_request();
let (output, success) = custom_tool_output_body_and_success(&req, "call-1");
assert_ne!(
success,
Some(false),
"exec normalized rmcp import failed unexpectedly: {output}"
);
assert_eq!(output, "echo=ECHOING: ping");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn code_mode_lists_global_scope_items() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
let code = r#"
add_content(JSON.stringify(Object.getOwnPropertyNames(globalThis).sort()));
"#;
let (_test, second_mock) =
run_code_mode_turn_with_rmcp(&server, "use exec to inspect global scope", code).await?;
let req = second_mock.single_request();
let (output, success) = custom_tool_output_body_and_success(&req, "call-1");
assert_ne!(
success,
Some(false),
"exec global scope inspection failed unexpectedly: {output}"
);
let globals = serde_json::from_str::<Vec<String>>(&output)?;
let globals = globals.into_iter().collect::<HashSet<_>>();
let expected = [
"AggregateError",
"Array",
"ArrayBuffer",
"AsyncDisposableStack",
"Atomics",
"BigInt",
"BigInt64Array",
"BigUint64Array",
"Boolean",
"DataView",
"Date",
"DisposableStack",
"Error",
"EvalError",
"FinalizationRegistry",
"Float16Array",
"Float32Array",
"Float64Array",
"Function",
"Infinity",
"Int16Array",
"Int32Array",
"Int8Array",
"Intl",
"Iterator",
"JSON",
"Map",
"Math",
"NaN",
"Number",
"Object",
"Promise",
"Proxy",
"RangeError",
"ReferenceError",
"Reflect",
"RegExp",
"Set",
"SharedArrayBuffer",
"String",
"SuppressedError",
"Symbol",
"SyntaxError",
"TypeError",
"URIError",
"Uint16Array",
"Uint32Array",
"Uint8Array",
"Uint8ClampedArray",
"WeakMap",
"WeakRef",
"WeakSet",
"WebAssembly",
"__codexContentItems",
"add_content",
"console",
"decodeURI",
"decodeURIComponent",
"encodeURI",
"encodeURIComponent",
"escape",
"eval",
"globalThis",
"isFinite",
"isNaN",
"parseFloat",
"parseInt",
"undefined",
"unescape",
];
for g in &globals {
assert!(
expected.contains(&g.as_str()),
"unexpected global {g} in {globals:?}"
);
}
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn code_mode_exports_all_tools_metadata_for_builtin_tools() -> Result<()> {
skip_if_no_network!(Ok(()));

View file

@ -45,6 +45,7 @@ impl TestToolServer {
fn new() -> Self {
let tools = vec![
Self::echo_tool(),
Self::echo_dash_tool(),
Self::image_tool(),
Self::image_scenario_tool(),
];
@ -58,6 +59,20 @@ impl TestToolServer {
}
fn echo_tool() -> Tool {
Self::build_echo_tool(
"echo",
"Echo back the provided message and include environment data.",
)
}
fn echo_dash_tool() -> Tool {
Self::build_echo_tool(
"echo-tool",
"Echo back the provided message via a tool name that is not a legal JS identifier.",
)
}
fn build_echo_tool(name: &'static str, description: &'static str) -> Tool {
#[expect(clippy::expect_used)]
let schema: JsonObject = serde_json::from_value(json!({
"type": "object",
@ -71,8 +86,8 @@ impl TestToolServer {
.expect("echo tool schema should deserialize");
Tool::new(
Cow::Borrowed("echo"),
Cow::Borrowed("Echo back the provided message and include environment data."),
Cow::Borrowed(name),
Cow::Borrowed(description),
Arc::new(schema),
)
}
@ -296,7 +311,7 @@ impl ServerHandler for TestToolServer {
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
) -> Result<CallToolResult, McpError> {
match request.name.as_ref() {
"echo" => {
"echo" | "echo-tool" => {
let args: EchoArgs = match request.arguments {
Some(arguments) => serde_json::from_value(serde_json::Value::Object(
arguments.into_iter().collect(),
@ -304,7 +319,7 @@ impl ServerHandler for TestToolServer {
.map_err(|err| McpError::invalid_params(err.to_string(), None))?,
None => {
return Err(McpError::invalid_params(
"missing arguments for echo tool",
format!("missing arguments for {} tool", request.name),
None,
));
}