diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index f767c05f4..a7f5b72ea 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -4557,7 +4557,7 @@ async fn fatal_tool_error_stops_turn_and_reports_error() { .expect("tool call present"); let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); let err = router - .dispatch_tool_call( + .dispatch_tool_call_with_code_mode_result( Arc::clone(&session), Arc::clone(&turn_context), tracker, @@ -4565,7 +4565,8 @@ async fn fatal_tool_error_stops_turn_and_reports_error() { ToolCallSource::Direct, ) .await - .expect_err("expected fatal error"); + .err() + .expect("expected fatal error"); match err { FunctionCallError::Fatal(message) => { diff --git a/codex-rs/core/src/tools/code_mode/mod.rs b/codex-rs/core/src/tools/code_mode/mod.rs index a7a7c40ea..c8e1e0c16 100644 --- a/codex-rs/core/src/tools/code_mode/mod.rs +++ b/codex-rs/core/src/tools/code_mode/mod.rs @@ -14,6 +14,7 @@ use serde_json::Value as JsonValue; use crate::client_common::tools::ToolSpec; use crate::codex::Session; use crate::codex::TurnContext; +use crate::function_tool::FunctionCallError; use crate::tools::ToolRouter; use crate::tools::code_mode_description::augment_tool_spec_for_code_mode; use crate::tools::code_mode_description::code_mode_tool_reference; @@ -303,9 +304,11 @@ async fn call_nested_tool( tool_name: String, input: Option, cancellation_token: tokio_util::sync::CancellationToken, -) -> JsonValue { +) -> Result { if tool_name == PUBLIC_TOOL_NAME { - return JsonValue::String(format!("{PUBLIC_TOOL_NAME} cannot invoke itself")); + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} cannot invoke itself" + ))); } let payload = @@ -316,12 +319,12 @@ async fn call_nested_tool( tool, raw_arguments, }, - Err(error) => return JsonValue::String(error), + Err(error) => return Err(FunctionCallError::RespondToModel(error)), } } else { match build_nested_tool_payload(tool_runtime.find_spec(&tool_name), &tool_name, input) { Ok(payload) => payload, - Err(error) => return JsonValue::String(error), + Err(error) => return Err(FunctionCallError::RespondToModel(error)), } }; @@ -333,12 +336,8 @@ async fn call_nested_tool( }; let result = tool_runtime .handle_tool_call_with_source(call, ToolCallSource::CodeMode, cancellation_token) - .await; - - match result { - Ok(result) => result.code_mode_result(), - Err(error) => JsonValue::String(error.to_string()), - } + .await?; + Ok(result.code_mode_result()) } fn tool_kind_for_spec(spec: &ToolSpec) -> protocol::CodeModeToolKind { diff --git a/codex-rs/core/src/tools/code_mode/protocol.rs b/codex-rs/core/src/tools/code_mode/protocol.rs index 8116d95b4..2e72e1229 100644 --- a/codex-rs/core/src/tools/code_mode/protocol.rs +++ b/codex-rs/core/src/tools/code_mode/protocol.rs @@ -70,6 +70,8 @@ pub(super) enum HostToNodeMessage { request_id: String, id: String, code_mode_result: JsonValue, + #[serde(default)] + error_text: Option, }, } diff --git a/codex-rs/core/src/tools/code_mode/runner.cjs b/codex-rs/core/src/tools/code_mode/runner.cjs index 2fcfddeaf..408725555 100644 --- a/codex-rs/core/src/tools/code_mode/runner.cjs +++ b/codex-rs/core/src/tools/code_mode/runner.cjs @@ -595,6 +595,10 @@ function createProtocol() { return; } pending.delete(message.request_id + ':' + message.id); + if (typeof message.error_text === 'string') { + entry.reject(new Error(message.error_text)); + return; + } entry.resolve(message.code_mode_result ?? ''); return; } diff --git a/codex-rs/core/src/tools/code_mode/worker.rs b/codex-rs/core/src/tools/code_mode/worker.rs index 7456f9c6f..5853f3abe 100644 --- a/codex-rs/core/src/tools/code_mode/worker.rs +++ b/codex-rs/core/src/tools/code_mode/worker.rs @@ -14,6 +14,7 @@ use super::process::write_message; use super::protocol::HostToNodeMessage; use super::protocol::NodeToHostMessage; use crate::tools::parallel::ToolCallRuntime; + pub(crate) struct CodeModeWorker { shutdown_tx: Option>, } @@ -53,17 +54,23 @@ impl CodeModeProcess { let tool_runtime = tool_runtime.clone(); let stdin = stdin.clone(); tokio::spawn(async move { + let result = call_nested_tool( + exec, + tool_runtime, + tool_call.name, + tool_call.input, + CancellationToken::new(), + ) + .await; + let (code_mode_result, error_text) = match result { + Ok(code_mode_result) => (code_mode_result, None), + Err(error) => (serde_json::Value::Null, Some(error.to_string())), + }; let response = HostToNodeMessage::Response { request_id: tool_call.request_id, id: tool_call.id, - code_mode_result: call_nested_tool( - exec, - tool_runtime, - tool_call.name, - tool_call.input, - CancellationToken::new(), - ) - .await, + code_mode_result, + error_text, }; if let Err(err) = write_message(&stdin, &response).await { warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}"); diff --git a/codex-rs/core/src/tools/js_repl/mod.rs b/codex-rs/core/src/tools/js_repl/mod.rs index 392f311ce..fcdc0f8ec 100644 --- a/codex-rs/core/src/tools/js_repl/mod.rs +++ b/codex-rs/core/src/tools/js_repl/mod.rs @@ -1607,8 +1607,8 @@ impl JsReplManager { let tracker = Arc::clone(&exec.tracker); match router - .dispatch_tool_call( - session.clone(), + .dispatch_tool_call_with_code_mode_result( + session, turn, tracker, call, @@ -1616,7 +1616,8 @@ impl JsReplManager { ) .await { - Ok(response) => { + Ok(result) => { + let response = result.into_response(); let summary = Self::summarize_tool_call_response(&response); match serde_json::to_value(response) { Ok(value) => { diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index be7a28ed7..0cc0989fb 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -16,6 +16,7 @@ use crate::error::CodexErr; use crate::function_tool::FunctionCallError; use crate::tools::context::AbortedToolOutput; use crate::tools::context::SharedTurnDiffTracker; +use crate::tools::context::ToolPayload; use crate::tools::registry::AnyToolResult; use crate::tools::router::ToolCall; use crate::tools::router::ToolCallSource; @@ -57,9 +58,17 @@ impl ToolCallRuntime { call: ToolCall, cancellation_token: CancellationToken, ) -> impl std::future::Future> { + let error_call = call.clone(); let future = self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token); - async move { future.await.map(AnyToolResult::into_response) }.in_current_span() + async move { + match future.await { + Ok(response) => Ok(response.into_response()), + Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), + Err(other) => Ok(Self::failure_response(error_call, other)), + } + } + .in_current_span() } #[instrument(level = "trace", skip_all)] @@ -68,7 +77,7 @@ impl ToolCallRuntime { call: ToolCall, source: ToolCallSource, cancellation_token: CancellationToken, - ) -> impl std::future::Future> { + ) -> impl std::future::Future> { let supports_parallel = self.router.tool_supports_parallel(&call.tool_name); let router = Arc::clone(&self.router); let session = Arc::clone(&self.session); @@ -78,7 +87,7 @@ impl ToolCallRuntime { let started = Instant::now(); let dispatch_span = trace_span!( - "dispatch_tool_call", + "dispatch_tool_call_with_code_mode_result", otel.name = call.tool_name.as_str(), tool_name = call.tool_name.as_str(), call_id = call.call_id.as_str(), @@ -115,20 +124,42 @@ impl ToolCallRuntime { })); async move { - match handle.await { - Ok(Ok(response)) => Ok(response), - Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)), - Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())), - Err(err) => Err(CodexErr::Fatal(format!( - "tool task failed to receive: {err:?}" - ))), - } + handle.await.map_err(|err| { + FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}")) + })? } .in_current_span() } } impl ToolCallRuntime { + fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem { + let message = err.to_string(); + match call.payload { + ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput { + call_id: call.call_id, + status: "completed".to_string(), + execution: "client".to_string(), + tools: Vec::new(), + }, + ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput { + call_id: call.call_id, + name: None, + output: codex_protocol::models::FunctionCallOutputPayload { + body: codex_protocol::models::FunctionCallOutputBody::Text(message), + success: Some(false), + }, + }, + _ => ResponseInputItem::FunctionCallOutput { + call_id: call.call_id, + output: codex_protocol::models::FunctionCallOutputPayload { + body: codex_protocol::models::FunctionCallOutputBody::Text(message), + success: Some(false), + }, + }, + } + } + fn aborted_response(call: &ToolCall, secs: f32) -> AnyToolResult { AnyToolResult { call_id: call.call_id.clone(), diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index b41c59ef9..8544eb404 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -5,11 +5,9 @@ use crate::function_tool::FunctionCallError; use crate::mcp_connection_manager::ToolInfo; use crate::sandboxing::SandboxPermissions; use crate::tools::code_mode::is_code_mode_nested_tool; -use crate::tools::context::FunctionToolOutput; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; -use crate::tools::context::ToolSearchOutput; use crate::tools::discoverable::DiscoverableTool; use crate::tools::registry::AnyToolResult; use crate::tools::registry::ConfiguredToolSpec; @@ -18,7 +16,6 @@ use crate::tools::spec::ToolsConfig; use crate::tools::spec::build_specs_with_discoverable_tools; use codex_protocol::dynamic_tools::DynamicToolSpec; use codex_protocol::models::LocalShellAction; -use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; use codex_protocol::models::SearchToolCallParams; use codex_protocol::models::ShellToolCallParams; @@ -214,21 +211,6 @@ impl ToolRouter { } } - #[instrument(level = "trace", skip_all, err)] - pub async fn dispatch_tool_call( - &self, - session: Arc, - turn: Arc, - tracker: SharedTurnDiffTracker, - call: ToolCall, - source: ToolCallSource, - ) -> Result { - Ok(self - .dispatch_tool_call_with_code_mode_result(session, turn, tracker, call, source) - .await? - .into_response()) - } - #[instrument(level = "trace", skip_all, err)] pub async fn dispatch_tool_call_with_code_mode_result( &self, @@ -244,23 +226,14 @@ impl ToolRouter { call_id, payload, } = call; - let payload_outputs_custom = matches!(payload, ToolPayload::Custom { .. }); - let payload_outputs_tool_search = matches!(payload, ToolPayload::ToolSearch { .. }); - let failure_call_id = call_id.clone(); if source == ToolCallSource::Direct && turn.tools_config.js_repl_tools_only && !matches!(tool_name.as_str(), "js_repl" | "js_repl_reset") { - let err = FunctionCallError::RespondToModel( + return Err(FunctionCallError::RespondToModel( "direct tool calls are disabled; use js_repl and codex.tool(...) instead" .to_string(), - ); - return Ok(Self::failure_result( - failure_call_id, - payload_outputs_custom, - payload_outputs_tool_search, - err, )); } @@ -274,53 +247,7 @@ impl ToolRouter { payload, }; - match self.registry.dispatch_any(invocation).await { - Ok(response) => Ok(response), - Err(FunctionCallError::Fatal(message)) => Err(FunctionCallError::Fatal(message)), - Err(err) => Ok(Self::failure_result( - failure_call_id, - payload_outputs_custom, - payload_outputs_tool_search, - err, - )), - } - } - - fn failure_result( - call_id: String, - payload_outputs_custom: bool, - payload_outputs_tool_search: bool, - err: FunctionCallError, - ) -> AnyToolResult { - let message = err.to_string(); - if payload_outputs_tool_search { - AnyToolResult { - call_id, - payload: ToolPayload::ToolSearch { - arguments: SearchToolCallParams { - query: String::new(), - limit: None, - }, - }, - result: Box::new(ToolSearchOutput { tools: Vec::new() }), - } - } else if payload_outputs_custom { - AnyToolResult { - call_id, - payload: ToolPayload::Custom { - input: String::new(), - }, - result: Box::new(FunctionToolOutput::from_text(message, Some(false))), - } - } else { - AnyToolResult { - call_id, - payload: ToolPayload::Function { - arguments: "{}".to_string(), - }, - result: Box::new(FunctionToolOutput::from_text(message, Some(false))), - } - } + self.registry.dispatch_any(invocation).await } } #[cfg(test)] diff --git a/codex-rs/core/src/tools/router_tests.rs b/codex-rs/core/src/tools/router_tests.rs index 6350323d1..641adb56d 100644 --- a/codex-rs/core/src/tools/router_tests.rs +++ b/codex-rs/core/src/tools/router_tests.rs @@ -1,9 +1,9 @@ use std::sync::Arc; use crate::codex::make_session_and_context; +use crate::function_tool::FunctionCallError; use crate::tools::context::ToolPayload; use crate::turn_diff_tracker::TurnDiffTracker; -use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; use super::ToolCall; @@ -50,20 +50,21 @@ async fn js_repl_tools_only_blocks_direct_tool_calls() -> anyhow::Result<()> { }, }; let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - let response = router - .dispatch_tool_call(session, turn, tracker, call, ToolCallSource::Direct) - .await?; - - match response { - ResponseInputItem::FunctionCallOutput { output, .. } => { - let content = output.text_content().unwrap_or_default(); - assert!( - content.contains("direct tool calls are disabled"), - "unexpected tool call message: {content}", - ); - } - other => panic!("expected function call output, got {other:?}"), - } + let err = router + .dispatch_tool_call_with_code_mode_result( + session, + turn, + tracker, + call, + ToolCallSource::Direct, + ) + .await + .err() + .expect("direct tool calls should be blocked"); + let FunctionCallError::RespondToModel(message) = err else { + panic!("expected RespondToModel, got {err:?}"); + }; + assert!(message.contains("direct tool calls are disabled")); Ok(()) } @@ -107,20 +108,22 @@ async fn js_repl_tools_only_allows_js_repl_source_calls() -> anyhow::Result<()> }, }; let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - let response = router - .dispatch_tool_call(session, turn, tracker, call, ToolCallSource::JsRepl) - .await?; - - match response { - ResponseInputItem::FunctionCallOutput { output, .. } => { - let content = output.text_content().unwrap_or_default(); - assert!( - !content.contains("direct tool calls are disabled"), - "js_repl source should bypass direct-call policy gate" - ); - } - other => panic!("expected function call output, got {other:?}"), - } + let err = router + .dispatch_tool_call_with_code_mode_result( + session, + turn, + tracker, + call, + ToolCallSource::JsRepl, + ) + .await + .err() + .expect("shell call with empty args should fail"); + let message = err.to_string(); + assert!( + !message.contains("direct tool calls are disabled"), + "js_repl source should bypass direct-call policy gate" + ); Ok(()) } diff --git a/codex-rs/core/tests/suite/code_mode.rs b/codex-rs/core/tests/suite/code_mode.rs index 53e3d9e8c..2a7652691 100644 --- a/codex-rs/core/tests/suite/code_mode.rs +++ b/codex-rs/core/tests/suite/code_mode.rs @@ -537,6 +537,46 @@ Error:\ boom\n Ok(()) } +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_surfaces_handler_errors_as_exceptions() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let (_test, second_mock) = run_code_mode_turn( + &server, + "surface nested tool handler failures as script exceptions", + r#" +try { + await tools.exec_command({}); + text("no-exception"); +} catch (error) { + text(`caught:${error?.message ?? String(error)}`); +} +"#, + false, + ) + .await?; + + let request = second_mock.single_request(); + let (output, success) = custom_tool_output_body_and_success(&request, "call-1"); + assert_ne!( + success, + Some(false), + "script should catch the nested tool error: {output}" + ); + assert!( + output.contains("caught:"), + "expected caught exception text in output: {output}" + ); + assert!( + !output.contains("no-exception"), + "nested tool error should not allow success path: {output}" + ); + + Ok(()) +} + #[cfg_attr(windows, ignore = "no exec_command on Windows")] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn code_mode_can_yield_and_resume_with_wait() -> Result<()> {