Fix MCP tool calling (#14491)
Properly escape mcp tool names and make tools only available via imports.
This commit is contained in:
parent
a5a4899d0c
commit
dadffd27d4
9 changed files with 317 additions and 82 deletions
|
|
@ -1,43 +1,8 @@
|
|||
const __codexEnabledTools = __CODE_MODE_ENABLED_TOOLS_PLACEHOLDER__;
|
||||
const __codexEnabledToolNames = __codexEnabledTools.map((tool) => tool.tool_name);
|
||||
const __codexContentItems = Array.isArray(globalThis.__codexContentItems)
|
||||
? globalThis.__codexContentItems
|
||||
: [];
|
||||
|
||||
function __codexCloneContentItem(item) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
throw new TypeError('content item must be an object');
|
||||
}
|
||||
switch (item.type) {
|
||||
case 'input_text':
|
||||
if (typeof item.text !== 'string') {
|
||||
throw new TypeError('content item "input_text" requires a string text field');
|
||||
}
|
||||
return { type: 'input_text', text: item.text };
|
||||
case 'input_image':
|
||||
if (typeof item.image_url !== 'string') {
|
||||
throw new TypeError('content item "input_image" requires a string image_url field');
|
||||
}
|
||||
return { type: 'input_image', image_url: item.image_url };
|
||||
default:
|
||||
throw new TypeError(`unsupported content item type "${item.type}"`);
|
||||
}
|
||||
}
|
||||
|
||||
function __codexNormalizeRawContentItems(value) {
|
||||
if (Array.isArray(value)) {
|
||||
return value.flatMap((entry) => __codexNormalizeRawContentItems(entry));
|
||||
}
|
||||
return [__codexCloneContentItem(value)];
|
||||
}
|
||||
|
||||
function __codexNormalizeContentItems(value) {
|
||||
if (typeof value === 'string') {
|
||||
return [{ type: 'input_text', text: value }];
|
||||
}
|
||||
return __codexNormalizeRawContentItems(value);
|
||||
}
|
||||
|
||||
Object.defineProperty(globalThis, '__codexContentItems', {
|
||||
value: __codexContentItems,
|
||||
configurable: true,
|
||||
|
|
@ -45,33 +10,54 @@ Object.defineProperty(globalThis, '__codexContentItems', {
|
|||
writable: false,
|
||||
});
|
||||
|
||||
globalThis.codex = {
|
||||
enabledTools: Object.freeze(__codexEnabledToolNames.slice()),
|
||||
};
|
||||
|
||||
globalThis.add_content = (value) => {
|
||||
const contentItems = __codexNormalizeContentItems(value);
|
||||
__codexContentItems.push(...contentItems);
|
||||
return contentItems;
|
||||
};
|
||||
|
||||
globalThis.console = Object.freeze({
|
||||
log() {},
|
||||
info() {},
|
||||
warn() {},
|
||||
error() {},
|
||||
debug() {},
|
||||
});
|
||||
|
||||
for (const name of __codexEnabledToolNames) {
|
||||
if (!(name in globalThis)) {
|
||||
Object.defineProperty(globalThis, name, {
|
||||
value: async (args) => __codex_tool_call(name, args),
|
||||
configurable: true,
|
||||
enumerable: false,
|
||||
writable: false,
|
||||
});
|
||||
(() => {
|
||||
function cloneContentItem(item) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
throw new TypeError('content item must be an object');
|
||||
}
|
||||
switch (item.type) {
|
||||
case 'input_text':
|
||||
if (typeof item.text !== 'string') {
|
||||
throw new TypeError('content item "input_text" requires a string text field');
|
||||
}
|
||||
return { type: 'input_text', text: item.text };
|
||||
case 'input_image':
|
||||
if (typeof item.image_url !== 'string') {
|
||||
throw new TypeError('content item "input_image" requires a string image_url field');
|
||||
}
|
||||
return { type: 'input_image', image_url: item.image_url };
|
||||
default:
|
||||
throw new TypeError(`unsupported content item type "${item.type}"`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeRawContentItems(value) {
|
||||
if (Array.isArray(value)) {
|
||||
return value.flatMap((entry) => normalizeRawContentItems(entry));
|
||||
}
|
||||
return [cloneContentItem(value)];
|
||||
}
|
||||
|
||||
function normalizeContentItems(value) {
|
||||
if (typeof value === 'string') {
|
||||
return [{ type: 'input_text', text: value }];
|
||||
}
|
||||
return normalizeRawContentItems(value);
|
||||
}
|
||||
|
||||
globalThis.add_content = (value) => {
|
||||
const contentItems = normalizeContentItems(value);
|
||||
__codexContentItems.push(...contentItems);
|
||||
return contentItems;
|
||||
};
|
||||
|
||||
globalThis.console = Object.freeze({
|
||||
log() {},
|
||||
info() {},
|
||||
warn() {},
|
||||
error() {},
|
||||
debug() {},
|
||||
});
|
||||
})();
|
||||
|
||||
__CODE_MODE_USER_CODE_PLACEHOLDER__
|
||||
|
|
|
|||
|
|
@ -16,4 +16,3 @@
|
|||
- `set_max_output_tokens_per_exec_call(value)`: sets the token budget for direct `exec` results. By default the result is truncated to 10000 tokens.
|
||||
- `set_yield_time(value)`: asks `exec` to yield early after that many milliseconds if the script is still running.
|
||||
- `yield_control()`: yields the accumulated output to the model immediately while the script keeps running.
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ use crate::codex::TurnContext;
|
|||
use crate::tools::ToolRouter;
|
||||
use crate::tools::code_mode_description::augment_tool_spec_for_code_mode;
|
||||
use crate::tools::code_mode_description::code_mode_tool_reference;
|
||||
use crate::tools::code_mode_description::normalize_code_mode_identifier;
|
||||
use crate::tools::context::FunctionToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
|
|
@ -233,10 +234,11 @@ fn enabled_tool_from_spec(spec: ToolSpec) -> Option<protocol::EnabledTool> {
|
|||
};
|
||||
|
||||
Some(protocol::EnabledTool {
|
||||
global_name: normalize_code_mode_identifier(&tool_name),
|
||||
tool_name,
|
||||
module_path: reference.module_path,
|
||||
namespace: reference.namespace,
|
||||
name: reference.tool_key,
|
||||
name: normalize_code_mode_identifier(&reference.tool_key),
|
||||
description,
|
||||
kind,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ pub(super) enum CodeModeToolKind {
|
|||
#[derive(Clone, Debug, Serialize)]
|
||||
pub(super) struct EnabledTool {
|
||||
pub(super) tool_name: String,
|
||||
pub(super) global_name: String,
|
||||
#[serde(rename = "module")]
|
||||
pub(super) module_path: String,
|
||||
pub(super) namespace: Vec<String>,
|
||||
|
|
|
|||
|
|
@ -134,8 +134,8 @@ function codeModeWorkerMain() {
|
|||
function createToolsNamespace(callTool, enabledTools) {
|
||||
const tools = Object.create(null);
|
||||
|
||||
for (const { tool_name } of enabledTools) {
|
||||
Object.defineProperty(tools, tool_name, {
|
||||
for (const { tool_name, global_name } of enabledTools) {
|
||||
Object.defineProperty(tools, global_name, {
|
||||
value: async (args) => callTool(tool_name, args),
|
||||
configurable: false,
|
||||
enumerable: true,
|
||||
|
|
@ -163,9 +163,9 @@ function codeModeWorkerMain() {
|
|||
const allTools = createAllToolsMetadata(enabledTools);
|
||||
const exportNames = ['ALL_TOOLS'];
|
||||
|
||||
for (const { tool_name } of enabledTools) {
|
||||
if (tool_name !== 'ALL_TOOLS') {
|
||||
exportNames.push(tool_name);
|
||||
for (const { global_name } of enabledTools) {
|
||||
if (global_name !== 'ALL_TOOLS') {
|
||||
exportNames.push(global_name);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -382,6 +382,24 @@ function codeModeWorkerMain() {
|
|||
};
|
||||
}
|
||||
|
||||
async function resolveDynamicModule(specifier, resolveModule) {
|
||||
const module = resolveModule(specifier);
|
||||
|
||||
if (module.status === 'unlinked') {
|
||||
await module.link(resolveModule);
|
||||
}
|
||||
|
||||
if (module.status === 'linked' || module.status === 'evaluating') {
|
||||
await module.evaluate();
|
||||
}
|
||||
|
||||
if (module.status === 'errored') {
|
||||
throw module.error;
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
async function runModule(context, start, state, callTool) {
|
||||
const resolveModule = createModuleResolver(
|
||||
context,
|
||||
|
|
@ -392,7 +410,8 @@ function codeModeWorkerMain() {
|
|||
const mainModule = new SourceTextModule(start.source, {
|
||||
context,
|
||||
identifier: 'exec_main.mjs',
|
||||
importModuleDynamically: async (specifier) => resolveModule(specifier),
|
||||
importModuleDynamically: async (specifier) =>
|
||||
resolveDynamicModule(specifier, resolveModule),
|
||||
});
|
||||
|
||||
await mainModule.link(resolveModule);
|
||||
|
|
@ -408,7 +427,6 @@ function codeModeWorkerMain() {
|
|||
const callTool = createToolCaller();
|
||||
const context = vm.createContext({
|
||||
__codexContentItems: createContentItems(),
|
||||
__codex_tool_call: callTool,
|
||||
});
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -75,13 +75,15 @@ fn append_code_mode_sample(
|
|||
output_type: String,
|
||||
) -> String {
|
||||
let reference = code_mode_tool_reference(tool_name);
|
||||
format!(
|
||||
"{description}\n\nCode mode declaration:\n```ts\nimport {{ {} }} from \"{}\";\ndeclare function {}({input_name}: {input_type}): Promise<{output_type}>;\n```",
|
||||
reference.tool_key, reference.module_path, reference.tool_key
|
||||
)
|
||||
let local_name = normalize_code_mode_identifier(&reference.tool_key);
|
||||
let declaration = format!(
|
||||
"import {{ {local_name} }} from \"{}\";\ndeclare function {local_name}({input_name}: {input_type}): Promise<{output_type}>;",
|
||||
reference.module_path
|
||||
);
|
||||
format!("{description}\n\nCode mode declaration:\n```ts\n{declaration}\n```")
|
||||
}
|
||||
|
||||
fn code_mode_local_name(tool_key: &str) -> String {
|
||||
pub(crate) fn normalize_code_mode_identifier(tool_key: &str) -> String {
|
||||
let mut identifier = String::new();
|
||||
|
||||
for (index, ch) in tool_key.chars().enumerate() {
|
||||
|
|
@ -98,7 +100,11 @@ fn code_mode_local_name(tool_key: &str) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
identifier
|
||||
if identifier.is_empty() {
|
||||
"_".to_string()
|
||||
} else {
|
||||
identifier
|
||||
}
|
||||
}
|
||||
|
||||
fn render_json_schema_to_typescript(schema: &JsonValue) -> String {
|
||||
|
|
@ -279,7 +285,7 @@ fn render_json_schema_object(map: &serde_json::Map<String, JsonValue>, indent: u
|
|||
}
|
||||
|
||||
fn render_json_schema_property_name(name: &str) -> String {
|
||||
if code_mode_local_name(name) == name {
|
||||
if normalize_code_mode_identifier(name) == name {
|
||||
name.to_string()
|
||||
} else {
|
||||
serde_json::to_string(name).unwrap_or_else(|_| format!("\"{}\"", name.replace('"', "\\\"")))
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use super::append_code_mode_sample;
|
||||
use super::render_json_schema_to_typescript;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
|
@ -73,3 +74,31 @@ fn render_json_schema_to_typescript_sorts_object_properties() {
|
|||
"{\n _meta?: string;\n content: Array<string>;\n isError?: boolean;\n structuredContent?: string;\n}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_code_mode_sample_uses_static_import_for_valid_identifiers() {
|
||||
assert_eq!(
|
||||
append_code_mode_sample(
|
||||
"desc",
|
||||
"mcp__ologs__get_profile",
|
||||
"args",
|
||||
"{ foo: string }".to_string(),
|
||||
"unknown".to_string(),
|
||||
),
|
||||
"desc\n\nCode mode declaration:\n```ts\nimport { get_profile } from \"tools/mcp/ologs.js\";\ndeclare function get_profile(args: { foo: string }): Promise<unknown>;\n```"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_code_mode_sample_normalizes_non_identifier_tool_names() {
|
||||
assert_eq!(
|
||||
append_code_mode_sample(
|
||||
"desc",
|
||||
"mcp__rmcp__echo-tool",
|
||||
"args",
|
||||
"{ foo: string }".to_string(),
|
||||
"unknown".to_string(),
|
||||
),
|
||||
"desc\n\nCode mode declaration:\n```ts\nimport { echo_tool } from \"tools/mcp/rmcp.js\";\ndeclare function echo_tool(args: { foo: string }): Promise<unknown>;\n```"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ use core_test_support::test_codex::test_codex;
|
|||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
|
@ -1584,6 +1585,184 @@ contentLength=0"
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn code_mode_can_dynamically_import_namespaced_mcp_tools() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let code = r#"
|
||||
const rmcp = await import("tools/mcp/rmcp.js");
|
||||
const { content, structuredContent, isError } = await rmcp.echo({
|
||||
message: "ping",
|
||||
});
|
||||
add_content(
|
||||
`hasEcho=${String(Object.keys(rmcp).includes("echo"))}\n` +
|
||||
`echoType=${typeof rmcp.echo}\n` +
|
||||
`echo=${structuredContent?.echo ?? "missing"}\n` +
|
||||
`isError=${String(isError)}\n` +
|
||||
`contentLength=${content.length}`
|
||||
);
|
||||
"#;
|
||||
|
||||
let (_test, second_mock) = run_code_mode_turn_with_rmcp(
|
||||
&server,
|
||||
"use exec to dynamically import the rmcp module",
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let req = second_mock.single_request();
|
||||
let (output, success) = custom_tool_output_body_and_success(&req, "call-1");
|
||||
assert_ne!(
|
||||
success,
|
||||
Some(false),
|
||||
"exec dynamic rmcp import failed unexpectedly: {output}"
|
||||
);
|
||||
assert_eq!(
|
||||
output,
|
||||
"hasEcho=true
|
||||
echoType=function
|
||||
echo=ECHOING: ping
|
||||
isError=false
|
||||
contentLength=0"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn code_mode_normalizes_illegal_namespaced_mcp_tool_identifiers() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let code = r#"
|
||||
import { echo_tool } from "tools/mcp/rmcp.js";
|
||||
|
||||
const result = await echo_tool({ message: "ping" });
|
||||
add_content(`echo=${result.structuredContent.echo}`);
|
||||
"#;
|
||||
|
||||
let (_test, second_mock) = run_code_mode_turn_with_rmcp(
|
||||
&server,
|
||||
"use exec to import a normalized rmcp tool name",
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let req = second_mock.single_request();
|
||||
let (output, success) = custom_tool_output_body_and_success(&req, "call-1");
|
||||
assert_ne!(
|
||||
success,
|
||||
Some(false),
|
||||
"exec normalized rmcp import failed unexpectedly: {output}"
|
||||
);
|
||||
assert_eq!(output, "echo=ECHOING: ping");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn code_mode_lists_global_scope_items() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
let code = r#"
|
||||
add_content(JSON.stringify(Object.getOwnPropertyNames(globalThis).sort()));
|
||||
"#;
|
||||
|
||||
let (_test, second_mock) =
|
||||
run_code_mode_turn_with_rmcp(&server, "use exec to inspect global scope", code).await?;
|
||||
|
||||
let req = second_mock.single_request();
|
||||
let (output, success) = custom_tool_output_body_and_success(&req, "call-1");
|
||||
assert_ne!(
|
||||
success,
|
||||
Some(false),
|
||||
"exec global scope inspection failed unexpectedly: {output}"
|
||||
);
|
||||
let globals = serde_json::from_str::<Vec<String>>(&output)?;
|
||||
let globals = globals.into_iter().collect::<HashSet<_>>();
|
||||
let expected = [
|
||||
"AggregateError",
|
||||
"Array",
|
||||
"ArrayBuffer",
|
||||
"AsyncDisposableStack",
|
||||
"Atomics",
|
||||
"BigInt",
|
||||
"BigInt64Array",
|
||||
"BigUint64Array",
|
||||
"Boolean",
|
||||
"DataView",
|
||||
"Date",
|
||||
"DisposableStack",
|
||||
"Error",
|
||||
"EvalError",
|
||||
"FinalizationRegistry",
|
||||
"Float16Array",
|
||||
"Float32Array",
|
||||
"Float64Array",
|
||||
"Function",
|
||||
"Infinity",
|
||||
"Int16Array",
|
||||
"Int32Array",
|
||||
"Int8Array",
|
||||
"Intl",
|
||||
"Iterator",
|
||||
"JSON",
|
||||
"Map",
|
||||
"Math",
|
||||
"NaN",
|
||||
"Number",
|
||||
"Object",
|
||||
"Promise",
|
||||
"Proxy",
|
||||
"RangeError",
|
||||
"ReferenceError",
|
||||
"Reflect",
|
||||
"RegExp",
|
||||
"Set",
|
||||
"SharedArrayBuffer",
|
||||
"String",
|
||||
"SuppressedError",
|
||||
"Symbol",
|
||||
"SyntaxError",
|
||||
"TypeError",
|
||||
"URIError",
|
||||
"Uint16Array",
|
||||
"Uint32Array",
|
||||
"Uint8Array",
|
||||
"Uint8ClampedArray",
|
||||
"WeakMap",
|
||||
"WeakRef",
|
||||
"WeakSet",
|
||||
"WebAssembly",
|
||||
"__codexContentItems",
|
||||
"add_content",
|
||||
"console",
|
||||
"decodeURI",
|
||||
"decodeURIComponent",
|
||||
"encodeURI",
|
||||
"encodeURIComponent",
|
||||
"escape",
|
||||
"eval",
|
||||
"globalThis",
|
||||
"isFinite",
|
||||
"isNaN",
|
||||
"parseFloat",
|
||||
"parseInt",
|
||||
"undefined",
|
||||
"unescape",
|
||||
];
|
||||
for g in &globals {
|
||||
assert!(
|
||||
expected.contains(&g.as_str()),
|
||||
"unexpected global {g} in {globals:?}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn code_mode_exports_all_tools_metadata_for_builtin_tools() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ impl TestToolServer {
|
|||
fn new() -> Self {
|
||||
let tools = vec![
|
||||
Self::echo_tool(),
|
||||
Self::echo_dash_tool(),
|
||||
Self::image_tool(),
|
||||
Self::image_scenario_tool(),
|
||||
];
|
||||
|
|
@ -58,6 +59,20 @@ impl TestToolServer {
|
|||
}
|
||||
|
||||
fn echo_tool() -> Tool {
|
||||
Self::build_echo_tool(
|
||||
"echo",
|
||||
"Echo back the provided message and include environment data.",
|
||||
)
|
||||
}
|
||||
|
||||
fn echo_dash_tool() -> Tool {
|
||||
Self::build_echo_tool(
|
||||
"echo-tool",
|
||||
"Echo back the provided message via a tool name that is not a legal JS identifier.",
|
||||
)
|
||||
}
|
||||
|
||||
fn build_echo_tool(name: &'static str, description: &'static str) -> Tool {
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema: JsonObject = serde_json::from_value(json!({
|
||||
"type": "object",
|
||||
|
|
@ -71,8 +86,8 @@ impl TestToolServer {
|
|||
.expect("echo tool schema should deserialize");
|
||||
|
||||
Tool::new(
|
||||
Cow::Borrowed("echo"),
|
||||
Cow::Borrowed("Echo back the provided message and include environment data."),
|
||||
Cow::Borrowed(name),
|
||||
Cow::Borrowed(description),
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
|
|
@ -296,7 +311,7 @@ impl ServerHandler for TestToolServer {
|
|||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match request.name.as_ref() {
|
||||
"echo" => {
|
||||
"echo" | "echo-tool" => {
|
||||
let args: EchoArgs = match request.arguments {
|
||||
Some(arguments) => serde_json::from_value(serde_json::Value::Object(
|
||||
arguments.into_iter().collect(),
|
||||
|
|
@ -304,7 +319,7 @@ impl ServerHandler for TestToolServer {
|
|||
.map_err(|err| McpError::invalid_params(err.to_string(), None))?,
|
||||
None => {
|
||||
return Err(McpError::invalid_params(
|
||||
"missing arguments for echo tool",
|
||||
format!("missing arguments for {} tool", request.name),
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue