diff --git a/codex-rs/core/src/tools/code_mode.rs b/codex-rs/core/src/tools/code_mode.rs deleted file mode 100644 index 0f8ac1a8e..000000000 --- a/codex-rs/core/src/tools/code_mode.rs +++ /dev/null @@ -1,920 +0,0 @@ -use std::collections::HashMap; -use std::path::PathBuf; -use std::sync::Arc; -use std::time::Duration; - -use crate::client_common::tools::ToolSpec; -use crate::codex::Session; -use crate::codex::TurnContext; -use crate::config::Config; -use crate::features::Feature; -use crate::function_tool::FunctionCallError; -use crate::tools::ToolRouter; -use crate::tools::code_mode_description::augment_tool_spec_for_code_mode; -use crate::tools::code_mode_description::code_mode_tool_reference; -use crate::tools::context::FunctionToolOutput; -use crate::tools::context::SharedTurnDiffTracker; -use crate::tools::context::ToolPayload; -use crate::tools::js_repl::resolve_compatible_node; -use crate::tools::router::ToolCall; -use crate::tools::router::ToolCallSource; -use crate::tools::router::ToolRouterParams; -use crate::truncate::TruncationPolicy; -use crate::truncate::formatted_truncate_text_content_items_with_policy; -use crate::truncate::truncate_function_output_items_with_policy; -use crate::unified_exec::resolve_max_tokens; -use codex_protocol::models::FunctionCallOutputContentItem; -use serde::Deserialize; -use serde::Serialize; -use serde_json::Value as JsonValue; -use tokio::io::AsyncBufReadExt; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::BufReader; -use tokio::sync::Mutex; -use tokio::sync::mpsc; -use tokio::sync::oneshot; -use tokio::task::JoinHandle; -use tracing::warn; - -const CODE_MODE_RUNNER_SOURCE: &str = include_str!("code_mode_runner.cjs"); -const CODE_MODE_BRIDGE_SOURCE: &str = include_str!("code_mode_bridge.js"); -pub(crate) const PUBLIC_TOOL_NAME: &str = "exec"; -pub(crate) const WAIT_TOOL_NAME: &str = "exec_wait"; -pub(crate) const DEFAULT_WAIT_YIELD_TIME_MS: u64 = 10_000; - -#[derive(Clone)] -struct ExecContext { - session: Arc, - turn: Arc, - tracker: SharedTurnDiffTracker, -} - -pub(crate) struct CodeModeProcess { - child: tokio::process::Child, - stdin: Arc>, - stdout_task: JoinHandle<()>, - // A set of current requests waiting for a response from code mode host - response_waiters: Arc>>>, - // When there is an active worker it listens for tool calls from code mode and processes them - tool_call_rx: Arc>>, -} - -pub(crate) struct CodeModeWorker { - shutdown_tx: Option>, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "snake_case")] -struct CodeModeToolCall { - request_id: String, - id: String, - name: String, - #[serde(default)] - input: Option, -} - -impl Drop for CodeModeWorker { - fn drop(&mut self) { - if let Some(shutdown_tx) = self.shutdown_tx.take() { - let _ = shutdown_tx.send(()); - } - } -} - -impl CodeModeProcess { - fn worker(&self, exec: ExecContext) -> CodeModeWorker { - let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); - let stdin = Arc::clone(&self.stdin); - let tool_call_rx = Arc::clone(&self.tool_call_rx); - tokio::spawn(async move { - loop { - let tool_call = tokio::select! { - _ = &mut shutdown_rx => break, - tool_call = async { - let mut tool_call_rx = tool_call_rx.lock().await; - tool_call_rx.recv().await - } => tool_call, - }; - let Some(tool_call) = tool_call else { - break; - }; - let exec = exec.clone(); - let stdin = Arc::clone(&stdin); - tokio::spawn(async move { - let response = HostToNodeMessage::Response { - request_id: tool_call.request_id, - id: tool_call.id, - code_mode_result: call_nested_tool(exec, tool_call.name, tool_call.input) - .await, - }; - if let Err(err) = write_message(&stdin, &response).await { - warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}"); - } - }); - } - }); - - CodeModeWorker { - shutdown_tx: Some(shutdown_tx), - } - } - - async fn send( - &mut self, - request_id: &str, - message: &HostToNodeMessage, - ) -> Result { - if self.stdout_task.is_finished() { - return Err(std::io::Error::other(format!( - "{PUBLIC_TOOL_NAME} runner is not available" - ))); - } - - let (tx, rx) = oneshot::channel(); - self.response_waiters - .lock() - .await - .insert(request_id.to_string(), tx); - if let Err(err) = write_message(&self.stdin, message).await { - self.response_waiters.lock().await.remove(request_id); - return Err(err); - } - - match rx.await { - Ok(message) => Ok(message), - Err(_) => Err(std::io::Error::other(format!( - "{PUBLIC_TOOL_NAME} runner is not available" - ))), - } - } - - fn has_exited(&mut self) -> Result { - self.child - .try_wait() - .map(|status| status.is_some()) - .map_err(std::io::Error::other) - } -} - -pub(crate) struct CodeModeService { - js_repl_node_path: Option, - stored_values: Mutex>, - process: Arc>>, - next_session_id: Mutex, -} - -impl CodeModeService { - pub(crate) fn new(js_repl_node_path: Option) -> Self { - Self { - js_repl_node_path, - stored_values: Mutex::new(HashMap::new()), - process: Arc::new(Mutex::new(None)), - next_session_id: Mutex::new(1), - } - } - - pub(crate) async fn stored_values(&self) -> HashMap { - self.stored_values.lock().await.clone() - } - - pub(crate) async fn replace_stored_values(&self, values: HashMap) { - *self.stored_values.lock().await = values; - } - - async fn ensure_started( - &self, - ) -> Result>, std::io::Error> { - let mut process_slot = self.process.lock().await; - let needs_spawn = match process_slot.as_mut() { - Some(process) => !matches!(process.has_exited(), Ok(false)), - None => true, - }; - if needs_spawn { - let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()) - .await - .map_err(std::io::Error::other)?; - *process_slot = Some(spawn_code_mode_process(&node_path).await?); - } - drop(process_slot); - Ok(self.process.clone().lock_owned().await) - } - - pub(crate) async fn start_turn_worker( - &self, - session: &Arc, - turn: &Arc, - tracker: &SharedTurnDiffTracker, - ) -> Option { - if !turn.features.enabled(Feature::CodeMode) { - return None; - } - let exec = ExecContext { - session: Arc::clone(session), - turn: Arc::clone(turn), - tracker: Arc::clone(tracker), - }; - let mut process_slot = match self.ensure_started().await { - Ok(process_slot) => process_slot, - Err(err) => { - warn!("failed to start {PUBLIC_TOOL_NAME} worker for turn: {err}"); - return None; - } - }; - let Some(process) = process_slot.as_mut() else { - warn!( - "failed to start {PUBLIC_TOOL_NAME} worker for turn: {PUBLIC_TOOL_NAME} runner failed to start" - ); - return None; - }; - Some(process.worker(exec)) - } - - pub(crate) async fn allocate_session_id(&self) -> i32 { - let mut next_session_id = self.next_session_id.lock().await; - let session_id = *next_session_id; - *next_session_id = next_session_id.saturating_add(1); - session_id - } - - pub(crate) async fn allocate_request_id(&self) -> String { - uuid::Uuid::new_v4().to_string() - } -} - -#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] -#[serde(rename_all = "snake_case")] -enum CodeModeToolKind { - Function, - Freeform, -} - -#[derive(Clone, Debug, Serialize)] -struct EnabledTool { - tool_name: String, - #[serde(rename = "module")] - module_path: String, - namespace: Vec, - name: String, - description: String, - kind: CodeModeToolKind, -} - -#[derive(Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum HostToNodeMessage { - Start { - request_id: String, - session_id: i32, - enabled_tools: Vec, - stored_values: HashMap, - source: String, - }, - Poll { - request_id: String, - session_id: i32, - yield_time_ms: u64, - }, - Terminate { - request_id: String, - session_id: i32, - }, - Response { - request_id: String, - id: String, - code_mode_result: JsonValue, - }, -} - -#[derive(Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum NodeToHostMessage { - ToolCall { - #[serde(flatten)] - tool_call: CodeModeToolCall, - }, - Yielded { - request_id: String, - content_items: Vec, - }, - Terminated { - request_id: String, - content_items: Vec, - }, - Result { - request_id: String, - content_items: Vec, - stored_values: HashMap, - #[serde(default)] - error_text: Option, - #[serde(default)] - max_output_tokens_per_exec_call: Option, - }, -} - -enum CodeModeSessionProgress { - Finished(FunctionToolOutput), - Yielded { output: FunctionToolOutput }, -} - -enum CodeModeExecutionStatus { - Completed, - Failed, - Running(i32), - Terminated, -} - -pub(crate) fn instructions(config: &Config) -> Option { - if !config.features.enabled(Feature::CodeMode) { - return None; - } - - let mut section = String::from("## Exec\n"); - section.push_str(&format!( - "- Use `{PUBLIC_TOOL_NAME}` for JavaScript execution in a Node-backed `node:vm` context.\n", - )); - section.push_str(&format!( - "- `{PUBLIC_TOOL_NAME}` is a freeform/custom tool. Direct `{PUBLIC_TOOL_NAME}` calls must send raw JavaScript tool input. Do not wrap code in JSON, quotes, or markdown code fences.\n", - )); - section.push_str(&format!( - "- Direct tool calls remain available while `{PUBLIC_TOOL_NAME}` is enabled.\n", - )); - section.push_str(&format!( - "- `{PUBLIC_TOOL_NAME}` uses the same Node runtime resolution as `js_repl`. If needed, point `js_repl_node_path` at the Node binary you want Codex to use.\n", - )); - section.push_str("- Import nested tools from `tools.js`, for example `import { exec_command } from \"tools.js\"` or `import { ALL_TOOLS } from \"tools.js\"` to inspect the available `{ module, name, description }` entries. Namespaced tools are also available from `tools/.js`; MCP tools use `tools/mcp/.js`, for example `import { append_notebook_logs_chart } from \"tools/mcp/ologs.js\"`. Nested tool calls resolve to their code-mode result values.\n"); - section.push_str(&format!( - "- Import `{{ background, output_text, output_image, set_max_output_tokens_per_exec_call, set_yield_time, store, load }}` from `@openai/code_mode` (or `\"openai/code_mode\"`). `output_text(value)` surfaces text back to the model and stringifies non-string objects with `JSON.stringify(...)` when possible. `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs. `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, and `load(key)` returns a cloned stored value or `undefined`. `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate direct `{PUBLIC_TOOL_NAME}` returns; `{WAIT_TOOL_NAME}` uses its own `max_tokens` argument instead and defaults to `10000`. `set_yield_time(value)` asks `{PUBLIC_TOOL_NAME}` to return early if the script is still running after that many milliseconds so `{WAIT_TOOL_NAME}` can resume it later. `background()` returns a yielded `{PUBLIC_TOOL_NAME}` response immediately while the script keeps running in the background. The returned content starts with a separate `Script completed`, `Script failed`, or `Script running with session ID …` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker.\n", - )); - section.push_str(&format!( - "- If `{PUBLIC_TOOL_NAME}` returns `Script running with session ID …`, call `{WAIT_TOOL_NAME}` with that `session_id` to keep waiting for more output, completion, or termination.\n", - )); - section.push_str( - "- Function tools require JSON object arguments. Freeform tools require raw strings.\n", - ); - section.push_str("- `add_content(value)` remains available for compatibility. It is synchronous and accepts a content item, an array of content items, or a string. Structured nested-tool results should be converted to text first, for example with `JSON.stringify(...)`.\n"); - section - .push_str("- Only content passed to `output_text(...)`, `output_image(...)`, or `add_content(value)` is surfaced back to the model."); - Some(section) -} - -pub(crate) async fn execute( - session: Arc, - turn: Arc, - tracker: SharedTurnDiffTracker, - code: String, -) -> Result { - let exec = ExecContext { - session, - turn, - tracker, - }; - let enabled_tools = build_enabled_tools(&exec).await; - let service = &exec.session.services.code_mode_service; - let stored_values = service.stored_values().await; - let source = build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?; - let session_id = service.allocate_session_id().await; - let request_id = service.allocate_request_id().await; - let process_slot = service - .ensure_started() - .await - .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?; - let started_at = std::time::Instant::now(); - let message = HostToNodeMessage::Start { - request_id: request_id.clone(), - session_id, - enabled_tools, - stored_values, - source, - }; - let result = { - let mut process_slot = process_slot; - let Some(process) = process_slot.as_mut() else { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner failed to start" - ))); - }; - let message = process - .send(&request_id, &message) - .await - .map_err(|err| err.to_string()); - let message = match message { - Ok(message) => message, - Err(error) => return Err(FunctionCallError::RespondToModel(error)), - }; - handle_node_message(&exec, session_id, message, None, started_at).await - }; - match result { - Ok(CodeModeSessionProgress::Finished(output)) - | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), - Err(error) => Err(FunctionCallError::RespondToModel(error)), - } -} - -pub(crate) async fn wait( - session: Arc, - turn: Arc, - tracker: SharedTurnDiffTracker, - session_id: i32, - yield_time_ms: u64, - max_output_tokens: Option, - terminate: bool, -) -> Result { - let exec = ExecContext { - session, - turn, - tracker, - }; - let request_id = exec - .session - .services - .code_mode_service - .allocate_request_id() - .await; - let started_at = std::time::Instant::now(); - let message = if terminate { - HostToNodeMessage::Terminate { - request_id: request_id.clone(), - session_id, - } - } else { - HostToNodeMessage::Poll { - request_id: request_id.clone(), - session_id, - yield_time_ms, - } - }; - let process_slot = exec - .session - .services - .code_mode_service - .ensure_started() - .await - .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?; - let result = { - let mut process_slot = process_slot; - let Some(process) = process_slot.as_mut() else { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner failed to start" - ))); - }; - if !matches!(process.has_exited(), Ok(false)) { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner failed to start" - ))); - } - let message = process - .send(&request_id, &message) - .await - .map_err(|err| err.to_string()); - let message = match message { - Ok(message) => message, - Err(error) => return Err(FunctionCallError::RespondToModel(error)), - }; - handle_node_message( - &exec, - session_id, - message, - Some(max_output_tokens), - started_at, - ) - .await - }; - match result { - Ok(CodeModeSessionProgress::Finished(output)) - | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), - Err(error) => Err(FunctionCallError::RespondToModel(error)), - } -} - -async fn handle_node_message( - exec: &ExecContext, - session_id: i32, - message: NodeToHostMessage, - poll_max_output_tokens: Option>, - started_at: std::time::Instant, -) -> Result { - match message { - NodeToHostMessage::ToolCall { .. } => Err(format!( - "{PUBLIC_TOOL_NAME} received an unexpected tool call response" - )), - NodeToHostMessage::Yielded { content_items, .. } => { - let mut delta_items = output_content_items_from_json_values(content_items)?; - delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); - prepend_script_status( - &mut delta_items, - CodeModeExecutionStatus::Running(session_id), - started_at.elapsed(), - ); - Ok(CodeModeSessionProgress::Yielded { - output: FunctionToolOutput::from_content(delta_items, Some(true)), - }) - } - NodeToHostMessage::Terminated { content_items, .. } => { - let mut delta_items = output_content_items_from_json_values(content_items)?; - delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); - prepend_script_status( - &mut delta_items, - CodeModeExecutionStatus::Terminated, - started_at.elapsed(), - ); - Ok(CodeModeSessionProgress::Finished( - FunctionToolOutput::from_content(delta_items, Some(true)), - )) - } - NodeToHostMessage::Result { - content_items, - stored_values, - error_text, - max_output_tokens_per_exec_call, - .. - } => { - exec.session - .services - .code_mode_service - .replace_stored_values(stored_values) - .await; - let mut delta_items = output_content_items_from_json_values(content_items)?; - let success = error_text.is_none(); - if let Some(error_text) = error_text { - delta_items.push(FunctionCallOutputContentItem::InputText { - text: format!("Script error:\n{error_text}"), - }); - } - - let mut delta_items = truncate_code_mode_result( - delta_items, - poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), - ); - prepend_script_status( - &mut delta_items, - if success { - CodeModeExecutionStatus::Completed - } else { - CodeModeExecutionStatus::Failed - }, - started_at.elapsed(), - ); - Ok(CodeModeSessionProgress::Finished( - FunctionToolOutput::from_content(delta_items, Some(success)), - )) - } - } -} - -async fn spawn_code_mode_process( - node_path: &std::path::Path, -) -> Result { - let mut cmd = tokio::process::Command::new(node_path); - cmd.arg("--experimental-vm-modules"); - cmd.arg("--eval"); - cmd.arg(CODE_MODE_RUNNER_SOURCE); - cmd.stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); - - let mut child = cmd.spawn().map_err(std::io::Error::other)?; - let stdout = child.stdout.take().ok_or_else(|| { - std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdout")) - })?; - let stderr = child.stderr.take().ok_or_else(|| { - std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stderr")) - })?; - let stdin = child - .stdin - .take() - .ok_or_else(|| std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdin")))?; - let stdin = Arc::new(Mutex::new(stdin)); - let response_waiters = Arc::new(Mutex::new(HashMap::< - String, - oneshot::Sender, - >::new())); - let (tool_call_tx, tool_call_rx) = mpsc::unbounded_channel(); - - tokio::spawn(async move { - let mut reader = BufReader::new(stderr); - let mut buf = Vec::new(); - match reader.read_to_end(&mut buf).await { - Ok(_) => { - let stderr = String::from_utf8_lossy(&buf).trim().to_string(); - if !stderr.is_empty() { - warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}"); - } - } - Err(err) => { - warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}"); - } - } - }); - let stdout_task = tokio::spawn({ - let response_waiters = Arc::clone(&response_waiters); - async move { - let mut stdout_lines = BufReader::new(stdout).lines(); - loop { - let line = match stdout_lines.next_line().await { - Ok(line) => line, - Err(err) => { - warn!("failed to read {PUBLIC_TOOL_NAME} stdout: {err}"); - break; - } - }; - let Some(line) = line else { - break; - }; - if line.trim().is_empty() { - continue; - } - let message: NodeToHostMessage = match serde_json::from_str(&line) { - Ok(message) => message, - Err(err) => { - warn!("failed to parse {PUBLIC_TOOL_NAME} stdout message: {err}"); - break; - } - }; - match message { - NodeToHostMessage::ToolCall { tool_call } => { - let _ = tool_call_tx.send(tool_call); - } - message => { - let request_id = message_request_id(&message).to_string(); - if let Some(waiter) = response_waiters.lock().await.remove(&request_id) { - let _ = waiter.send(message); - } - } - } - } - response_waiters.lock().await.clear(); - } - }); - - Ok(CodeModeProcess { - child, - stdin, - stdout_task, - response_waiters, - tool_call_rx: Arc::new(Mutex::new(tool_call_rx)), - }) -} - -async fn write_message( - stdin: &Arc>, - message: &HostToNodeMessage, -) -> Result<(), std::io::Error> { - let line = serde_json::to_string(message).map_err(std::io::Error::other)?; - let mut stdin = stdin.lock().await; - stdin.write_all(line.as_bytes()).await?; - stdin.write_all(b"\n").await?; - stdin.flush().await?; - Ok(()) -} - -fn message_request_id(message: &NodeToHostMessage) -> &str { - match message { - NodeToHostMessage::ToolCall { tool_call } => &tool_call.request_id, - NodeToHostMessage::Yielded { request_id, .. } - | NodeToHostMessage::Terminated { request_id, .. } - | NodeToHostMessage::Result { request_id, .. } => request_id, - } -} - -fn prepend_script_status( - content_items: &mut Vec, - status: CodeModeExecutionStatus, - wall_time: Duration, -) { - let wall_time_seconds = ((wall_time.as_secs_f32()) * 10.0).round() / 10.0; - let header = format!( - "{}\nWall time {wall_time_seconds:.1} seconds\nOutput:\n", - match status { - CodeModeExecutionStatus::Completed => "Script completed".to_string(), - CodeModeExecutionStatus::Failed => "Script failed".to_string(), - CodeModeExecutionStatus::Running(session_id) => { - format!("Script running with session ID {session_id}") - } - CodeModeExecutionStatus::Terminated => "Script terminated".to_string(), - } - ); - content_items.insert(0, FunctionCallOutputContentItem::InputText { text: header }); -} - -fn build_source(user_code: &str, enabled_tools: &[EnabledTool]) -> Result { - let enabled_tools_json = serde_json::to_string(enabled_tools) - .map_err(|err| format!("failed to serialize enabled tools: {err}"))?; - Ok(CODE_MODE_BRIDGE_SOURCE - .replace( - "__CODE_MODE_ENABLED_TOOLS_PLACEHOLDER__", - &enabled_tools_json, - ) - .replace("__CODE_MODE_USER_CODE_PLACEHOLDER__", user_code)) -} - -fn truncate_code_mode_result( - items: Vec, - max_output_tokens_per_exec_call: Option, -) -> Vec { - let max_output_tokens = resolve_max_tokens(max_output_tokens_per_exec_call); - let policy = TruncationPolicy::Tokens(max_output_tokens); - if items - .iter() - .all(|item| matches!(item, FunctionCallOutputContentItem::InputText { .. })) - { - let (truncated_items, _) = - formatted_truncate_text_content_items_with_policy(&items, policy); - return truncated_items; - } - - truncate_function_output_items_with_policy(&items, policy) -} - -async fn build_enabled_tools(exec: &ExecContext) -> Vec { - let router = build_nested_router(exec).await; - let mut out = router - .specs() - .into_iter() - .map(|spec| augment_tool_spec_for_code_mode(spec, true)) - .filter_map(enabled_tool_from_spec) - .collect::>(); - out.sort_by(|left, right| left.tool_name.cmp(&right.tool_name)); - out.dedup_by(|left, right| left.tool_name == right.tool_name); - out -} - -fn enabled_tool_from_spec(spec: ToolSpec) -> Option { - let tool_name = spec.name().to_string(); - if tool_name == PUBLIC_TOOL_NAME || tool_name == WAIT_TOOL_NAME { - return None; - } - - let reference = code_mode_tool_reference(&tool_name); - - let (description, kind) = match spec { - ToolSpec::Function(tool) => (tool.description, CodeModeToolKind::Function), - ToolSpec::Freeform(tool) => (tool.description, CodeModeToolKind::Freeform), - ToolSpec::LocalShell {} - | ToolSpec::ImageGeneration { .. } - | ToolSpec::ToolSearch { .. } - | ToolSpec::WebSearch { .. } => { - return None; - } - }; - - Some(EnabledTool { - tool_name, - module_path: reference.module_path, - namespace: reference.namespace, - name: reference.tool_key, - description, - kind, - }) -} - -async fn build_nested_router(exec: &ExecContext) -> ToolRouter { - let nested_tools_config = exec.turn.tools_config.for_code_mode_nested_tools(); - let mcp_tools = exec - .session - .services - .mcp_connection_manager - .read() - .await - .list_all_tools() - .await - .into_iter() - .map(|(name, tool_info)| (name, tool_info.tool)) - .collect(); - - ToolRouter::from_config( - &nested_tools_config, - ToolRouterParams { - mcp_tools: Some(mcp_tools), - app_tools: None, - discoverable_tools: None, - dynamic_tools: exec.turn.dynamic_tools.as_slice(), - }, - ) -} - -async fn call_nested_tool( - exec: ExecContext, - tool_name: String, - input: Option, -) -> JsonValue { - if tool_name == PUBLIC_TOOL_NAME { - return JsonValue::String(format!("{PUBLIC_TOOL_NAME} cannot invoke itself")); - } - - let router = build_nested_router(&exec).await; - - let specs = router.specs(); - let payload = - if let Some((server, tool)) = exec.session.parse_mcp_tool_name(&tool_name, &None).await { - match serialize_function_tool_arguments(&tool_name, input) { - Ok(raw_arguments) => ToolPayload::Mcp { - server, - tool, - raw_arguments, - }, - Err(error) => return JsonValue::String(error), - } - } else { - match build_nested_tool_payload(&specs, &tool_name, input) { - Ok(payload) => payload, - Err(error) => return JsonValue::String(error), - } - }; - - let call = ToolCall { - tool_name: tool_name.clone(), - call_id: format!("{PUBLIC_TOOL_NAME}-{}", uuid::Uuid::new_v4()), - tool_namespace: None, - payload, - }; - let result = router - .dispatch_tool_call_with_code_mode_result( - Arc::clone(&exec.session), - Arc::clone(&exec.turn), - Arc::clone(&exec.tracker), - call, - ToolCallSource::CodeMode, - ) - .await; - - match result { - Ok(result) => result.code_mode_result(), - Err(error) => JsonValue::String(error.to_string()), - } -} - -fn tool_kind_for_spec(spec: &ToolSpec) -> CodeModeToolKind { - if matches!(spec, ToolSpec::Freeform(_)) { - CodeModeToolKind::Freeform - } else { - CodeModeToolKind::Function - } -} - -fn tool_kind_for_name(specs: &[ToolSpec], tool_name: &str) -> Result { - specs - .iter() - .find(|spec| spec.name() == tool_name) - .map(tool_kind_for_spec) - .ok_or_else(|| format!("tool `{tool_name}` is not enabled in {PUBLIC_TOOL_NAME}")) -} - -fn build_nested_tool_payload( - specs: &[ToolSpec], - tool_name: &str, - input: Option, -) -> Result { - let actual_kind = tool_kind_for_name(specs, tool_name)?; - match actual_kind { - CodeModeToolKind::Function => build_function_tool_payload(tool_name, input), - CodeModeToolKind::Freeform => build_freeform_tool_payload(tool_name, input), - } -} - -fn build_function_tool_payload( - tool_name: &str, - input: Option, -) -> Result { - let arguments = serialize_function_tool_arguments(tool_name, input)?; - Ok(ToolPayload::Function { arguments }) -} - -fn serialize_function_tool_arguments( - tool_name: &str, - input: Option, -) -> Result { - match input { - None => Ok("{}".to_string()), - Some(JsonValue::Object(map)) => serde_json::to_string(&JsonValue::Object(map)) - .map_err(|err| format!("failed to serialize tool `{tool_name}` arguments: {err}")), - Some(_) => Err(format!( - "tool `{tool_name}` expects a JSON object for arguments" - )), - } -} - -fn build_freeform_tool_payload( - tool_name: &str, - input: Option, -) -> Result { - match input { - Some(JsonValue::String(input)) => Ok(ToolPayload::Custom { input }), - _ => Err(format!("tool `{tool_name}` expects a string input")), - } -} - -fn output_content_items_from_json_values( - content_items: Vec, -) -> Result, String> { - content_items - .into_iter() - .enumerate() - .map(|(index, item)| { - serde_json::from_value(item).map_err(|err| { - format!("invalid {PUBLIC_TOOL_NAME} content item at index {index}: {err}") - }) - }) - .collect() -} diff --git a/codex-rs/core/src/tools/code_mode_bridge.js b/codex-rs/core/src/tools/code_mode/bridge.js similarity index 100% rename from codex-rs/core/src/tools/code_mode_bridge.js rename to codex-rs/core/src/tools/code_mode/bridge.js diff --git a/codex-rs/core/src/tools/code_mode/execute_handler.rs b/codex-rs/core/src/tools/code_mode/execute_handler.rs new file mode 100644 index 000000000..493c638da --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/execute_handler.rs @@ -0,0 +1,111 @@ +use async_trait::async_trait; + +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::SharedTurnDiffTracker; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +use super::CodeModeSessionProgress; +use super::ExecContext; +use super::PUBLIC_TOOL_NAME; +use super::build_enabled_tools; +use super::handle_node_message; +use super::protocol::HostToNodeMessage; +use super::protocol::build_source; + +pub struct CodeModeExecuteHandler; + +impl CodeModeExecuteHandler { + async fn execute( + &self, + session: std::sync::Arc, + turn: std::sync::Arc, + tracker: SharedTurnDiffTracker, + code: String, + ) -> Result { + let exec = ExecContext { + session, + turn, + tracker, + }; + let enabled_tools = build_enabled_tools(&exec).await; + let service = &exec.session.services.code_mode_service; + let stored_values = service.stored_values().await; + let source = + build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?; + let session_id = service.allocate_session_id().await; + let request_id = service.allocate_request_id().await; + let process_slot = service + .ensure_started() + .await + .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?; + let started_at = std::time::Instant::now(); + let message = HostToNodeMessage::Start { + request_id: request_id.clone(), + session_id, + enabled_tools, + stored_values, + source, + }; + let result = { + let mut process_slot = process_slot; + let Some(process) = process_slot.as_mut() else { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner failed to start" + ))); + }; + let message = process + .send(&request_id, &message) + .await + .map_err(|err| err.to_string()); + let message = match message { + Ok(message) => message, + Err(error) => return Err(FunctionCallError::RespondToModel(error)), + }; + handle_node_message(&exec, session_id, message, None, started_at).await + }; + match result { + Ok(CodeModeSessionProgress::Finished(output)) + | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), + Err(error) => Err(FunctionCallError::RespondToModel(error)), + } + } +} + +#[async_trait] +impl ToolHandler for CodeModeExecuteHandler { + type Output = FunctionToolOutput; + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Custom { .. }) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + tool_name, + payload, + .. + } = invocation; + + match payload { + ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => { + self.execute(session, turn, tracker, input).await + } + _ => Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} expects raw JavaScript source text" + ))), + } + } +} diff --git a/codex-rs/core/src/tools/code_mode/mod.rs b/codex-rs/core/src/tools/code_mode/mod.rs new file mode 100644 index 000000000..1b51cfc2f --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/mod.rs @@ -0,0 +1,399 @@ +mod execute_handler; +mod process; +mod protocol; +mod service; +mod wait_handler; +mod worker; + +use std::sync::Arc; +use std::time::Duration; + +use codex_protocol::models::FunctionCallOutputContentItem; +use serde_json::Value as JsonValue; + +use crate::client_common::tools::ToolSpec; +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::config::Config; +use crate::features::Feature; +use crate::tools::ToolRouter; +use crate::tools::code_mode_description::augment_tool_spec_for_code_mode; +use crate::tools::code_mode_description::code_mode_tool_reference; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::SharedTurnDiffTracker; +use crate::tools::context::ToolPayload; +use crate::tools::router::ToolCall; +use crate::tools::router::ToolCallSource; +use crate::tools::router::ToolRouterParams; +use crate::truncate::TruncationPolicy; +use crate::truncate::formatted_truncate_text_content_items_with_policy; +use crate::truncate::truncate_function_output_items_with_policy; +use crate::unified_exec::resolve_max_tokens; + +const CODE_MODE_RUNNER_SOURCE: &str = include_str!("runner.cjs"); +const CODE_MODE_BRIDGE_SOURCE: &str = include_str!("bridge.js"); + +pub(crate) const PUBLIC_TOOL_NAME: &str = "exec"; +pub(crate) const WAIT_TOOL_NAME: &str = "exec_wait"; +pub(crate) const DEFAULT_WAIT_YIELD_TIME_MS: u64 = 10_000; + +#[derive(Clone)] +pub(super) struct ExecContext { + pub(super) session: Arc, + pub(super) turn: Arc, + pub(super) tracker: SharedTurnDiffTracker, +} + +pub(crate) use execute_handler::CodeModeExecuteHandler; +pub(crate) use service::CodeModeService; +pub(crate) use wait_handler::CodeModeWaitHandler; + +enum CodeModeSessionProgress { + Finished(FunctionToolOutput), + Yielded { output: FunctionToolOutput }, +} + +enum CodeModeExecutionStatus { + Completed, + Failed, + Running(i32), + Terminated, +} + +pub(crate) fn instructions(config: &Config) -> Option { + if !config.features.enabled(Feature::CodeMode) { + return None; + } + + let mut section = String::from("## Exec\n"); + section.push_str(&format!( + "- Use `{PUBLIC_TOOL_NAME}` for JavaScript execution in a Node-backed `node:vm` context.\n", + )); + section.push_str(&format!( + "- `{PUBLIC_TOOL_NAME}` is a freeform/custom tool. Direct `{PUBLIC_TOOL_NAME}` calls must send raw JavaScript tool input. Do not wrap code in JSON, quotes, or markdown code fences.\n", + )); + section.push_str(&format!( + "- Direct tool calls remain available while `{PUBLIC_TOOL_NAME}` is enabled.\n", + )); + section.push_str(&format!( + "- `{PUBLIC_TOOL_NAME}` uses the same Node runtime resolution as `js_repl`. If needed, point `js_repl_node_path` at the Node binary you want Codex to use.\n", + )); + section.push_str("- Import nested tools from `tools.js`, for example `import { exec_command } from \"tools.js\"` or `import { ALL_TOOLS } from \"tools.js\"` to inspect the available `{ module, name, description }` entries. Namespaced tools are also available from `tools/.js`; MCP tools use `tools/mcp/.js`, for example `import { append_notebook_logs_chart } from \"tools/mcp/ologs.js\"`. Nested tool calls resolve to their code-mode result values.\n"); + section.push_str(&format!( + "- Import `{{ background, output_text, output_image, set_max_output_tokens_per_exec_call, set_yield_time, store, load }}` from `@openai/code_mode` (or `\"openai/code_mode\"`). `output_text(value)` surfaces text back to the model and stringifies non-string objects with `JSON.stringify(...)` when possible. `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs. `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, and `load(key)` returns a cloned stored value or `undefined`. `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate direct `{PUBLIC_TOOL_NAME}` returns; `{WAIT_TOOL_NAME}` uses its own `max_tokens` argument instead and defaults to `10000`. `set_yield_time(value)` asks `{PUBLIC_TOOL_NAME}` to return early if the script is still running after that many milliseconds so `{WAIT_TOOL_NAME}` can resume it later. `background()` returns a yielded `{PUBLIC_TOOL_NAME}` response immediately while the script keeps running in the background. The returned content starts with a separate `Script completed`, `Script failed`, or `Script running with session ID …` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker.\n", + )); + section.push_str(&format!( + "- If `{PUBLIC_TOOL_NAME}` returns `Script running with session ID …`, call `{WAIT_TOOL_NAME}` with that `session_id` to keep waiting for more output, completion, or termination.\n", + )); + section.push_str( + "- Function tools require JSON object arguments. Freeform tools require raw strings.\n", + ); + section.push_str("- `add_content(value)` remains available for compatibility. It is synchronous and accepts a content item, an array of content items, or a string. Structured nested-tool results should be converted to text first, for example with `JSON.stringify(...)`.\n"); + section + .push_str("- Only content passed to `output_text(...)`, `output_image(...)`, or `add_content(value)` is surfaced back to the model."); + Some(section) +} + +async fn handle_node_message( + exec: &ExecContext, + session_id: i32, + message: protocol::NodeToHostMessage, + poll_max_output_tokens: Option>, + started_at: std::time::Instant, +) -> Result { + match message { + protocol::NodeToHostMessage::ToolCall { .. } => Err(protocol::unexpected_tool_call_error()), + protocol::NodeToHostMessage::Yielded { content_items, .. } => { + let mut delta_items = output_content_items_from_json_values(content_items)?; + delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); + prepend_script_status( + &mut delta_items, + CodeModeExecutionStatus::Running(session_id), + started_at.elapsed(), + ); + Ok(CodeModeSessionProgress::Yielded { + output: FunctionToolOutput::from_content(delta_items, Some(true)), + }) + } + protocol::NodeToHostMessage::Terminated { content_items, .. } => { + let mut delta_items = output_content_items_from_json_values(content_items)?; + delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); + prepend_script_status( + &mut delta_items, + CodeModeExecutionStatus::Terminated, + started_at.elapsed(), + ); + Ok(CodeModeSessionProgress::Finished( + FunctionToolOutput::from_content(delta_items, Some(true)), + )) + } + protocol::NodeToHostMessage::Result { + content_items, + stored_values, + error_text, + max_output_tokens_per_exec_call, + .. + } => { + exec.session + .services + .code_mode_service + .replace_stored_values(stored_values) + .await; + let mut delta_items = output_content_items_from_json_values(content_items)?; + let success = error_text.is_none(); + if let Some(error_text) = error_text { + delta_items.push(FunctionCallOutputContentItem::InputText { + text: format!("Script error:\n{error_text}"), + }); + } + + let mut delta_items = truncate_code_mode_result( + delta_items, + poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), + ); + prepend_script_status( + &mut delta_items, + if success { + CodeModeExecutionStatus::Completed + } else { + CodeModeExecutionStatus::Failed + }, + started_at.elapsed(), + ); + Ok(CodeModeSessionProgress::Finished( + FunctionToolOutput::from_content(delta_items, Some(success)), + )) + } + } +} + +fn prepend_script_status( + content_items: &mut Vec, + status: CodeModeExecutionStatus, + wall_time: Duration, +) { + let wall_time_seconds = ((wall_time.as_secs_f32()) * 10.0).round() / 10.0; + let header = format!( + "{}\nWall time {wall_time_seconds:.1} seconds\nOutput:\n", + match status { + CodeModeExecutionStatus::Completed => "Script completed".to_string(), + CodeModeExecutionStatus::Failed => "Script failed".to_string(), + CodeModeExecutionStatus::Running(session_id) => { + format!("Script running with session ID {session_id}") + } + CodeModeExecutionStatus::Terminated => "Script terminated".to_string(), + } + ); + content_items.insert(0, FunctionCallOutputContentItem::InputText { text: header }); +} + +fn truncate_code_mode_result( + items: Vec, + max_output_tokens_per_exec_call: Option, +) -> Vec { + let max_output_tokens = resolve_max_tokens(max_output_tokens_per_exec_call); + let policy = TruncationPolicy::Tokens(max_output_tokens); + if items + .iter() + .all(|item| matches!(item, FunctionCallOutputContentItem::InputText { .. })) + { + let (truncated_items, _) = + formatted_truncate_text_content_items_with_policy(&items, policy); + return truncated_items; + } + + truncate_function_output_items_with_policy(&items, policy) +} + +fn output_content_items_from_json_values( + content_items: Vec, +) -> Result, String> { + content_items + .into_iter() + .enumerate() + .map(|(index, item)| { + serde_json::from_value(item).map_err(|err| { + format!("invalid {PUBLIC_TOOL_NAME} content item at index {index}: {err}") + }) + }) + .collect() +} + +async fn build_enabled_tools(exec: &ExecContext) -> Vec { + let router = build_nested_router(exec).await; + let mut out = router + .specs() + .into_iter() + .map(|spec| augment_tool_spec_for_code_mode(spec, true)) + .filter_map(enabled_tool_from_spec) + .collect::>(); + out.sort_by(|left, right| left.tool_name.cmp(&right.tool_name)); + out.dedup_by(|left, right| left.tool_name == right.tool_name); + out +} + +fn enabled_tool_from_spec(spec: ToolSpec) -> Option { + let tool_name = spec.name().to_string(); + if tool_name == PUBLIC_TOOL_NAME || tool_name == WAIT_TOOL_NAME { + return None; + } + + let reference = code_mode_tool_reference(&tool_name); + let (description, kind) = match spec { + ToolSpec::Function(tool) => (tool.description, protocol::CodeModeToolKind::Function), + ToolSpec::Freeform(tool) => (tool.description, protocol::CodeModeToolKind::Freeform), + ToolSpec::LocalShell {} + | ToolSpec::ImageGeneration { .. } + | ToolSpec::ToolSearch { .. } + | ToolSpec::WebSearch { .. } => { + return None; + } + }; + + Some(protocol::EnabledTool { + tool_name, + module_path: reference.module_path, + namespace: reference.namespace, + name: reference.tool_key, + description, + kind, + }) +} + +async fn build_nested_router(exec: &ExecContext) -> ToolRouter { + let nested_tools_config = exec.turn.tools_config.for_code_mode_nested_tools(); + let mcp_tools = exec + .session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await + .into_iter() + .map(|(name, tool_info)| (name, tool_info.tool)) + .collect(); + + ToolRouter::from_config( + &nested_tools_config, + ToolRouterParams { + mcp_tools: Some(mcp_tools), + app_tools: None, + discoverable_tools: None, + dynamic_tools: exec.turn.dynamic_tools.as_slice(), + }, + ) +} + +async fn call_nested_tool( + exec: ExecContext, + tool_name: String, + input: Option, +) -> JsonValue { + if tool_name == PUBLIC_TOOL_NAME { + return JsonValue::String(format!("{PUBLIC_TOOL_NAME} cannot invoke itself")); + } + + let router = build_nested_router(&exec).await; + let specs = router.specs(); + let payload = + if let Some((server, tool)) = exec.session.parse_mcp_tool_name(&tool_name, &None).await { + match serialize_function_tool_arguments(&tool_name, input) { + Ok(raw_arguments) => ToolPayload::Mcp { + server, + tool, + raw_arguments, + }, + Err(error) => return JsonValue::String(error), + } + } else { + match build_nested_tool_payload(&specs, &tool_name, input) { + Ok(payload) => payload, + Err(error) => return JsonValue::String(error), + } + }; + + let call = ToolCall { + tool_name: tool_name.clone(), + call_id: format!("{PUBLIC_TOOL_NAME}-{}", uuid::Uuid::new_v4()), + tool_namespace: None, + payload, + }; + let result = router + .dispatch_tool_call_with_code_mode_result( + exec.session.clone(), + exec.turn.clone(), + exec.tracker.clone(), + call, + ToolCallSource::CodeMode, + ) + .await; + + match result { + Ok(result) => result.code_mode_result(), + Err(error) => JsonValue::String(error.to_string()), + } +} + +fn tool_kind_for_spec(spec: &ToolSpec) -> protocol::CodeModeToolKind { + if matches!(spec, ToolSpec::Freeform(_)) { + protocol::CodeModeToolKind::Freeform + } else { + protocol::CodeModeToolKind::Function + } +} + +fn tool_kind_for_name( + specs: &[ToolSpec], + tool_name: &str, +) -> Result { + specs + .iter() + .find(|spec| spec.name() == tool_name) + .map(tool_kind_for_spec) + .ok_or_else(|| format!("tool `{tool_name}` is not enabled in {PUBLIC_TOOL_NAME}")) +} + +fn build_nested_tool_payload( + specs: &[ToolSpec], + tool_name: &str, + input: Option, +) -> Result { + let actual_kind = tool_kind_for_name(specs, tool_name)?; + match actual_kind { + protocol::CodeModeToolKind::Function => build_function_tool_payload(tool_name, input), + protocol::CodeModeToolKind::Freeform => build_freeform_tool_payload(tool_name, input), + } +} + +fn build_function_tool_payload( + tool_name: &str, + input: Option, +) -> Result { + let arguments = serialize_function_tool_arguments(tool_name, input)?; + Ok(ToolPayload::Function { arguments }) +} + +fn serialize_function_tool_arguments( + tool_name: &str, + input: Option, +) -> Result { + match input { + None => Ok("{}".to_string()), + Some(JsonValue::Object(map)) => serde_json::to_string(&JsonValue::Object(map)) + .map_err(|err| format!("failed to serialize tool `{tool_name}` arguments: {err}")), + Some(_) => Err(format!( + "tool `{tool_name}` expects a JSON object for arguments" + )), + } +} + +fn build_freeform_tool_payload( + tool_name: &str, + input: Option, +) -> Result { + match input { + Some(JsonValue::String(input)) => Ok(ToolPayload::Custom { input }), + _ => Err(format!("tool `{tool_name}` expects a string input")), + } +} diff --git a/codex-rs/core/src/tools/code_mode/process.rs b/codex-rs/core/src/tools/code_mode/process.rs new file mode 100644 index 000000000..d27296fca --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/process.rs @@ -0,0 +1,172 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; +use tracing::warn; + +use super::CODE_MODE_RUNNER_SOURCE; +use super::PUBLIC_TOOL_NAME; +use super::protocol::CodeModeToolCall; +use super::protocol::HostToNodeMessage; +use super::protocol::NodeToHostMessage; +use super::protocol::message_request_id; + +pub(super) struct CodeModeProcess { + pub(super) child: tokio::process::Child, + pub(super) stdin: Arc>, + pub(super) stdout_task: JoinHandle<()>, + pub(super) response_waiters: Arc>>>, + pub(super) tool_call_rx: Arc>>, +} + +impl CodeModeProcess { + pub(super) async fn send( + &mut self, + request_id: &str, + message: &HostToNodeMessage, + ) -> Result { + if self.stdout_task.is_finished() { + return Err(std::io::Error::other(format!( + "{PUBLIC_TOOL_NAME} runner is not available" + ))); + } + + let (tx, rx) = oneshot::channel(); + self.response_waiters + .lock() + .await + .insert(request_id.to_string(), tx); + if let Err(err) = write_message(&self.stdin, message).await { + self.response_waiters.lock().await.remove(request_id); + return Err(err); + } + + match rx.await { + Ok(message) => Ok(message), + Err(_) => Err(std::io::Error::other(format!( + "{PUBLIC_TOOL_NAME} runner is not available" + ))), + } + } + + pub(super) fn has_exited(&mut self) -> Result { + self.child + .try_wait() + .map(|status| status.is_some()) + .map_err(std::io::Error::other) + } +} + +pub(super) async fn spawn_code_mode_process( + node_path: &std::path::Path, +) -> Result { + let mut cmd = tokio::process::Command::new(node_path); + cmd.arg("--experimental-vm-modules"); + cmd.arg("--eval"); + cmd.arg(CODE_MODE_RUNNER_SOURCE); + cmd.stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true); + + let mut child = cmd.spawn().map_err(std::io::Error::other)?; + let stdout = child.stdout.take().ok_or_else(|| { + std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdout")) + })?; + let stderr = child.stderr.take().ok_or_else(|| { + std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stderr")) + })?; + let stdin = child + .stdin + .take() + .ok_or_else(|| std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdin")))?; + let stdin = Arc::new(Mutex::new(stdin)); + let response_waiters = Arc::new(Mutex::new(HashMap::< + String, + oneshot::Sender, + >::new())); + let (tool_call_tx, tool_call_rx) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + let mut reader = BufReader::new(stderr); + let mut buf = Vec::new(); + match reader.read_to_end(&mut buf).await { + Ok(_) => { + let stderr = String::from_utf8_lossy(&buf).trim().to_string(); + if !stderr.is_empty() { + warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}"); + } + } + Err(err) => { + warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}"); + } + } + }); + let stdout_task = tokio::spawn({ + let response_waiters = Arc::clone(&response_waiters); + async move { + let mut stdout_lines = BufReader::new(stdout).lines(); + loop { + let line = match stdout_lines.next_line().await { + Ok(line) => line, + Err(err) => { + warn!("failed to read {PUBLIC_TOOL_NAME} stdout: {err}"); + break; + } + }; + let Some(line) = line else { + break; + }; + if line.trim().is_empty() { + continue; + } + let message: NodeToHostMessage = match serde_json::from_str(&line) { + Ok(message) => message, + Err(err) => { + warn!("failed to parse {PUBLIC_TOOL_NAME} stdout message: {err}"); + break; + } + }; + match message { + NodeToHostMessage::ToolCall { tool_call } => { + let _ = tool_call_tx.send(tool_call); + } + message => { + let request_id = message_request_id(&message).to_string(); + if let Some(waiter) = response_waiters.lock().await.remove(&request_id) { + let _ = waiter.send(message); + } + } + } + } + response_waiters.lock().await.clear(); + } + }); + + Ok(CodeModeProcess { + child, + stdin, + stdout_task, + response_waiters, + tool_call_rx: Arc::new(Mutex::new(tool_call_rx)), + }) +} + +pub(super) async fn write_message( + stdin: &Arc>, + message: &HostToNodeMessage, +) -> Result<(), std::io::Error> { + let line = serde_json::to_string(message).map_err(std::io::Error::other)?; + let mut stdin = stdin.lock().await; + stdin.write_all(line.as_bytes()).await?; + stdin.write_all(b"\n").await?; + stdin.flush().await?; + Ok(()) +} diff --git a/codex-rs/core/src/tools/code_mode/protocol.rs b/codex-rs/core/src/tools/code_mode/protocol.rs new file mode 100644 index 000000000..fe0ab861f --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/protocol.rs @@ -0,0 +1,115 @@ +use std::collections::HashMap; + +use serde::Deserialize; +use serde::Serialize; +use serde_json::Value as JsonValue; + +use super::CODE_MODE_BRIDGE_SOURCE; +use super::PUBLIC_TOOL_NAME; + +#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub(super) enum CodeModeToolKind { + Function, + Freeform, +} + +#[derive(Clone, Debug, Serialize)] +pub(super) struct EnabledTool { + pub(super) tool_name: String, + #[serde(rename = "module")] + pub(super) module_path: String, + pub(super) namespace: Vec, + pub(super) name: String, + pub(super) description: String, + pub(super) kind: CodeModeToolKind, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +pub(super) struct CodeModeToolCall { + pub(super) request_id: String, + pub(super) id: String, + pub(super) name: String, + #[serde(default)] + pub(super) input: Option, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub(super) enum HostToNodeMessage { + Start { + request_id: String, + session_id: i32, + enabled_tools: Vec, + stored_values: HashMap, + source: String, + }, + Poll { + request_id: String, + session_id: i32, + yield_time_ms: u64, + }, + Terminate { + request_id: String, + session_id: i32, + }, + Response { + request_id: String, + id: String, + code_mode_result: JsonValue, + }, +} + +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub(super) enum NodeToHostMessage { + ToolCall { + #[serde(flatten)] + tool_call: CodeModeToolCall, + }, + Yielded { + request_id: String, + content_items: Vec, + }, + Terminated { + request_id: String, + content_items: Vec, + }, + Result { + request_id: String, + content_items: Vec, + stored_values: HashMap, + #[serde(default)] + error_text: Option, + #[serde(default)] + max_output_tokens_per_exec_call: Option, + }, +} + +pub(super) fn build_source( + user_code: &str, + enabled_tools: &[EnabledTool], +) -> Result { + let enabled_tools_json = serde_json::to_string(enabled_tools) + .map_err(|err| format!("failed to serialize enabled tools: {err}"))?; + Ok(CODE_MODE_BRIDGE_SOURCE + .replace( + "__CODE_MODE_ENABLED_TOOLS_PLACEHOLDER__", + &enabled_tools_json, + ) + .replace("__CODE_MODE_USER_CODE_PLACEHOLDER__", user_code)) +} + +pub(super) fn message_request_id(message: &NodeToHostMessage) -> &str { + match message { + NodeToHostMessage::ToolCall { tool_call } => &tool_call.request_id, + NodeToHostMessage::Yielded { request_id, .. } + | NodeToHostMessage::Terminated { request_id, .. } + | NodeToHostMessage::Result { request_id, .. } => request_id, + } +} + +pub(super) fn unexpected_tool_call_error() -> String { + format!("{PUBLIC_TOOL_NAME} received an unexpected tool call response") +} diff --git a/codex-rs/core/src/tools/code_mode_runner.cjs b/codex-rs/core/src/tools/code_mode/runner.cjs similarity index 100% rename from codex-rs/core/src/tools/code_mode_runner.cjs rename to codex-rs/core/src/tools/code_mode/runner.cjs diff --git a/codex-rs/core/src/tools/code_mode/service.rs b/codex-rs/core/src/tools/code_mode/service.rs new file mode 100644 index 000000000..c7ca3c372 --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/service.rs @@ -0,0 +1,104 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use serde_json::Value as JsonValue; +use tokio::sync::Mutex; +use tracing::warn; + +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::features::Feature; +use crate::tools::context::SharedTurnDiffTracker; +use crate::tools::js_repl::resolve_compatible_node; + +use super::ExecContext; +use super::PUBLIC_TOOL_NAME; +use super::process::CodeModeProcess; +use super::process::spawn_code_mode_process; +use super::worker::CodeModeWorker; + +pub(crate) struct CodeModeService { + js_repl_node_path: Option, + stored_values: Mutex>, + process: Arc>>, + next_session_id: Mutex, +} + +impl CodeModeService { + pub(crate) fn new(js_repl_node_path: Option) -> Self { + Self { + js_repl_node_path, + stored_values: Mutex::new(HashMap::new()), + process: Arc::new(Mutex::new(None)), + next_session_id: Mutex::new(1), + } + } + + pub(crate) async fn stored_values(&self) -> HashMap { + self.stored_values.lock().await.clone() + } + + pub(crate) async fn replace_stored_values(&self, values: HashMap) { + *self.stored_values.lock().await = values; + } + + pub(super) async fn ensure_started( + &self, + ) -> Result>, std::io::Error> { + let mut process_slot = self.process.lock().await; + let needs_spawn = match process_slot.as_mut() { + Some(process) => !matches!(process.has_exited(), Ok(false)), + None => true, + }; + if needs_spawn { + let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()) + .await + .map_err(std::io::Error::other)?; + *process_slot = Some(spawn_code_mode_process(&node_path).await?); + } + drop(process_slot); + Ok(self.process.clone().lock_owned().await) + } + + pub(crate) async fn start_turn_worker( + &self, + session: &Arc, + turn: &Arc, + tracker: &SharedTurnDiffTracker, + ) -> Option { + if !turn.features.enabled(Feature::CodeMode) { + return None; + } + let exec = ExecContext { + session: Arc::clone(session), + turn: Arc::clone(turn), + tracker: Arc::clone(tracker), + }; + let mut process_slot = match self.ensure_started().await { + Ok(process_slot) => process_slot, + Err(err) => { + warn!("failed to start {PUBLIC_TOOL_NAME} worker for turn: {err}"); + return None; + } + }; + let Some(process) = process_slot.as_mut() else { + warn!( + "failed to start {PUBLIC_TOOL_NAME} worker for turn: {PUBLIC_TOOL_NAME} runner failed to start" + ); + return None; + }; + Some(process.worker(exec)) + } + + pub(crate) async fn allocate_session_id(&self) -> i32 { + let mut next_session_id = self.next_session_id.lock().await; + let session_id = *next_session_id; + *next_session_id = next_session_id.saturating_add(1); + session_id + } + + pub(crate) async fn allocate_request_id(&self) -> String { + uuid::Uuid::new_v4().to_string() + } +} diff --git a/codex-rs/core/src/tools/code_mode/wait_handler.rs b/codex-rs/core/src/tools/code_mode/wait_handler.rs new file mode 100644 index 000000000..ddfce8eb3 --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/wait_handler.rs @@ -0,0 +1,137 @@ +use async_trait::async_trait; +use serde::Deserialize; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +use super::CodeModeSessionProgress; +use super::DEFAULT_WAIT_YIELD_TIME_MS; +use super::ExecContext; +use super::PUBLIC_TOOL_NAME; +use super::WAIT_TOOL_NAME; +use super::handle_node_message; +use super::protocol::HostToNodeMessage; + +pub struct CodeModeWaitHandler; + +#[derive(Debug, Deserialize)] +struct ExecWaitArgs { + session_id: i32, + #[serde(default = "default_wait_yield_time_ms")] + yield_time_ms: u64, + #[serde(default)] + max_tokens: Option, + #[serde(default)] + terminate: bool, +} + +fn default_wait_yield_time_ms() -> u64 { + DEFAULT_WAIT_YIELD_TIME_MS +} + +fn parse_arguments(arguments: &str) -> Result +where + T: for<'de> Deserialize<'de>, +{ + serde_json::from_str(arguments).map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}")) + }) +} + +#[async_trait] +impl ToolHandler for CodeModeWaitHandler { + type Output = FunctionToolOutput; + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + tool_name, + payload, + .. + } = invocation; + + match payload { + ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => { + let args: ExecWaitArgs = parse_arguments(&arguments)?; + let exec = ExecContext { + session, + turn, + tracker, + }; + let request_id = exec + .session + .services + .code_mode_service + .allocate_request_id() + .await; + let started_at = std::time::Instant::now(); + let message = if args.terminate { + HostToNodeMessage::Terminate { + request_id: request_id.clone(), + session_id: args.session_id, + } + } else { + HostToNodeMessage::Poll { + request_id: request_id.clone(), + session_id: args.session_id, + yield_time_ms: args.yield_time_ms, + } + }; + let process_slot = exec + .session + .services + .code_mode_service + .ensure_started() + .await + .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?; + let result = { + let mut process_slot = process_slot; + let Some(process) = process_slot.as_mut() else { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner failed to start" + ))); + }; + if !matches!(process.has_exited(), Ok(false)) { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner failed to start" + ))); + } + let message = process + .send(&request_id, &message) + .await + .map_err(|err| err.to_string()); + let message = match message { + Ok(message) => message, + Err(error) => return Err(FunctionCallError::RespondToModel(error)), + }; + handle_node_message( + &exec, + args.session_id, + message, + Some(args.max_tokens), + started_at, + ) + .await + }; + match result { + Ok(CodeModeSessionProgress::Finished(output)) + | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), + Err(error) => Err(FunctionCallError::RespondToModel(error)), + } + } + _ => Err(FunctionCallError::RespondToModel(format!( + "{WAIT_TOOL_NAME} expects JSON arguments" + ))), + } + } +} diff --git a/codex-rs/core/src/tools/code_mode/worker.rs b/codex-rs/core/src/tools/code_mode/worker.rs new file mode 100644 index 000000000..ce739d637 --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/worker.rs @@ -0,0 +1,59 @@ +use tokio::sync::oneshot; +use tracing::warn; + +use super::ExecContext; +use super::PUBLIC_TOOL_NAME; +use super::call_nested_tool; +use super::process::CodeModeProcess; +use super::process::write_message; +use super::protocol::HostToNodeMessage; +pub(crate) struct CodeModeWorker { + shutdown_tx: Option>, +} + +impl Drop for CodeModeWorker { + fn drop(&mut self) { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + } +} + +impl CodeModeProcess { + pub(super) fn worker(&self, exec: ExecContext) -> CodeModeWorker { + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let stdin = self.stdin.clone(); + let tool_call_rx = self.tool_call_rx.clone(); + tokio::spawn(async move { + loop { + let tool_call = tokio::select! { + _ = &mut shutdown_rx => break, + tool_call = async { + let mut tool_call_rx = tool_call_rx.lock().await; + tool_call_rx.recv().await + } => tool_call, + }; + let Some(tool_call) = tool_call else { + break; + }; + let exec = exec.clone(); + let stdin = stdin.clone(); + tokio::spawn(async move { + let response = HostToNodeMessage::Response { + request_id: tool_call.request_id, + id: tool_call.id, + code_mode_result: call_nested_tool(exec, tool_call.name, tool_call.input) + .await, + }; + if let Err(err) = write_message(&stdin, &response).await { + warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}"); + } + }); + } + }); + + CodeModeWorker { + shutdown_tx: Some(shutdown_tx), + } + } +} diff --git a/codex-rs/core/src/tools/handlers/code_mode.rs b/codex-rs/core/src/tools/handlers/code_mode.rs deleted file mode 100644 index fe4a23965..000000000 --- a/codex-rs/core/src/tools/handlers/code_mode.rs +++ /dev/null @@ -1,104 +0,0 @@ -use async_trait::async_trait; -use serde::Deserialize; - -use crate::function_tool::FunctionCallError; -use crate::tools::code_mode; -use crate::tools::code_mode::DEFAULT_WAIT_YIELD_TIME_MS; -use crate::tools::code_mode::PUBLIC_TOOL_NAME; -use crate::tools::code_mode::WAIT_TOOL_NAME; -use crate::tools::context::FunctionToolOutput; -use crate::tools::context::ToolInvocation; -use crate::tools::context::ToolPayload; -use crate::tools::handlers::parse_arguments; -use crate::tools::registry::ToolHandler; -use crate::tools::registry::ToolKind; - -pub struct CodeModeHandler; -pub struct CodeModeWaitHandler; - -#[derive(Debug, Deserialize)] -struct ExecWaitArgs { - session_id: i32, - #[serde(default = "default_wait_yield_time_ms")] - yield_time_ms: u64, - #[serde(default)] - max_tokens: Option, - #[serde(default)] - terminate: bool, -} - -fn default_wait_yield_time_ms() -> u64 { - DEFAULT_WAIT_YIELD_TIME_MS -} - -#[async_trait] -impl ToolHandler for CodeModeHandler { - type Output = FunctionToolOutput; - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::Custom { .. }) - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - tool_name, - payload, - .. - } = invocation; - - match payload { - ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => { - code_mode::execute(session, turn, tracker, input).await - } - _ => Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} expects raw JavaScript source text" - ))), - } - } -} - -#[async_trait] -impl ToolHandler for CodeModeWaitHandler { - type Output = FunctionToolOutput; - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - tool_name, - payload, - .. - } = invocation; - - match payload { - ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => { - let args: ExecWaitArgs = parse_arguments(&arguments)?; - code_mode::wait( - session, - turn, - tracker, - args.session_id, - args.yield_time_ms, - args.max_tokens, - args.terminate, - ) - .await - } - _ => Err(FunctionCallError::RespondToModel(format!( - "{WAIT_TOOL_NAME} expects JSON arguments" - ))), - } - } -} diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index 068031b5a..5d8aaeba6 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -1,7 +1,6 @@ pub(crate) mod agent_jobs; pub mod apply_patch; mod artifacts; -mod code_mode; mod dynamic; mod grep_files; mod js_repl; @@ -32,10 +31,10 @@ use crate::function_tool::FunctionCallError; use crate::sandboxing::SandboxPermissions; use crate::sandboxing::merge_permission_profiles; use crate::sandboxing::normalize_additional_permissions; +pub(crate) use crate::tools::code_mode::CodeModeExecuteHandler; +pub(crate) use crate::tools::code_mode::CodeModeWaitHandler; pub use apply_patch::ApplyPatchHandler; pub use artifacts::ArtifactsHandler; -pub use code_mode::CodeModeHandler; -pub use code_mode::CodeModeWaitHandler; use codex_protocol::models::PermissionProfile; use codex_protocol::protocol::AskForApproval; pub use dynamic::DynamicToolHandler; diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index ccba578e4..ab41a3b36 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -2295,7 +2295,7 @@ pub(crate) fn build_specs_with_discoverable_tools( ) -> ToolRegistryBuilder { use crate::tools::handlers::ApplyPatchHandler; use crate::tools::handlers::ArtifactsHandler; - use crate::tools::handlers::CodeModeHandler; + use crate::tools::handlers::CodeModeExecuteHandler; use crate::tools::handlers::CodeModeWaitHandler; use crate::tools::handlers::DynamicToolHandler; use crate::tools::handlers::GrepFilesHandler; @@ -2334,7 +2334,7 @@ pub(crate) fn build_specs_with_discoverable_tools( default_mode_request_user_input: config.default_mode_request_user_input, }); let tool_suggest_handler = Arc::new(ToolSuggestHandler); - let code_mode_handler = Arc::new(CodeModeHandler); + let code_mode_handler = Arc::new(CodeModeExecuteHandler); let code_mode_wait_handler = Arc::new(CodeModeWaitHandler); let js_repl_handler = Arc::new(JsReplHandler); let js_repl_reset_handler = Arc::new(JsReplResetHandler);