feat: search_tool (#10657)
**Why We Did This** - The goal is to reduce MCP tool context pollution by not exposing the full MCP tool list up front - It forces an explicit discovery step (`search_tool_bm25`) so the model narrows tool scope before making MCP calls, which helps relevance and lowers prompt/tool clutter. **What It Changed** - Added a new experimental feature flag `search_tool` in `core/src/features.rs:90` and `core/src/features.rs:430`. - Added config/schema support for that flag in `core/config.schema.json:214` and `core/config.schema.json:1235`. - Added BM25 dependency (`bm25`) in `Cargo.toml:129` and `core/Cargo.toml:23`. - Added new tool handler `search_tool_bm25` in `core/src/tools/handlers/search_tool_bm25.rs:18`. - Registered the handler and tool spec in `core/src/tools/handlers/mod.rs:11` and `core/src/tools/spec.rs:780` and `core/src/tools/spec.rs:1344`. - Extended `ToolsConfig` to carry `search_tool` enablement in `core/src/tools/spec.rs:32` and `core/src/tools/spec.rs:56`. - Injected dedicated developer instructions for tool-discovery workflow in `core/src/codex.rs:483` and `core/src/codex.rs:1976`, using `core/templates/search_tool/developer_instructions.md:1`. - Added session state to store one-shot selected MCP tools in `core/src/state/session.rs:27` and `core/src/state/session.rs:131`. - Added filtering so when feature is enabled, only selected MCP tools are exposed on the next request (then consumed) in `core/src/codex.rs:3800` and `core/src/codex.rs:3843`. - Added E2E suite coverage for enablement/instructions/hide-until-search/one-turn-selection in `core/tests/suite/search_tool.rs:72`, `core/tests/suite/search_tool.rs:109`, `core/tests/suite/search_tool.rs:147`, and `core/tests/suite/search_tool.rs:218`. - Refactored test helper utilities to support config-driven tool collection in `core/tests/suite/tools.rs:281`. **Net Behavioral Effect** - With `search_tool` **off**: existing MCP behavior (tools exposed normally). - With `search_tool` **on**: MCP tools start hidden, model must call `search_tool_bm25`, and only returned `selected_tools` are available for the next model call.
This commit is contained in:
parent
9450cd9ce5
commit
becc3a0424
15 changed files with 1238 additions and 1 deletions
108
codex-rs/Cargo.lock
generated
108
codex-rs/Cargo.lock
generated
|
|
@ -922,6 +922,20 @@ dependencies = [
|
|||
"piper",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bm25"
|
||||
version = "2.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1cbd8ffdfb7b4c2ff038726178a780a94f90525ed0ad264c0afaa75dd8c18a64"
|
||||
dependencies = [
|
||||
"cached",
|
||||
"deunicode",
|
||||
"fxhash",
|
||||
"rust-stemmers",
|
||||
"stop-words",
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "borsh"
|
||||
version = "1.6.0"
|
||||
|
|
@ -1000,6 +1014,39 @@ dependencies = [
|
|||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached"
|
||||
version = "0.56.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "801927ee168e17809ab8901d9f01f700cd7d8d6a6527997fee44e4b0327a253c"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"cached_proc_macro",
|
||||
"cached_proc_macro_types",
|
||||
"hashbrown 0.15.5",
|
||||
"once_cell",
|
||||
"thiserror 2.0.18",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached_proc_macro"
|
||||
version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9225bdcf4e4a9a4c08bf16607908eb2fbf746828d5e0b5e019726dbf6571f201"
|
||||
dependencies = [
|
||||
"darling 0.20.11",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached_proc_macro_types"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ade8366b8bd5ba243f0a58f036cc0ca8a2f069cff1a2351ef1cac6b083e16fc0"
|
||||
|
||||
[[package]]
|
||||
name = "cassowary"
|
||||
version = "0.3.0"
|
||||
|
|
@ -1555,6 +1602,7 @@ dependencies = [
|
|||
"async-channel",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"bm25",
|
||||
"chardetng",
|
||||
"chrono",
|
||||
"clap",
|
||||
|
|
@ -2704,6 +2752,16 @@ dependencies = [
|
|||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
|
||||
dependencies = [
|
||||
"darling_core 0.20.11",
|
||||
"darling_macro 0.20.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.21.3"
|
||||
|
|
@ -2724,6 +2782,20 @@ dependencies = [
|
|||
"darling_macro 0.23.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim 0.11.1",
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.21.3"
|
||||
|
|
@ -2751,6 +2823,17 @@ dependencies = [
|
|||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
|
||||
dependencies = [
|
||||
"darling_core 0.20.11",
|
||||
"quote",
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.21.3"
|
||||
|
|
@ -2955,6 +3038,12 @@ dependencies = [
|
|||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deunicode"
|
||||
version = "1.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "abd57806937c9cc163efc8ea3910e00a62e2aeb0b8119f1793a978088f8f6b04"
|
||||
|
||||
[[package]]
|
||||
name = "diff"
|
||||
version = "0.1.13"
|
||||
|
|
@ -7223,6 +7312,16 @@ dependencies = [
|
|||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-stemmers"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e46a2036019fdb888131db7a4c847a1063a7493f971ed94ea82c67eada63ca54"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.27"
|
||||
|
|
@ -8456,6 +8555,15 @@ version = "1.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "stop-words"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "645a3d441ccf4bf47f2e4b7681461986681a6eeea9937d4c3bc9febd61d17c71"
|
||||
dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "streaming-iterator"
|
||||
version = "0.1.9"
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ async-stream = "0.3.6"
|
|||
async-trait = "0.1.89"
|
||||
axum = { version = "0.8", default-features = false }
|
||||
base64 = "0.22.1"
|
||||
bm25 = "2.3.2"
|
||||
bytes = "1.10.1"
|
||||
chardetng = "0.1.17"
|
||||
chrono = "0.4.43"
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ arc-swap = "1.8.0"
|
|||
async-channel = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bm25 = { workspace = true }
|
||||
chardetng = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
|
|
|
|||
|
|
@ -245,6 +245,9 @@
|
|||
"runtime_metrics": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"search_tool": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"shell_snapshot": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
|
@ -1288,6 +1291,9 @@
|
|||
"runtime_metrics": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"search_tool": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"shell_snapshot": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -139,6 +139,8 @@ use crate::mcp::effective_mcp_servers;
|
|||
use crate::mcp::maybe_prompt_and_install_mcp_dependencies;
|
||||
use crate::mcp::with_codex_apps_mcp;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::filter_codex_apps_mcp_tools_only;
|
||||
use crate::mcp_connection_manager::filter_mcp_tools_by_name;
|
||||
use crate::mentions::build_connector_slug_counts;
|
||||
use crate::mentions::build_skill_name_counts;
|
||||
use crate::mentions::collect_explicit_app_paths;
|
||||
|
|
@ -504,6 +506,9 @@ pub(crate) struct Session {
|
|||
next_internal_sub_id: AtomicU64,
|
||||
}
|
||||
|
||||
const SEARCH_TOOL_DEVELOPER_INSTRUCTIONS: &str =
|
||||
include_str!("../templates/search_tool/developer_instructions.md");
|
||||
|
||||
/// The context needed for a single turn of the thread.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TurnContext {
|
||||
|
|
@ -1257,6 +1262,21 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn merge_mcp_tool_selection(&self, tool_names: Vec<String>) -> Vec<String> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.merge_mcp_tool_selection(tool_names)
|
||||
}
|
||||
|
||||
pub(crate) async fn get_mcp_tool_selection(&self) -> Option<Vec<String>> {
|
||||
let state = self.state.lock().await;
|
||||
state.get_mcp_tool_selection()
|
||||
}
|
||||
|
||||
pub(crate) async fn clear_mcp_tool_selection(&self) {
|
||||
let mut state = self.state.lock().await;
|
||||
state.clear_mcp_tool_selection();
|
||||
}
|
||||
|
||||
async fn record_initial_history(&self, conversation_history: InitialHistory) {
|
||||
let turn_context = self.new_default_turn().await;
|
||||
match conversation_history {
|
||||
|
|
@ -2182,6 +2202,11 @@ impl Session {
|
|||
if let Some(developer_instructions) = turn_context.developer_instructions.as_deref() {
|
||||
items.push(DeveloperInstructions::new(developer_instructions.to_string()).into());
|
||||
}
|
||||
if turn_context.tools_config.search_tool {
|
||||
items.push(
|
||||
DeveloperInstructions::new(SEARCH_TOOL_DEVELOPER_INSTRUCTIONS.to_string()).into(),
|
||||
);
|
||||
}
|
||||
// Add developer instructions from collaboration_mode if they exist and are non-empty
|
||||
let (collaboration_mode, base_instructions) = {
|
||||
let state = self.state.lock().await;
|
||||
|
|
@ -4119,6 +4144,7 @@ async fn run_sampling_request(
|
|||
.list_all_tools()
|
||||
.or_cancel(&cancellation_token)
|
||||
.await?;
|
||||
|
||||
let connectors_for_tools = if turn_context.config.features.enabled(Feature::Apps) {
|
||||
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
|
||||
Some(filter_connectors_for_input(
|
||||
|
|
@ -4130,9 +4156,25 @@ async fn run_sampling_request(
|
|||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(connectors) = connectors_for_tools.as_ref() {
|
||||
|
||||
if turn_context.config.features.enabled(Feature::SearchTool) {
|
||||
let mut selected_mcp_tools =
|
||||
if let Some(selected_tools) = sess.get_mcp_tool_selection().await {
|
||||
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tools)
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
if let Some(connectors) = connectors_for_tools.as_ref() {
|
||||
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, connectors);
|
||||
selected_mcp_tools.extend(apps_mcp_tools);
|
||||
}
|
||||
|
||||
mcp_tools = selected_mcp_tools;
|
||||
} else if let Some(connectors) = connectors_for_tools.as_ref() {
|
||||
mcp_tools = filter_codex_apps_mcp_tools(mcp_tools, connectors);
|
||||
}
|
||||
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(
|
||||
|
|
@ -4958,6 +5000,8 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
|
|||
pub(crate) use tests::make_session_and_context;
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context_with_rx;
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_configuration_for_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
@ -4967,6 +5011,7 @@ mod tests {
|
|||
use crate::config::test_config;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp_connection_manager::ToolInfo;
|
||||
use crate::shell::default_user_shell;
|
||||
use crate::tools::format_exec_output_str;
|
||||
|
||||
|
|
@ -5006,6 +5051,8 @@ mod tests {
|
|||
|
||||
use codex_protocol::mcp::CallToolResult as McpCallToolResult;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::JsonObject;
|
||||
use rmcp::model::Tool;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
|
|
@ -5042,6 +5089,30 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
fn make_mcp_tool(
|
||||
server_name: &str,
|
||||
tool_name: &str,
|
||||
connector_id: Option<&str>,
|
||||
connector_name: Option<&str>,
|
||||
) -> ToolInfo {
|
||||
ToolInfo {
|
||||
server_name: server_name.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
tool: Tool {
|
||||
name: tool_name.to_string().into(),
|
||||
title: None,
|
||||
description: Some(format!("Test tool: {tool_name}").into()),
|
||||
input_schema: Arc::new(JsonObject::default()),
|
||||
output_schema: None,
|
||||
annotations: None,
|
||||
icons: None,
|
||||
meta: None,
|
||||
},
|
||||
connector_id: connector_id.map(str::to_string),
|
||||
connector_name: connector_name.map(str::to_string),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_base_instructions_no_user_content() {
|
||||
let prompt_with_apply_patch_instructions =
|
||||
|
|
@ -5145,6 +5216,93 @@ mod tests {
|
|||
assert_eq!(selected, Vec::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_tool_selection_keeps_codex_apps_tools_without_mentions() {
|
||||
let selected_tool_names = vec![
|
||||
"mcp__codex_apps__calendar_create_event".to_string(),
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
];
|
||||
let mcp_tools = HashMap::from([
|
||||
(
|
||||
"mcp__codex_apps__calendar_create_event".to_string(),
|
||||
make_mcp_tool(
|
||||
CODEX_APPS_MCP_SERVER_NAME,
|
||||
"calendar_create_event",
|
||||
Some("calendar"),
|
||||
Some("Calendar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
make_mcp_tool("rmcp", "echo", None, None),
|
||||
),
|
||||
]);
|
||||
|
||||
let mut selected_mcp_tools =
|
||||
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tool_names);
|
||||
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
|
||||
let connectors = filter_connectors_for_input(
|
||||
connectors,
|
||||
&[user_message("run the selected tools")],
|
||||
&[],
|
||||
&HashMap::new(),
|
||||
);
|
||||
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, &connectors);
|
||||
selected_mcp_tools.extend(apps_mcp_tools);
|
||||
|
||||
let mut tool_names: Vec<String> = selected_mcp_tools.into_keys().collect();
|
||||
tool_names.sort();
|
||||
assert_eq!(
|
||||
tool_names,
|
||||
vec![
|
||||
"mcp__codex_apps__calendar_create_event".to_string(),
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apps_mentions_add_codex_apps_tools_to_search_selected_set() {
|
||||
let selected_tool_names = vec!["mcp__rmcp__echo".to_string()];
|
||||
let mcp_tools = HashMap::from([
|
||||
(
|
||||
"mcp__codex_apps__calendar_create_event".to_string(),
|
||||
make_mcp_tool(
|
||||
CODEX_APPS_MCP_SERVER_NAME,
|
||||
"calendar_create_event",
|
||||
Some("calendar"),
|
||||
Some("Calendar"),
|
||||
),
|
||||
),
|
||||
(
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
make_mcp_tool("rmcp", "echo", None, None),
|
||||
),
|
||||
]);
|
||||
|
||||
let mut selected_mcp_tools =
|
||||
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tool_names);
|
||||
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
|
||||
let connectors = filter_connectors_for_input(
|
||||
connectors,
|
||||
&[user_message("use $calendar and then echo the response")],
|
||||
&[],
|
||||
&HashMap::new(),
|
||||
);
|
||||
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, &connectors);
|
||||
selected_mcp_tools.extend(apps_mcp_tools);
|
||||
|
||||
let mut tool_names: Vec<String> = selected_mcp_tools.into_keys().collect();
|
||||
tool_names.sort();
|
||||
assert_eq!(
|
||||
tool_names,
|
||||
vec![
|
||||
"mcp__codex_apps__calendar_create_event".to_string(),
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reconstruct_history_matches_live_compactions() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
|
|
@ -5849,6 +6007,46 @@ mod tests {
|
|||
)
|
||||
}
|
||||
|
||||
pub(crate) async fn make_session_configuration_for_tests() -> SessionConfiguration {
|
||||
let codex_home = tempfile::tempdir().expect("create temp dir");
|
||||
let config = build_test_config(codex_home.path()).await;
|
||||
let config = Arc::new(config);
|
||||
let model = ModelsManager::get_model_offline(config.model.as_deref());
|
||||
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
|
||||
let reasoning_effort = config.model_reasoning_effort;
|
||||
let collaboration_mode = CollaborationMode {
|
||||
mode: ModeKind::Default,
|
||||
settings: Settings {
|
||||
model,
|
||||
reasoning_effort,
|
||||
developer_instructions: None,
|
||||
},
|
||||
};
|
||||
|
||||
SessionConfiguration {
|
||||
provider: config.model_provider.clone(),
|
||||
collaboration_mode,
|
||||
model_reasoning_summary: config.model_reasoning_summary,
|
||||
developer_instructions: config.developer_instructions.clone(),
|
||||
user_instructions: config.user_instructions.clone(),
|
||||
personality: config.personality,
|
||||
base_instructions: config
|
||||
.base_instructions
|
||||
.clone()
|
||||
.unwrap_or_else(|| model_info.get_model_instructions(config.personality)),
|
||||
compact_prompt: config.compact_prompt.clone(),
|
||||
approval_policy: config.approval_policy.clone(),
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
windows_sandbox_level: WindowsSandboxLevel::from_config(&config),
|
||||
cwd: config.cwd.clone(),
|
||||
codex_home: config.codex_home.clone(),
|
||||
thread_name: None,
|
||||
original_config_do_not_use: Arc::clone(&config),
|
||||
session_source: SessionSource::Exec,
|
||||
dynamic_tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
|
||||
let (tx_event, _rx_event) = async_channel::unbounded();
|
||||
let codex_home = tempfile::tempdir().expect("create temp dir");
|
||||
|
|
|
|||
|
|
@ -87,6 +87,8 @@ pub enum Feature {
|
|||
/// Allow the model to request web searches that fetch cached content.
|
||||
/// Takes precedence over `WebSearchRequest`.
|
||||
WebSearchCached,
|
||||
/// Allow the model to search MCP tools via BM25 before exposing them.
|
||||
SearchTool,
|
||||
/// Use the bubblewrap-based Linux sandbox pipeline.
|
||||
UseLinuxSandboxBwrap,
|
||||
/// Allow the model to request approval and propose exec rules.
|
||||
|
|
@ -432,6 +434,12 @@ pub const FEATURES: &[FeatureSpec] = &[
|
|||
stage: Stage::Deprecated,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::SearchTool,
|
||||
key: "search_tool",
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
// Experimental program. Rendered in the `/experimental` menu for users.
|
||||
FeatureSpec {
|
||||
id: Feature::RuntimeMetrics,
|
||||
|
|
|
|||
|
|
@ -843,6 +843,37 @@ fn filter_tools(tools: Vec<ToolInfo>, filter: ToolFilter) -> Vec<ToolInfo> {
|
|||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn filter_codex_apps_mcp_tools_only(
|
||||
mut mcp_tools: HashMap<String, ToolInfo>,
|
||||
connectors: &[crate::connectors::AppInfo],
|
||||
) -> HashMap<String, ToolInfo> {
|
||||
let allowed: HashSet<&str> = connectors
|
||||
.iter()
|
||||
.map(|connector| connector.id.as_str())
|
||||
.collect();
|
||||
|
||||
mcp_tools.retain(|_, tool| {
|
||||
if tool.server_name != CODEX_APPS_MCP_SERVER_NAME {
|
||||
return false;
|
||||
}
|
||||
let Some(connector_id) = tool.connector_id.as_deref() else {
|
||||
return false;
|
||||
};
|
||||
allowed.contains(connector_id)
|
||||
});
|
||||
|
||||
mcp_tools
|
||||
}
|
||||
|
||||
pub(crate) fn filter_mcp_tools_by_name(
|
||||
mut mcp_tools: HashMap<String, ToolInfo>,
|
||||
selected_tools: &[String],
|
||||
) -> HashMap<String, ToolInfo> {
|
||||
let allowed: HashSet<&str> = selected_tools.iter().map(String::as_str).collect();
|
||||
mcp_tools.retain(|name, _| allowed.contains(name.as_str()));
|
||||
mcp_tools
|
||||
}
|
||||
|
||||
fn normalize_codex_apps_tool_title(
|
||||
server_name: &str,
|
||||
connector_name: Option<&str>,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ pub(crate) struct SessionState {
|
|||
pub(crate) pending_resume_previous_model: Option<String>,
|
||||
/// Startup regular task pre-created during session initialization.
|
||||
pub(crate) startup_regular_task: Option<RegularTask>,
|
||||
pub(crate) active_mcp_tool_selection: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
|
|
@ -45,6 +46,7 @@ impl SessionState {
|
|||
initial_context_seeded: false,
|
||||
pending_resume_previous_model: None,
|
||||
startup_regular_task: None,
|
||||
active_mcp_tool_selection: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -140,6 +142,32 @@ impl SessionState {
|
|||
pub(crate) fn take_startup_regular_task(&mut self) -> Option<RegularTask> {
|
||||
self.startup_regular_task.take()
|
||||
}
|
||||
|
||||
pub(crate) fn merge_mcp_tool_selection(&mut self, tool_names: Vec<String>) -> Vec<String> {
|
||||
if tool_names.is_empty() {
|
||||
return self.active_mcp_tool_selection.clone().unwrap_or_default();
|
||||
}
|
||||
|
||||
let mut merged = self.active_mcp_tool_selection.take().unwrap_or_default();
|
||||
let mut seen: HashSet<String> = merged.iter().cloned().collect();
|
||||
|
||||
for tool_name in tool_names {
|
||||
if seen.insert(tool_name.clone()) {
|
||||
merged.push(tool_name);
|
||||
}
|
||||
}
|
||||
|
||||
self.active_mcp_tool_selection = Some(merged.clone());
|
||||
merged
|
||||
}
|
||||
|
||||
pub(crate) fn get_mcp_tool_selection(&self) -> Option<Vec<String>> {
|
||||
self.active_mcp_tool_selection.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn clear_mcp_tool_selection(&mut self) {
|
||||
self.active_mcp_tool_selection = None;
|
||||
}
|
||||
}
|
||||
|
||||
// Sometimes new snapshots don't include credits or plan information.
|
||||
|
|
@ -155,3 +183,79 @@ fn merge_rate_limit_fields(
|
|||
}
|
||||
snapshot
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codex::make_session_configuration_for_tests;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn merge_mcp_tool_selection_deduplicates_and_preserves_order() {
|
||||
let session_configuration = make_session_configuration_for_tests().await;
|
||||
let mut state = SessionState::new(session_configuration);
|
||||
|
||||
let merged = state.merge_mcp_tool_selection(vec![
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
"mcp__rmcp__image".to_string(),
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
]);
|
||||
assert_eq!(
|
||||
merged,
|
||||
vec![
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
"mcp__rmcp__image".to_string(),
|
||||
]
|
||||
);
|
||||
|
||||
let merged = state.merge_mcp_tool_selection(vec![
|
||||
"mcp__rmcp__image".to_string(),
|
||||
"mcp__rmcp__search".to_string(),
|
||||
]);
|
||||
assert_eq!(
|
||||
merged,
|
||||
vec![
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
"mcp__rmcp__image".to_string(),
|
||||
"mcp__rmcp__search".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn merge_mcp_tool_selection_empty_input_is_noop() {
|
||||
let session_configuration = make_session_configuration_for_tests().await;
|
||||
let mut state = SessionState::new(session_configuration);
|
||||
state.merge_mcp_tool_selection(vec![
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
"mcp__rmcp__image".to_string(),
|
||||
]);
|
||||
|
||||
let merged = state.merge_mcp_tool_selection(Vec::new());
|
||||
assert_eq!(
|
||||
merged,
|
||||
vec![
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
"mcp__rmcp__image".to_string(),
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
state.get_mcp_tool_selection(),
|
||||
Some(vec![
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
"mcp__rmcp__image".to_string(),
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_mcp_tool_selection_removes_selection() {
|
||||
let session_configuration = make_session_configuration_for_tests().await;
|
||||
let mut state = SessionState::new(session_configuration);
|
||||
state.merge_mcp_tool_selection(vec!["mcp__rmcp__echo".to_string()]);
|
||||
|
||||
state.clear_mcp_tool_selection();
|
||||
|
||||
assert_eq!(state.get_mcp_tool_selection(), None);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ impl Session {
|
|||
task: T,
|
||||
) {
|
||||
self.abort_all_tasks(TurnAbortReason::Replaced).await;
|
||||
self.clear_mcp_tool_selection().await;
|
||||
self.seed_initial_context_if_needed(turn_context.as_ref())
|
||||
.await;
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ mod mcp_resource;
|
|||
mod plan;
|
||||
mod read_file;
|
||||
mod request_user_input;
|
||||
mod search_tool_bm25;
|
||||
mod shell;
|
||||
mod test_sync;
|
||||
mod unified_exec;
|
||||
|
|
@ -28,6 +29,8 @@ pub use plan::PlanHandler;
|
|||
pub use read_file::ReadFileHandler;
|
||||
pub use request_user_input::RequestUserInputHandler;
|
||||
pub(crate) use request_user_input::request_user_input_tool_description;
|
||||
pub(crate) use search_tool_bm25::DEFAULT_LIMIT as SEARCH_TOOL_BM25_DEFAULT_LIMIT;
|
||||
pub use search_tool_bm25::SearchToolBm25Handler;
|
||||
pub use shell::ShellCommandHandler;
|
||||
pub use shell::ShellHandler;
|
||||
pub use test_sync::TestSyncHandler;
|
||||
|
|
|
|||
217
codex-rs/core/src/tools/handlers/search_tool_bm25.rs
Normal file
217
codex-rs/core/src/tools/handlers/search_tool_bm25.rs
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
use async_trait::async_trait;
|
||||
use bm25::Document;
|
||||
use bm25::Language;
|
||||
use bm25::SearchEngineBuilder;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp_connection_manager::ToolInfo;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::handlers::parse_arguments;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct SearchToolBm25Handler;
|
||||
|
||||
pub(crate) const DEFAULT_LIMIT: usize = 8;
|
||||
|
||||
fn default_limit() -> usize {
|
||||
DEFAULT_LIMIT
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SearchToolBm25Args {
|
||||
query: String,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ToolEntry {
|
||||
name: String,
|
||||
server_name: String,
|
||||
title: Option<String>,
|
||||
description: Option<String>,
|
||||
connector_id: Option<String>,
|
||||
connector_name: Option<String>,
|
||||
input_keys: Vec<String>,
|
||||
search_text: String,
|
||||
}
|
||||
|
||||
impl ToolEntry {
|
||||
fn new(name: String, info: ToolInfo) -> Self {
|
||||
let input_keys = info
|
||||
.tool
|
||||
.input_schema
|
||||
.get("properties")
|
||||
.and_then(serde_json::Value::as_object)
|
||||
.map(|map| map.keys().cloned().collect::<Vec<_>>())
|
||||
.unwrap_or_default();
|
||||
let search_text = build_search_text(&name, &info, &input_keys);
|
||||
Self {
|
||||
name,
|
||||
server_name: info.server_name,
|
||||
title: info.tool.title,
|
||||
description: info
|
||||
.tool
|
||||
.description
|
||||
.map(|description| description.to_string()),
|
||||
connector_id: info.connector_id,
|
||||
connector_name: info.connector_name,
|
||||
input_keys,
|
||||
search_text,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for SearchToolBm25Handler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
payload, session, ..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::Fatal(
|
||||
"search_tool_bm25 handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: SearchToolBm25Args = parse_arguments(&arguments)?;
|
||||
let query = args.query.trim();
|
||||
if query.is_empty() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"query must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let limit = args.limit;
|
||||
|
||||
let mcp_tools = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.list_all_tools()
|
||||
.await;
|
||||
|
||||
let mut entries: Vec<ToolEntry> = mcp_tools
|
||||
.into_iter()
|
||||
.map(|(name, info)| ToolEntry::new(name, info))
|
||||
.collect();
|
||||
entries.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
if entries.is_empty() {
|
||||
let active_selected_tools = session.get_mcp_tool_selection().await.unwrap_or_default();
|
||||
let content = json!({
|
||||
"query": query,
|
||||
"total_tools": 0,
|
||||
"active_selected_tools": active_selected_tools,
|
||||
"tools": [],
|
||||
})
|
||||
.to_string();
|
||||
return Ok(ToolOutput::Function {
|
||||
body: FunctionCallOutputBody::Text(content),
|
||||
success: Some(true),
|
||||
});
|
||||
}
|
||||
|
||||
let documents: Vec<Document<usize>> = entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, entry)| Document::new(idx, entry.search_text.clone()))
|
||||
.collect();
|
||||
let search_engine =
|
||||
SearchEngineBuilder::<usize>::with_documents(Language::English, documents).build();
|
||||
let results = search_engine.search(query, limit);
|
||||
|
||||
let mut selected_tools = Vec::new();
|
||||
let mut result_payloads = Vec::new();
|
||||
for result in results {
|
||||
let Some(entry) = entries.get(result.document.id) else {
|
||||
continue;
|
||||
};
|
||||
selected_tools.push(entry.name.clone());
|
||||
result_payloads.push(json!({
|
||||
"name": entry.name.clone(),
|
||||
"server": entry.server_name.clone(),
|
||||
"title": entry.title.clone(),
|
||||
"description": entry.description.clone(),
|
||||
"connector_id": entry.connector_id.clone(),
|
||||
"connector_name": entry.connector_name.clone(),
|
||||
"input_keys": entry.input_keys.clone(),
|
||||
"score": result.score,
|
||||
}));
|
||||
}
|
||||
|
||||
let active_selected_tools = session.merge_mcp_tool_selection(selected_tools).await;
|
||||
|
||||
let content = json!({
|
||||
"query": query,
|
||||
"total_tools": entries.len(),
|
||||
"active_selected_tools": active_selected_tools,
|
||||
"tools": result_payloads,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
body: FunctionCallOutputBody::Text(content),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn build_search_text(name: &str, info: &ToolInfo, input_keys: &[String]) -> String {
|
||||
let mut parts = vec![
|
||||
name.to_string(),
|
||||
info.tool_name.clone(),
|
||||
info.server_name.clone(),
|
||||
];
|
||||
|
||||
if let Some(title) = info.tool.title.as_deref()
|
||||
&& !title.trim().is_empty()
|
||||
{
|
||||
parts.push(title.to_string());
|
||||
}
|
||||
|
||||
if let Some(description) = info.tool.description.as_deref()
|
||||
&& !description.trim().is_empty()
|
||||
{
|
||||
parts.push(description.to_string());
|
||||
}
|
||||
|
||||
if let Some(connector_name) = info.connector_name.as_deref()
|
||||
&& !connector_name.trim().is_empty()
|
||||
{
|
||||
parts.push(connector_name.to_string());
|
||||
}
|
||||
|
||||
if let Some(connector_id) = info.connector_id.as_deref()
|
||||
&& !connector_id.trim().is_empty()
|
||||
{
|
||||
parts.push(connector_id.to_string());
|
||||
}
|
||||
|
||||
if !input_keys.is_empty() {
|
||||
parts.extend(input_keys.iter().cloned());
|
||||
}
|
||||
|
||||
parts.join(" ")
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ use crate::client_common::tools::ToolSpec;
|
|||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
use crate::tools::handlers::PLAN_TOOL;
|
||||
use crate::tools::handlers::SEARCH_TOOL_BM25_DEFAULT_LIMIT;
|
||||
use crate::tools::handlers::apply_patch::create_apply_patch_freeform_tool;
|
||||
use crate::tools::handlers::apply_patch::create_apply_patch_json_tool;
|
||||
use crate::tools::handlers::collab::DEFAULT_WAIT_TIMEOUT_MS;
|
||||
|
|
@ -31,6 +32,7 @@ pub(crate) struct ToolsConfig {
|
|||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
pub web_search_mode: Option<WebSearchMode>,
|
||||
pub supports_image_input: bool,
|
||||
pub search_tool: bool,
|
||||
pub collab_tools: bool,
|
||||
pub collaboration_modes_tools: bool,
|
||||
pub request_rule_enabled: bool,
|
||||
|
|
@ -54,6 +56,7 @@ impl ToolsConfig {
|
|||
let include_collab_tools = features.enabled(Feature::Collab);
|
||||
let include_collaboration_modes_tools = features.enabled(Feature::CollaborationModes);
|
||||
let request_rule_enabled = features.enabled(Feature::RequestRule);
|
||||
let include_search_tool = features.enabled(Feature::SearchTool);
|
||||
|
||||
let shell_type = if !features.enabled(Feature::ShellTool) {
|
||||
ConfigShellToolType::Disabled
|
||||
|
|
@ -85,6 +88,7 @@ impl ToolsConfig {
|
|||
apply_patch_tool_type,
|
||||
web_search_mode: *web_search_mode,
|
||||
supports_image_input: model_info.input_modalities.contains(&InputModality::Image),
|
||||
search_tool: include_search_tool,
|
||||
collab_tools: include_collab_tools,
|
||||
collaboration_modes_tools: include_collaboration_modes_tools,
|
||||
request_rule_enabled,
|
||||
|
|
@ -800,6 +804,36 @@ fn create_grep_files_tool() -> ToolSpec {
|
|||
})
|
||||
}
|
||||
|
||||
fn create_search_tool_bm25_tool() -> ToolSpec {
|
||||
let properties = BTreeMap::from([
|
||||
(
|
||||
"query".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Search query for MCP tools.".to_string()),
|
||||
},
|
||||
),
|
||||
(
|
||||
"limit".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(format!(
|
||||
"Maximum number of tools to return (defaults to {SEARCH_TOOL_BM25_DEFAULT_LIMIT})."
|
||||
)),
|
||||
},
|
||||
),
|
||||
]);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "search_tool_bm25".to_string(),
|
||||
description: "Searches MCP tool metadata with BM25 and exposes matching tools for the next model call.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["query".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_read_file_tool() -> ToolSpec {
|
||||
let indentation_properties = BTreeMap::from([
|
||||
(
|
||||
|
|
@ -1261,6 +1295,7 @@ pub(crate) fn build_specs(
|
|||
use crate::tools::handlers::PlanHandler;
|
||||
use crate::tools::handlers::ReadFileHandler;
|
||||
use crate::tools::handlers::RequestUserInputHandler;
|
||||
use crate::tools::handlers::SearchToolBm25Handler;
|
||||
use crate::tools::handlers::ShellCommandHandler;
|
||||
use crate::tools::handlers::ShellHandler;
|
||||
use crate::tools::handlers::TestSyncHandler;
|
||||
|
|
@ -1280,6 +1315,7 @@ pub(crate) fn build_specs(
|
|||
let mcp_resource_handler = Arc::new(McpResourceHandler);
|
||||
let shell_command_handler = Arc::new(ShellCommandHandler);
|
||||
let request_user_input_handler = Arc::new(RequestUserInputHandler);
|
||||
let search_tool_handler = Arc::new(SearchToolBm25Handler);
|
||||
|
||||
match &config.shell_type {
|
||||
ConfigShellToolType::Default => {
|
||||
|
|
@ -1334,6 +1370,11 @@ pub(crate) fn build_specs(
|
|||
builder.register_handler("request_user_input", request_user_input_handler);
|
||||
}
|
||||
|
||||
if config.search_tool {
|
||||
builder.push_spec_with_parallel_support(create_search_tool_bm25_tool(), true);
|
||||
builder.register_handler("search_tool_bm25", search_tool_handler);
|
||||
}
|
||||
|
||||
if let Some(apply_patch_tool_type) = &config.apply_patch_tool_type {
|
||||
match apply_patch_tool_type {
|
||||
ApplyPatchToolType::Freeform => {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
# MCP tool discovery
|
||||
|
||||
When `search_tool_bm25` is available, MCP tools (`mcp__...`) are hidden until you search for them.
|
||||
|
||||
Follow this workflow:
|
||||
|
||||
1. Call `search_tool_bm25` with:
|
||||
- `query` (required): focused terms that describe the capability you need.
|
||||
- `limit` (optional): maximum number of tools to return (default `8`).
|
||||
2. Use the returned `tools` list to decide which MCP tools are relevant.
|
||||
3. Matching tools are added to `active_selected_tools`. Only tools in `active_selected_tools` are available for the remainder of the current turn.
|
||||
4. Repeated searches in the same turn are additive: new matches are unioned into `active_selected_tools`.
|
||||
5. `active_selected_tools` resets at the start of the next turn.
|
||||
|
||||
Notes:
|
||||
- Core tools remain available without searching.
|
||||
- If you are unsure, start with `limit` between 5 and 10 to see a broader set of tools.
|
||||
- `query` is matched against MCP tool metadata fields:
|
||||
- `name`
|
||||
- `tool_name`
|
||||
- `server_name`
|
||||
- `title`
|
||||
- `description`
|
||||
- `connector_name`
|
||||
- `connector_id`
|
||||
- input schema property keys (`input_keys`)
|
||||
- When the user asks to search/lookup/query any external system (logs, tickets, metrics, Slack, etc.), you must call `search_tool_bm25` first before running any shell command or repo search.
|
||||
- Only use shell commands if (a) MCP tools for that system are not available or not sufficient, and (b) the user explicitly wants a local file/CLI search.
|
||||
- If unsure which system/tool applies, ask a clarifying question after checking MCP tools.
|
||||
|
|
@ -104,6 +104,7 @@ mod resume_warning;
|
|||
mod review;
|
||||
mod rmcp_client;
|
||||
mod rollout_list_find;
|
||||
mod search_tool;
|
||||
mod seatbelt;
|
||||
mod shell_command;
|
||||
mod shell_serialization;
|
||||
|
|
|
|||
488
codex-rs/core/tests/suite/search_tool.rs
Normal file
488
codex-rs/core/tests/suite/search_tool.rs
Normal file
|
|
@ -0,0 +1,488 @@
|
|||
#![cfg(not(target_os = "windows"))]
|
||||
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::config::types::McpServerConfig;
|
||||
use codex_core::config::types::McpServerTransportConfig;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use core_test_support::responses::ResponsesRequest;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::stdio_server_bin;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
const SEARCH_TOOL_INSTRUCTION_SNIPPETS: [&str; 2] = [
|
||||
"MCP tools (`mcp__...`) are hidden until you search for them.",
|
||||
"Matching tools are added to `active_selected_tools`.",
|
||||
];
|
||||
|
||||
fn tool_names(body: &Value) -> Vec<String> {
|
||||
body.get("tools")
|
||||
.and_then(Value::as_array)
|
||||
.map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.filter_map(|tool| {
|
||||
tool.get("name")
|
||||
.or_else(|| tool.get("type"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn developer_messages(body: &Value) -> Vec<String> {
|
||||
body.get("input")
|
||||
.and_then(Value::as_array)
|
||||
.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
if item.get("role").and_then(Value::as_str) != Some("developer") {
|
||||
return None;
|
||||
}
|
||||
let content = item.get("content").and_then(Value::as_array)?;
|
||||
let texts: Vec<&str> = content
|
||||
.iter()
|
||||
.filter_map(|entry| entry.get("text").and_then(Value::as_str))
|
||||
.collect();
|
||||
if texts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(texts.join("\n"))
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn search_tool_output_payload(request: &ResponsesRequest, call_id: &str) -> Value {
|
||||
let (content, _success) = request
|
||||
.function_call_output_content_and_success(call_id)
|
||||
.expect("search_tool_bm25 function_call_output should be present");
|
||||
let content = content.expect("search_tool_bm25 output should include content");
|
||||
serde_json::from_str(&content).expect("search_tool_bm25 content should be valid JSON")
|
||||
}
|
||||
|
||||
fn active_selected_tools(payload: &Value) -> Vec<String> {
|
||||
payload
|
||||
.get("active_selected_tools")
|
||||
.and_then(Value::as_array)
|
||||
.expect("active_selected_tools should be an array")
|
||||
.iter()
|
||||
.map(|value| {
|
||||
value
|
||||
.as_str()
|
||||
.expect("active_selected_tools entries should be strings")
|
||||
.to_string()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn search_tool_flag_adds_tool() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mock = mount_sse_sequence(
|
||||
&server,
|
||||
vec![sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-1"),
|
||||
])],
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.features.enable(Feature::SearchTool);
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn_with_policies(
|
||||
"list tools",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let body = mock.single_request().body_json();
|
||||
let tools = tool_names(&body);
|
||||
assert!(
|
||||
tools.iter().any(|name| name == "search_tool_bm25"),
|
||||
"tools list should include search_tool_bm25 when enabled: {tools:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn search_tool_adds_developer_instructions() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mock = mount_sse_sequence(
|
||||
&server,
|
||||
vec![sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-1"),
|
||||
])],
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.features.enable(Feature::SearchTool);
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn_with_policies(
|
||||
"list tools",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let body = mock.single_request().body_json();
|
||||
let developer_texts = developer_messages(&body);
|
||||
assert!(
|
||||
developer_texts.iter().any(|text| {
|
||||
SEARCH_TOOL_INSTRUCTION_SNIPPETS
|
||||
.iter()
|
||||
.all(|snippet| text.contains(snippet))
|
||||
}),
|
||||
"developer instructions should include search tool workflow: {developer_texts:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn search_tool_hides_mcp_tools_without_search() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mock = mount_sse_sequence(
|
||||
&server,
|
||||
vec![sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-1"),
|
||||
])],
|
||||
)
|
||||
.await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.features.enable(Feature::SearchTool);
|
||||
let mut servers = config.mcp_servers.get().clone();
|
||||
servers.insert(
|
||||
"rmcp".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: rmcp_test_server_bin,
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
required: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
},
|
||||
);
|
||||
config
|
||||
.mcp_servers
|
||||
.set(servers)
|
||||
.expect("test mcp servers should accept any configuration");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn_with_policies(
|
||||
"hello tools",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let body = mock.single_request().body_json();
|
||||
let tools = tool_names(&body);
|
||||
assert!(
|
||||
tools.iter().any(|name| name == "search_tool_bm25"),
|
||||
"tools list should include search_tool_bm25 when enabled: {tools:?}"
|
||||
);
|
||||
assert!(
|
||||
!tools.iter().any(|name| name == "mcp__rmcp__echo"),
|
||||
"tools list should not include MCP tools before search: {tools:?}"
|
||||
);
|
||||
assert!(
|
||||
!tools.iter().any(|name| name == "mcp__rmcp__image"),
|
||||
"tools list should not include MCP tools before search: {tools:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn search_tool_selection_persists_within_turn_and_resets_next_turn() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let call_id = "tool-search";
|
||||
let args = json!({
|
||||
"query": "echo",
|
||||
"limit": 1,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "search_tool_bm25", &serde_json::to_string(&args)?),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-2", "done again"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
let mock = mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.features.enable(Feature::SearchTool);
|
||||
let mut servers = config.mcp_servers.get().clone();
|
||||
servers.insert(
|
||||
"rmcp".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: rmcp_test_server_bin,
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
required: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
},
|
||||
);
|
||||
config
|
||||
.mcp_servers
|
||||
.set(servers)
|
||||
.expect("test mcp servers should accept any configuration");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn_with_policies(
|
||||
"find the echo tool",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
test.submit_turn_with_policies(
|
||||
"hello again",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = mock.requests();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
3,
|
||||
"expected 3 requests, got {}",
|
||||
requests.len()
|
||||
);
|
||||
|
||||
let first_tools = tool_names(&requests[0].body_json());
|
||||
assert!(
|
||||
!first_tools.iter().any(|name| name == "mcp__rmcp__echo"),
|
||||
"first request should not include MCP tools before search: {first_tools:?}"
|
||||
);
|
||||
|
||||
let second_tools = tool_names(&requests[1].body_json());
|
||||
assert!(
|
||||
second_tools.iter().any(|name| name == "mcp__rmcp__echo"),
|
||||
"second request should include selected MCP tool: {second_tools:?}"
|
||||
);
|
||||
assert!(
|
||||
!second_tools.iter().any(|name| name == "mcp__rmcp__image"),
|
||||
"second request should only include selected MCP tool: {second_tools:?}"
|
||||
);
|
||||
|
||||
let search_output_payload = search_tool_output_payload(&requests[1], call_id);
|
||||
assert!(
|
||||
search_output_payload.get("selected_tools").is_none(),
|
||||
"selected_tools should not be returned: {search_output_payload:?}"
|
||||
);
|
||||
assert_eq!(
|
||||
active_selected_tools(&search_output_payload),
|
||||
vec!["mcp__rmcp__echo".to_string()],
|
||||
);
|
||||
|
||||
let third_tools = tool_names(&requests[2].body_json());
|
||||
assert!(
|
||||
!third_tools.iter().any(|name| name == "mcp__rmcp__echo"),
|
||||
"third request should not include MCP tools after turn reset: {third_tools:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn search_tool_selection_unions_results_within_turn() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let first_call_id = "tool-search-echo";
|
||||
let second_call_id = "tool-search-image";
|
||||
let first_args = json!({
|
||||
"query": "echo",
|
||||
"limit": 1,
|
||||
});
|
||||
let second_args = json!({
|
||||
"query": "image",
|
||||
"limit": 1,
|
||||
});
|
||||
let responses = vec![
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(
|
||||
first_call_id,
|
||||
"search_tool_bm25",
|
||||
&serde_json::to_string(&first_args)?,
|
||||
),
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_response_created("resp-2"),
|
||||
ev_function_call(
|
||||
second_call_id,
|
||||
"search_tool_bm25",
|
||||
&serde_json::to_string(&second_args)?,
|
||||
),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-3"),
|
||||
]),
|
||||
];
|
||||
let mock = mount_sse_sequence(&server, responses).await;
|
||||
|
||||
let rmcp_test_server_bin = stdio_server_bin()?;
|
||||
let mut builder = test_codex().with_config(move |config| {
|
||||
config.features.enable(Feature::SearchTool);
|
||||
let mut servers = config.mcp_servers.get().clone();
|
||||
servers.insert(
|
||||
"rmcp".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: rmcp_test_server_bin,
|
||||
args: Vec::new(),
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
},
|
||||
enabled: true,
|
||||
required: false,
|
||||
disabled_reason: None,
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
enabled_tools: None,
|
||||
disabled_tools: None,
|
||||
scopes: None,
|
||||
},
|
||||
);
|
||||
config
|
||||
.mcp_servers
|
||||
.set(servers)
|
||||
.expect("test mcp servers should accept any configuration");
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn_with_policies(
|
||||
"find echo and image tools",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let requests = mock.requests();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
3,
|
||||
"expected 3 requests, got {}",
|
||||
requests.len()
|
||||
);
|
||||
|
||||
let first_tools = tool_names(&requests[0].body_json());
|
||||
assert!(
|
||||
!first_tools.iter().any(|name| name == "mcp__rmcp__echo"),
|
||||
"first request should not include MCP tools before search: {first_tools:?}"
|
||||
);
|
||||
|
||||
let second_tools = tool_names(&requests[1].body_json());
|
||||
assert!(
|
||||
second_tools.iter().any(|name| name == "mcp__rmcp__echo"),
|
||||
"second request should include echo after first search: {second_tools:?}"
|
||||
);
|
||||
assert!(
|
||||
!second_tools.iter().any(|name| name == "mcp__rmcp__image"),
|
||||
"second request should not include image before second search runs: {second_tools:?}"
|
||||
);
|
||||
|
||||
let third_tools = tool_names(&requests[2].body_json());
|
||||
assert!(
|
||||
third_tools.iter().any(|name| name == "mcp__rmcp__echo"),
|
||||
"third request should still include echo: {third_tools:?}"
|
||||
);
|
||||
assert!(
|
||||
third_tools.iter().any(|name| name == "mcp__rmcp__image"),
|
||||
"third request should include image after second search: {third_tools:?}"
|
||||
);
|
||||
|
||||
let second_search_payload = search_tool_output_payload(&requests[2], second_call_id);
|
||||
assert!(
|
||||
second_search_payload.get("selected_tools").is_none(),
|
||||
"selected_tools should not be returned: {second_search_payload:?}"
|
||||
);
|
||||
assert_eq!(
|
||||
active_selected_tools(&second_search_payload),
|
||||
vec![
|
||||
"mcp__rmcp__echo".to_string(),
|
||||
"mcp__rmcp__image".to_string(),
|
||||
],
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue