From f6c6128fc705205b9f1f2bff50cc2710046ce8de Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 11 Mar 2026 23:13:54 -0700 Subject: [PATCH] Support waiting for code_mode sessions (#14295) ## Summary - persist the code mode runner process in the session-scoped code mode store - switch the runner protocol from `init` to `start` with explicit session ids - handle runner-side session processing without the init waiter queue ## Validation - just fmt - cargo check -p codex-core - node --check codex-rs/core/src/tools/code_mode_runner.cjs --- codex-rs/core/src/codex.rs | 4 +- codex-rs/core/src/codex_tests.rs | 8 +- codex-rs/core/src/state/service.rs | 26 +- codex-rs/core/src/tools/code_mode.rs | 569 +++++++--- codex-rs/core/src/tools/code_mode_runner.cjs | 978 ++++++++++++------ codex-rs/core/src/tools/handlers/code_mode.rs | 83 +- codex-rs/core/src/tools/handlers/mod.rs | 1 + codex-rs/core/src/tools/spec.rs | 66 +- codex-rs/core/tests/suite/code_mode.rs | 821 +++++++++++++++ 9 files changed, 2059 insertions(+), 497 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 5fa4cffa3..b5695a0bf 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1737,7 +1737,9 @@ impl Session { config.features.enabled(Feature::RuntimeMetrics), Self::build_model_client_beta_features_header(config.as_ref()), ), - code_mode_store: Default::default(), + code_mode_service: crate::tools::code_mode::CodeModeService::new( + config.js_repl_node_path.clone(), + ), }; let js_repl = Arc::new(JsReplHandle::with_node_path( config.js_repl_node_path.clone(), diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index f1892449b..7e838642a 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -2165,7 +2165,9 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { config.features.enabled(Feature::RuntimeMetrics), Session::build_model_client_beta_features_header(config.as_ref()), ), - code_mode_store: Default::default(), + code_mode_service: crate::tools::code_mode::CodeModeService::new( + config.js_repl_node_path.clone(), + ), }; let js_repl = Arc::new(JsReplHandle::with_node_path( config.js_repl_node_path.clone(), @@ -2802,7 +2804,9 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( config.features.enabled(Feature::RuntimeMetrics), Session::build_model_client_beta_features_header(config.as_ref()), ), - code_mode_store: Default::default(), + code_mode_service: crate::tools::code_mode::CodeModeService::new( + config.js_repl_node_path.clone(), + ), }; let js_repl = Arc::new(JsReplHandle::with_node_path( config.js_repl_node_path.clone(), diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 5c0a741a1..851618c00 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -15,6 +15,7 @@ use crate::models_manager::manager::ModelsManager; use crate::plugins::PluginsManager; use crate::skills::SkillsManager; use crate::state_db::StateDbHandle; +use crate::tools::code_mode::CodeModeService; use crate::tools::network_approval::NetworkApprovalService; use crate::tools::runtimes::ExecveSessionApproval; use crate::tools::sandboxing::ApprovalStore; @@ -22,35 +23,12 @@ use crate::unified_exec::UnifiedExecProcessManager; use codex_hooks::Hooks; use codex_otel::SessionTelemetry; use codex_utils_absolute_path::AbsolutePathBuf; -use serde_json::Value as JsonValue; use std::path::PathBuf; use tokio::sync::Mutex; use tokio::sync::RwLock; use tokio::sync::watch; use tokio_util::sync::CancellationToken; -pub(crate) struct CodeModeStoreService { - stored_values: Mutex>, -} - -impl Default for CodeModeStoreService { - fn default() -> Self { - Self { - stored_values: Mutex::new(HashMap::new()), - } - } -} - -impl CodeModeStoreService { - pub(crate) async fn stored_values(&self) -> HashMap { - self.stored_values.lock().await.clone() - } - - pub(crate) async fn replace_stored_values(&self, values: HashMap) { - *self.stored_values.lock().await = values; - } -} - pub(crate) struct SessionServices { pub(crate) mcp_connection_manager: Arc>, pub(crate) mcp_startup_cancellation_token: Mutex, @@ -82,5 +60,5 @@ pub(crate) struct SessionServices { pub(crate) state_db: Option, /// Session-scoped model client shared across turns. pub(crate) model_client: ModelClient, - pub(crate) code_mode_store: CodeModeStoreService, + pub(crate) code_mode_service: CodeModeService, } diff --git a/codex-rs/core/src/tools/code_mode.rs b/codex-rs/core/src/tools/code_mode.rs index 110588469..a6e6227be 100644 --- a/codex-rs/core/src/tools/code_mode.rs +++ b/codex-rs/core/src/tools/code_mode.rs @@ -1,4 +1,6 @@ use std::collections::HashMap; +use std::collections::VecDeque; +use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; @@ -6,7 +8,6 @@ use crate::client_common::tools::ToolSpec; use crate::codex::Session; use crate::codex::TurnContext; use crate::config::Config; -use crate::exec_env::create_env; use crate::features::Feature; use crate::function_tool::FunctionCallError; use crate::tools::ToolRouter; @@ -31,10 +32,15 @@ use tokio::io::AsyncBufReadExt; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tracing::warn; const CODE_MODE_RUNNER_SOURCE: &str = include_str!("code_mode_runner.cjs"); const CODE_MODE_BRIDGE_SOURCE: &str = include_str!("code_mode_bridge.js"); pub(crate) const PUBLIC_TOOL_NAME: &str = "exec"; +pub(crate) const WAIT_TOOL_NAME: &str = "exec_wait"; +pub(crate) const DEFAULT_WAIT_YIELD_TIME_MS: u64 = 10_000; #[derive(Clone)] struct ExecContext { @@ -43,6 +49,133 @@ struct ExecContext { tracker: SharedTurnDiffTracker, } +pub(crate) struct CodeModeProcess { + child: tokio::process::Child, + stdin: tokio::process::ChildStdin, + stdout_lines: tokio::io::Lines>, + stderr_task: Option>, + pending_messages: HashMap>, +} + +impl CodeModeProcess { + async fn write(&mut self, message: &HostToNodeMessage) -> Result<(), std::io::Error> { + let line = serde_json::to_string(message).map_err(std::io::Error::other)?; + self.stdin.write_all(line.as_bytes()).await?; + self.stdin.write_all(b"\n").await?; + self.stdin.flush().await?; + Ok(()) + } + + async fn read(&mut self, session_id: i32) -> Result { + if let Some(message) = self + .pending_messages + .get_mut(&session_id) + .and_then(VecDeque::pop_front) + { + return Ok(message); + } + + loop { + let Some(line) = self.stdout_lines.next_line().await? else { + match self.wait_for_exit().await { + Ok(status) => { + self.join_stderr_task().await; + return Err(std::io::Error::other(format!( + "{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})" + ))); + } + Err(err) => return Err(std::io::Error::other(err)), + } + }; + if line.trim().is_empty() { + continue; + } + let message: NodeToHostMessage = + serde_json::from_str(&line).map_err(std::io::Error::other)?; + let message_session_id = message_session_id(&message); + if message_session_id == session_id { + return Ok(message); + } + self.pending_messages + .entry(message_session_id) + .or_default() + .push_back(message); + } + } + + fn has_exited(&mut self) -> Result { + self.child + .try_wait() + .map(|status| status.is_some()) + .map_err(|err| format!("failed to inspect {PUBLIC_TOOL_NAME} runner: {err}")) + } + + async fn wait_for_exit(&mut self) -> Result { + self.child + .wait() + .await + .map_err(|err| format!("failed to wait for {PUBLIC_TOOL_NAME} runner: {err}")) + } + + async fn join_stderr_task(&mut self) { + let Some(stderr_task) = self.stderr_task.take() else { + return; + }; + if let Err(err) = stderr_task.await { + warn!("failed to join {PUBLIC_TOOL_NAME} stderr task: {err}"); + } + } +} + +pub(crate) struct CodeModeService { + js_repl_node_path: Option, + stored_values: Mutex>, + process: Arc>>, + next_session_id: Mutex, +} + +impl CodeModeService { + pub(crate) fn new(js_repl_node_path: Option) -> Self { + Self { + js_repl_node_path, + stored_values: Mutex::new(HashMap::new()), + process: Arc::new(Mutex::new(None)), + next_session_id: Mutex::new(1), + } + } + + pub(crate) async fn stored_values(&self) -> HashMap { + self.stored_values.lock().await.clone() + } + + pub(crate) async fn replace_stored_values(&self, values: HashMap) { + *self.stored_values.lock().await = values; + } + + async fn ensure_started( + &self, + ) -> Result>, String> { + let mut process_slot = self.process.lock().await; + let needs_spawn = match process_slot.as_mut() { + Some(process) => !matches!(process.has_exited(), Ok(false)), + None => true, + }; + if needs_spawn { + let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()).await?; + *process_slot = Some(spawn_code_mode_process(&node_path).await?); + } + drop(process_slot); + Ok(self.process.clone().lock_owned().await) + } + + pub(crate) async fn allocate_session_id(&self) -> i32 { + let mut next_session_id = self.next_session_id.lock().await; + let session_id = *next_session_id; + *next_session_id = next_session_id.saturating_add(1); + session_id + } +} + #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(rename_all = "snake_case")] enum CodeModeToolKind { @@ -64,12 +197,21 @@ struct EnabledTool { #[derive(Serialize)] #[serde(tag = "type", rename_all = "snake_case")] enum HostToNodeMessage { - Init { + Start { + session_id: i32, enabled_tools: Vec, stored_values: HashMap, source: String, }, + Poll { + session_id: i32, + yield_time_ms: u64, + }, + Terminate { + session_id: i32, + }, Response { + session_id: i32, id: String, code_mode_result: JsonValue, }, @@ -79,12 +221,22 @@ enum HostToNodeMessage { #[serde(tag = "type", rename_all = "snake_case")] enum NodeToHostMessage { ToolCall { + session_id: i32, id: String, name: String, #[serde(default)] input: Option, }, + Yielded { + session_id: i32, + content_items: Vec, + }, + Terminated { + session_id: i32, + content_items: Vec, + }, Result { + session_id: i32, content_items: Vec, stored_values: HashMap, #[serde(default)] @@ -94,6 +246,18 @@ enum NodeToHostMessage { }, } +enum CodeModeSessionProgress { + Finished(FunctionToolOutput), + Yielded { output: FunctionToolOutput }, +} + +enum CodeModeExecutionStatus { + Completed, + Failed, + Running(i32), + Terminated, +} + pub(crate) fn instructions(config: &Config) -> Option { if !config.features.enabled(Feature::CodeMode) { return None; @@ -114,7 +278,10 @@ pub(crate) fn instructions(config: &Config) -> Option { )); section.push_str("- Import nested tools from `tools.js`, for example `import { exec_command } from \"tools.js\"` or `import { ALL_TOOLS } from \"tools.js\"` to inspect the available `{ module, name, description }` entries. Namespaced tools are also available from `tools/.js`; MCP tools use `tools/mcp/.js`, for example `import { append_notebook_logs_chart } from \"tools/mcp/ologs.js\"`. Nested tool calls resolve to their code-mode result values.\n"); section.push_str(&format!( - "- Import `{{ output_text, output_image, set_max_output_tokens_per_exec_call, store, load }}` from `@openai/code_mode` (or `\"openai/code_mode\"`). `output_text(value)` surfaces text back to the model and stringifies non-string objects with `JSON.stringify(...)` when possible. `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs. `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, and `load(key)` returns a cloned stored value or `undefined`. `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate the final Rust-side result of the current `{PUBLIC_TOOL_NAME}` execution; the default is `10000`. This guards the overall `{PUBLIC_TOOL_NAME}` output, not individual nested tool invocations. The returned content starts with a separate `Script completed` or `Script failed` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker.\n", + "- Import `{{ output_text, output_image, set_max_output_tokens_per_exec_call, set_yield_time, store, load }}` from `@openai/code_mode` (or `\"openai/code_mode\"`). `output_text(value)` surfaces text back to the model and stringifies non-string objects with `JSON.stringify(...)` when possible. `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs. `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, and `load(key)` returns a cloned stored value or `undefined`. `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate direct `{PUBLIC_TOOL_NAME}` returns; `{WAIT_TOOL_NAME}` uses its own `max_tokens` argument instead and defaults to `10000`. `set_yield_time(value)` asks `{PUBLIC_TOOL_NAME}` to return early if the script is still running after that many milliseconds so `{WAIT_TOOL_NAME}` can resume it later. The returned content starts with a separate `Script completed`, `Script failed`, or `Script running with session ID …` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker.\n", + )); + section.push_str(&format!( + "- If `{PUBLIC_TOOL_NAME}` returns `Script running with session ID …`, call `{WAIT_TOOL_NAME}` with that `session_id` to keep waiting for more output, completion, or termination.\n", )); section.push_str( "- Function tools require JSON object arguments. Freeform tools require raw strings.\n", @@ -137,30 +304,103 @@ pub(crate) async fn execute( tracker, }; let enabled_tools = build_enabled_tools(&exec).await; - let stored_values = exec.session.services.code_mode_store.stored_values().await; + let service = &exec.session.services.code_mode_service; + let stored_values = service.stored_values().await; let source = build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?; - execute_node(exec, source, enabled_tools, stored_values) + let session_id = service.allocate_session_id().await; + let process_slot = service + .ensure_started() .await - .map_err(FunctionCallError::RespondToModel) + .map_err(FunctionCallError::RespondToModel)?; + let result = { + let mut process_slot = process_slot; + let Some(process) = process_slot.as_mut() else { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner failed to start" + ))); + }; + drive_code_mode_session( + &exec, + process, + HostToNodeMessage::Start { + session_id, + enabled_tools, + stored_values, + source, + }, + None, + false, + ) + .await + }; + match result { + Ok(CodeModeSessionProgress::Finished(output)) + | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), + Err(error) => Err(FunctionCallError::RespondToModel(error)), + } } -async fn execute_node( - exec: ExecContext, - source: String, - enabled_tools: Vec, - stored_values: HashMap, -) -> Result { - let node_path = resolve_compatible_node(exec.turn.config.js_repl_node_path.as_deref()).await?; - let started_at = std::time::Instant::now(); +pub(crate) async fn wait( + session: Arc, + turn: Arc, + tracker: SharedTurnDiffTracker, + session_id: i32, + yield_time_ms: u64, + max_output_tokens: Option, + terminate: bool, +) -> Result { + let exec = ExecContext { + session, + turn, + tracker, + }; + let process_slot = exec + .session + .services + .code_mode_service + .ensure_started() + .await + .map_err(FunctionCallError::RespondToModel)?; + let result = { + let mut process_slot = process_slot; + let Some(process) = process_slot.as_mut() else { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner failed to start" + ))); + }; + if !matches!(process.has_exited(), Ok(false)) { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner failed to start" + ))); + } + drive_code_mode_session( + &exec, + process, + if terminate { + HostToNodeMessage::Terminate { session_id } + } else { + HostToNodeMessage::Poll { + session_id, + yield_time_ms, + } + }, + Some(max_output_tokens), + terminate, + ) + .await + }; + match result { + Ok(CodeModeSessionProgress::Finished(output)) + | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), + Err(error) => Err(FunctionCallError::RespondToModel(error)), + } +} - let env = create_env(&exec.turn.shell_environment_policy, None); - let mut cmd = tokio::process::Command::new(&node_path); +async fn spawn_code_mode_process(node_path: &std::path::Path) -> Result { + let mut cmd = tokio::process::Command::new(node_path); cmd.arg("--experimental-vm-modules"); cmd.arg("--eval"); cmd.arg(CODE_MODE_RUNNER_SOURCE); - cmd.current_dir(&exec.turn.cwd); - cmd.env_clear(); - cmd.envs(env); cmd.stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) @@ -177,7 +417,7 @@ async fn execute_node( .stderr .take() .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stderr"))?; - let mut stdin = child + let stdin = child .stdin .take() .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stdin"))?; @@ -185,138 +425,189 @@ async fn execute_node( let stderr_task = tokio::spawn(async move { let mut reader = BufReader::new(stderr); let mut buf = Vec::new(); - let _ = reader.read_to_end(&mut buf).await; - String::from_utf8_lossy(&buf).trim().to_string() + match reader.read_to_end(&mut buf).await { + Ok(_) => { + let stderr = String::from_utf8_lossy(&buf).trim().to_string(); + if !stderr.is_empty() { + warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}"); + } + } + Err(err) => { + warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}"); + } + } }); - write_message( - &mut stdin, - &HostToNodeMessage::Init { - enabled_tools: enabled_tools.clone(), - stored_values, - source, - }, - ) - .await?; - - let mut stdout_lines = BufReader::new(stdout).lines(); - let mut pending_result = None; - while let Some(line) = stdout_lines - .next_line() - .await - .map_err(|err| format!("failed to read {PUBLIC_TOOL_NAME} runner stdout: {err}"))? - { - if line.trim().is_empty() { - continue; - } - let message: NodeToHostMessage = serde_json::from_str(&line).map_err(|err| { - format!("invalid {PUBLIC_TOOL_NAME} runner message: {err}; line={line}") - })?; - match message { - NodeToHostMessage::ToolCall { id, name, input } => { - let response = HostToNodeMessage::Response { - id, - code_mode_result: call_nested_tool(exec.clone(), name, input).await, - }; - write_message(&mut stdin, &response).await?; - } - NodeToHostMessage::Result { - content_items, - stored_values, - error_text, - max_output_tokens_per_exec_call, - } => { - exec.session - .services - .code_mode_store - .replace_stored_values(stored_values) - .await; - pending_result = Some(( - output_content_items_from_json_values(content_items)?, - error_text, - max_output_tokens_per_exec_call, - )); - break; - } - } - } - - drop(stdin); - - let status = child - .wait() - .await - .map_err(|err| format!("failed to wait for {PUBLIC_TOOL_NAME} runner: {err}"))?; - let stderr = stderr_task - .await - .map_err(|err| format!("failed to collect {PUBLIC_TOOL_NAME} stderr: {err}"))?; - let wall_time = started_at.elapsed(); - let success = status.success(); - - let Some((mut content_items, error_text, max_output_tokens_per_exec_call)) = pending_result - else { - let message = if stderr.is_empty() { - format!("{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})") - } else { - stderr - }; - return Err(message); - }; - - if !success { - let error_text = error_text.unwrap_or_else(|| { - if stderr.is_empty() { - format!("Process exited with status {status}") - } else { - stderr - } - }); - content_items.push(FunctionCallOutputContentItem::InputText { - text: format!("Script error:\n{error_text}"), - }); - } - - let mut content_items = - truncate_code_mode_result(content_items, max_output_tokens_per_exec_call); - prepend_script_status(&mut content_items, success, wall_time); - Ok(FunctionToolOutput::from_content( - content_items, - Some(success), - )) + Ok(CodeModeProcess { + child, + stdin, + stdout_lines: BufReader::new(stdout).lines(), + stderr_task: Some(stderr_task), + pending_messages: HashMap::new(), + }) } -async fn write_message( - stdin: &mut tokio::process::ChildStdin, - message: &HostToNodeMessage, -) -> Result<(), String> { - let line = serde_json::to_string(message) - .map_err(|err| format!("failed to serialize {PUBLIC_TOOL_NAME} message: {err}"))?; - stdin - .write_all(line.as_bytes()) +async fn drive_code_mode_session( + exec: &ExecContext, + process: &mut CodeModeProcess, + message: HostToNodeMessage, + poll_max_output_tokens: Option>, + is_terminate: bool, +) -> Result { + let started_at = std::time::Instant::now(); + let session_id = match &message { + HostToNodeMessage::Start { session_id, .. } + | HostToNodeMessage::Poll { session_id, .. } + | HostToNodeMessage::Terminate { session_id } + | HostToNodeMessage::Response { session_id, .. } => *session_id, + }; + process + .write(&message) .await - .map_err(|err| format!("failed to write {PUBLIC_TOOL_NAME} message: {err}"))?; - stdin - .write_all(b"\n") - .await - .map_err(|err| format!("failed to write {PUBLIC_TOOL_NAME} message newline: {err}"))?; - stdin - .flush() - .await - .map_err(|err| format!("failed to flush {PUBLIC_TOOL_NAME} message: {err}")) + .map_err(|err| err.to_string())?; + + loop { + let message = process + .read(session_id) + .await + .map_err(|err| err.to_string())?; + if let Some(progress) = handle_node_message( + exec, + process, + session_id, + message, + poll_max_output_tokens, + started_at, + is_terminate, + ) + .await? + { + return Ok(progress); + } + } +} + +async fn handle_node_message( + exec: &ExecContext, + process: &mut CodeModeProcess, + session_id: i32, + message: NodeToHostMessage, + poll_max_output_tokens: Option>, + started_at: std::time::Instant, + is_terminate: bool, +) -> Result, String> { + match message { + NodeToHostMessage::ToolCall { + session_id: message_session_id, + id, + name, + input, + } => { + if is_terminate { + return Ok(None); + } + let response = HostToNodeMessage::Response { + session_id: message_session_id, + id, + code_mode_result: call_nested_tool(exec.clone(), name, input).await, + }; + process + .write(&response) + .await + .map_err(|err| err.to_string())?; + Ok(None) + } + NodeToHostMessage::Yielded { content_items, .. } => { + if is_terminate { + return Ok(None); + } + let mut delta_items = output_content_items_from_json_values(content_items)?; + delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); + prepend_script_status( + &mut delta_items, + CodeModeExecutionStatus::Running(session_id), + started_at.elapsed(), + ); + Ok(Some(CodeModeSessionProgress::Yielded { + output: FunctionToolOutput::from_content(delta_items, Some(true)), + })) + } + NodeToHostMessage::Terminated { content_items, .. } => { + let mut delta_items = output_content_items_from_json_values(content_items)?; + delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); + prepend_script_status( + &mut delta_items, + CodeModeExecutionStatus::Terminated, + started_at.elapsed(), + ); + Ok(Some(CodeModeSessionProgress::Finished( + FunctionToolOutput::from_content(delta_items, Some(true)), + ))) + } + NodeToHostMessage::Result { + content_items, + stored_values, + error_text, + max_output_tokens_per_exec_call, + .. + } => { + exec.session + .services + .code_mode_service + .replace_stored_values(stored_values) + .await; + let mut delta_items = output_content_items_from_json_values(content_items)?; + let success = error_text.is_none(); + if let Some(error_text) = error_text { + delta_items.push(FunctionCallOutputContentItem::InputText { + text: format!("Script error:\n{error_text}"), + }); + } + + let mut delta_items = truncate_code_mode_result( + delta_items, + poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), + ); + prepend_script_status( + &mut delta_items, + if success { + CodeModeExecutionStatus::Completed + } else { + CodeModeExecutionStatus::Failed + }, + started_at.elapsed(), + ); + Ok(Some(CodeModeSessionProgress::Finished( + FunctionToolOutput::from_content(delta_items, Some(success)), + ))) + } + } +} + +fn message_session_id(message: &NodeToHostMessage) -> i32 { + match message { + NodeToHostMessage::ToolCall { session_id, .. } + | NodeToHostMessage::Yielded { session_id, .. } + | NodeToHostMessage::Terminated { session_id, .. } + | NodeToHostMessage::Result { session_id, .. } => *session_id, + } } fn prepend_script_status( content_items: &mut Vec, - success: bool, + status: CodeModeExecutionStatus, wall_time: Duration, ) { let wall_time_seconds = ((wall_time.as_secs_f32()) * 10.0).round() / 10.0; let header = format!( "{}\nWall time {wall_time_seconds:.1} seconds\nOutput:\n", - if success { - "Script completed" - } else { - "Script failed" + match status { + CodeModeExecutionStatus::Completed => "Script completed".to_string(), + CodeModeExecutionStatus::Failed => "Script failed".to_string(), + CodeModeExecutionStatus::Running(session_id) => { + format!("Script running with session ID {session_id}") + } + CodeModeExecutionStatus::Terminated => "Script terminated".to_string(), } ); content_items.insert(0, FunctionCallOutputContentItem::InputText { text: header }); @@ -366,7 +657,7 @@ async fn build_enabled_tools(exec: &ExecContext) -> Vec { fn enabled_tool_from_spec(spec: ToolSpec) -> Option { let tool_name = spec.name().to_string(); - if tool_name == PUBLIC_TOOL_NAME { + if tool_name == PUBLIC_TOOL_NAME || tool_name == WAIT_TOOL_NAME { return None; } diff --git a/codex-rs/core/src/tools/code_mode_runner.cjs b/codex-rs/core/src/tools/code_mode_runner.cjs index f36fa6f92..d64e369f3 100644 --- a/codex-rs/core/src/tools/code_mode_runner.cjs +++ b/codex-rs/core/src/tools/code_mode_runner.cjs @@ -1,9 +1,8 @@ 'use strict'; const readline = require('node:readline'); -const vm = require('node:vm'); +const { Worker } = require('node:worker_threads'); -const { SourceTextModule, SyntheticModule } = vm; const DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL = 10000; function normalizeMaxOutputTokensPerExecCall(value) { @@ -13,6 +12,425 @@ function normalizeMaxOutputTokensPerExecCall(value) { return value; } +function normalizeYieldTime(value) { + if (!Number.isSafeInteger(value) || value < 0) { + throw new TypeError('yield_time must be a non-negative safe integer'); + } + return value; +} + +function formatErrorText(error) { + return String(error && error.stack ? error.stack : error); +} + +function cloneJsonValue(value) { + return JSON.parse(JSON.stringify(value)); +} + +function clearTimer(timer) { + if (timer !== null) { + clearTimeout(timer); + } + return null; +} + +function takeContentItems(session) { + const clonedContentItems = cloneJsonValue(session.content_items); + session.content_items.splice(0, session.content_items.length); + return Array.isArray(clonedContentItems) ? clonedContentItems : []; +} + +function codeModeWorkerMain() { + 'use strict'; + + const { parentPort, workerData } = require('node:worker_threads'); + const vm = require('node:vm'); + const { SourceTextModule, SyntheticModule } = vm; + + const DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL = 10000; + + function normalizeMaxOutputTokensPerExecCall(value) { + if (!Number.isSafeInteger(value) || value < 0) { + throw new TypeError('max_output_tokens_per_exec_call must be a non-negative safe integer'); + } + return value; + } + + function normalizeYieldTime(value) { + if (!Number.isSafeInteger(value) || value < 0) { + throw new TypeError('yield_time must be a non-negative safe integer'); + } + return value; + } + + function formatErrorText(error) { + return String(error && error.stack ? error.stack : error); + } + + function cloneJsonValue(value) { + return JSON.parse(JSON.stringify(value)); + } + + function createToolCaller() { + let nextId = 0; + const pending = new Map(); + + parentPort.on('message', (message) => { + if (message.type === 'tool_response') { + const entry = pending.get(message.id); + if (!entry) { + return; + } + pending.delete(message.id); + entry.resolve(message.result ?? ''); + return; + } + + if (message.type === 'tool_response_error') { + const entry = pending.get(message.id); + if (!entry) { + return; + } + pending.delete(message.id); + entry.reject(new Error(message.error_text ?? 'tool call failed')); + return; + } + }); + + return (name, input) => { + const id = 'msg-' + ++nextId; + return new Promise((resolve, reject) => { + pending.set(id, { resolve, reject }); + parentPort.postMessage({ + type: 'tool_call', + id, + name: String(name), + input, + }); + }); + }; + } + + function createContentItems() { + const contentItems = []; + const push = contentItems.push.bind(contentItems); + contentItems.push = (...items) => { + for (const item of items) { + parentPort.postMessage({ + type: 'content_item', + item: cloneJsonValue(item), + }); + } + return push(...items); + }; + parentPort.on('message', (message) => { + if (message.type === 'clear_content') { + contentItems.splice(0, contentItems.length); + } + }); + return contentItems; + } + + function createToolsNamespace(callTool, enabledTools) { + const tools = Object.create(null); + + for (const { tool_name } of enabledTools) { + Object.defineProperty(tools, tool_name, { + value: async (args) => callTool(tool_name, args), + configurable: false, + enumerable: true, + writable: false, + }); + } + + return Object.freeze(tools); + } + + function createAllToolsMetadata(enabledTools) { + return Object.freeze( + enabledTools.map(({ module: modulePath, name, description }) => + Object.freeze({ + module: modulePath, + name, + description, + }) + ) + ); + } + + function createToolsModule(context, callTool, enabledTools) { + const tools = createToolsNamespace(callTool, enabledTools); + const allTools = createAllToolsMetadata(enabledTools); + const exportNames = ['ALL_TOOLS']; + + for (const { tool_name } of enabledTools) { + if (tool_name !== 'ALL_TOOLS') { + exportNames.push(tool_name); + } + } + + const uniqueExportNames = [...new Set(exportNames)]; + + return new SyntheticModule( + uniqueExportNames, + function initToolsModule() { + this.setExport('ALL_TOOLS', allTools); + for (const exportName of uniqueExportNames) { + if (exportName !== 'ALL_TOOLS') { + this.setExport(exportName, tools[exportName]); + } + } + }, + { context } + ); + } + + function ensureContentItems(context) { + if (!Array.isArray(context.__codexContentItems)) { + context.__codexContentItems = []; + } + return context.__codexContentItems; + } + + function serializeOutputText(value) { + if (typeof value === 'string') { + return value; + } + if ( + typeof value === 'undefined' || + value === null || + typeof value === 'boolean' || + typeof value === 'number' || + typeof value === 'bigint' + ) { + return String(value); + } + + const serialized = JSON.stringify(value); + if (typeof serialized === 'string') { + return serialized; + } + + return String(value); + } + + function normalizeOutputImageUrl(value) { + if (typeof value !== 'string' || !value) { + throw new TypeError('output_image expects a non-empty image URL string'); + } + if (/^(?:https?:\/\/|data:)/i.test(value)) { + return value; + } + throw new TypeError('output_image expects an http(s) or data URL'); + } + + function createCodeModeModule(context, state) { + const load = (key) => { + if (typeof key !== 'string') { + throw new TypeError('load key must be a string'); + } + if (!Object.prototype.hasOwnProperty.call(state.storedValues, key)) { + return undefined; + } + return cloneJsonValue(state.storedValues[key]); + }; + const store = (key, value) => { + if (typeof key !== 'string') { + throw new TypeError('store key must be a string'); + } + state.storedValues[key] = cloneJsonValue(value); + }; + const outputText = (value) => { + const item = { + type: 'input_text', + text: serializeOutputText(value), + }; + ensureContentItems(context).push(item); + return item; + }; + const outputImage = (value) => { + const item = { + type: 'input_image', + image_url: normalizeOutputImageUrl(value), + }; + ensureContentItems(context).push(item); + return item; + }; + + return new SyntheticModule( + [ + 'load', + 'output_text', + 'output_image', + 'set_max_output_tokens_per_exec_call', + 'set_yield_time', + 'store', + ], + function initCodeModeModule() { + this.setExport('load', load); + this.setExport('output_text', outputText); + this.setExport('output_image', outputImage); + this.setExport('set_max_output_tokens_per_exec_call', (value) => { + const normalized = normalizeMaxOutputTokensPerExecCall(value); + state.maxOutputTokensPerExecCall = normalized; + parentPort.postMessage({ + type: 'set_max_output_tokens_per_exec_call', + value: normalized, + }); + return normalized; + }); + this.setExport('set_yield_time', (value) => { + const normalized = normalizeYieldTime(value); + parentPort.postMessage({ + type: 'set_yield_time', + value: normalized, + }); + return normalized; + }); + this.setExport('store', store); + }, + { context } + ); + } + + function namespacesMatch(left, right) { + if (left.length !== right.length) { + return false; + } + return left.every((segment, index) => segment === right[index]); + } + + function createNamespacedToolsNamespace(callTool, enabledTools, namespace) { + const tools = Object.create(null); + + for (const tool of enabledTools) { + const toolNamespace = Array.isArray(tool.namespace) ? tool.namespace : []; + if (!namespacesMatch(toolNamespace, namespace)) { + continue; + } + + Object.defineProperty(tools, tool.name, { + value: async (args) => callTool(tool.tool_name, args), + configurable: false, + enumerable: true, + writable: false, + }); + } + + return Object.freeze(tools); + } + + function createNamespacedToolsModule(context, callTool, enabledTools, namespace) { + const tools = createNamespacedToolsNamespace(callTool, enabledTools, namespace); + const exportNames = []; + + for (const exportName of Object.keys(tools)) { + if (exportName !== 'ALL_TOOLS') { + exportNames.push(exportName); + } + } + + const uniqueExportNames = [...new Set(exportNames)]; + + return new SyntheticModule( + uniqueExportNames, + function initNamespacedToolsModule() { + for (const exportName of uniqueExportNames) { + this.setExport(exportName, tools[exportName]); + } + }, + { context } + ); + } + + function createModuleResolver(context, callTool, enabledTools, state) { + const toolsModule = createToolsModule(context, callTool, enabledTools); + const codeModeModule = createCodeModeModule(context, state); + const namespacedModules = new Map(); + + return function resolveModule(specifier) { + if (specifier === 'tools.js') { + return toolsModule; + } + if (specifier === '@openai/code_mode' || specifier === 'openai/code_mode') { + return codeModeModule; + } + const namespacedMatch = /^tools\/(.+)\.js$/.exec(specifier); + if (!namespacedMatch) { + throw new Error('Unsupported import in exec: ' + specifier); + } + + const namespace = namespacedMatch[1] + .split('/') + .filter((segment) => segment.length > 0); + if (namespace.length === 0) { + throw new Error('Unsupported import in exec: ' + specifier); + } + + const cacheKey = namespace.join('/'); + if (!namespacedModules.has(cacheKey)) { + namespacedModules.set( + cacheKey, + createNamespacedToolsModule(context, callTool, enabledTools, namespace) + ); + } + return namespacedModules.get(cacheKey); + }; + } + + async function runModule(context, start, state, callTool) { + const resolveModule = createModuleResolver( + context, + callTool, + start.enabled_tools ?? [], + state + ); + const mainModule = new SourceTextModule(start.source, { + context, + identifier: 'exec_main.mjs', + importModuleDynamically: async (specifier) => resolveModule(specifier), + }); + + await mainModule.link(resolveModule); + await mainModule.evaluate(); + } + + async function main() { + const start = workerData ?? {}; + const state = { + maxOutputTokensPerExecCall: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + storedValues: cloneJsonValue(start.stored_values ?? {}), + }; + const callTool = createToolCaller(); + const context = vm.createContext({ + __codexContentItems: createContentItems(), + __codex_tool_call: callTool, + }); + + try { + await runModule(context, start, state, callTool); + parentPort.postMessage({ + type: 'result', + stored_values: state.storedValues, + }); + } catch (error) { + parentPort.postMessage({ + type: 'result', + stored_values: state.storedValues, + error_text: formatErrorText(error), + }); + } + } + + void main().catch((error) => { + parentPort.postMessage({ + type: 'result', + stored_values: {}, + error_text: formatErrorText(error), + }); + }); +} + function createProtocol() { const rl = readline.createInterface({ input: process.stdin, @@ -21,11 +439,10 @@ function createProtocol() { let nextId = 0; const pending = new Map(); - let initResolve; - let initReject; - const init = new Promise((resolve, reject) => { - initResolve = resolve; - initReject = reject; + const sessions = new Map(); + let closedResolve; + const closed = new Promise((resolve) => { + closedResolve = resolve; }); rl.on('line', (line) => { @@ -37,40 +454,80 @@ function createProtocol() { try { message = JSON.parse(line); } catch (error) { - initReject(error); + process.stderr.write(formatErrorText(error) + '\n'); return; } - if (message.type === 'init') { - initResolve(message); + if (message.type === 'start') { + startSession(protocol, sessions, message); + return; + } + + if (message.type === 'poll') { + const session = sessions.get(message.session_id); + if (session) { + schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0)); + } else { + void protocol.send({ + type: 'result', + session_id: message.session_id, + content_items: [], + stored_values: {}, + error_text: `exec session ${message.session_id} not found`, + max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + }); + } + return; + } + + if (message.type === 'terminate') { + const session = sessions.get(message.session_id); + if (session) { + void terminateSession(protocol, sessions, session); + } else { + void protocol.send({ + type: 'result', + session_id: message.session_id, + content_items: [], + stored_values: {}, + error_text: `exec session ${message.session_id} not found`, + max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + }); + } return; } if (message.type === 'response') { - const entry = pending.get(message.id); + const entry = pending.get(message.session_id + ':' + message.id); if (!entry) { return; } - pending.delete(message.id); + pending.delete(message.session_id + ':' + message.id); entry.resolve(message.code_mode_result ?? ''); return; } - initReject(new Error(`Unknown protocol message type: ${message.type}`)); + process.stderr.write('Unknown protocol message type: ' + message.type + '\n'); }); rl.on('close', () => { const error = new Error('stdin closed'); - initReject(error); for (const entry of pending.values()) { entry.reject(error); } pending.clear(); + for (const session of sessions.values()) { + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + void session.worker.terminate().catch(() => {}); + } + sessions.clear(); + closedResolve(); }); function send(message) { return new Promise((resolve, reject) => { - process.stdout.write(`${JSON.stringify(message)}\n`, (error) => { + process.stdout.write(JSON.stringify(message) + '\n', (error) => { if (error) { reject(error); } else { @@ -80,328 +537,223 @@ function createProtocol() { }); } - function request(type, payload) { - const id = `msg-${++nextId}`; + function request(sessionId, type, payload) { + const id = 'msg-' + ++nextId; + const pendingKey = sessionId + ':' + id; return new Promise((resolve, reject) => { - pending.set(id, { resolve, reject }); - void send({ type, id, ...payload }).catch((error) => { - pending.delete(id); + pending.set(pendingKey, { resolve, reject }); + void send({ type, session_id: sessionId, id, ...payload }).catch((error) => { + pending.delete(pendingKey); reject(error); }); }); } - return { init, request, send }; + const protocol = { closed, request, send }; + return protocol; } -function readContentItems(context) { - try { - const serialized = vm.runInContext('JSON.stringify(globalThis.__codexContentItems ?? [])', context); - const contentItems = JSON.parse(serialized); - return Array.isArray(contentItems) ? contentItems : []; - } catch { - return []; - } +function sessionWorkerSource() { + return '(' + codeModeWorkerMain.toString() + ')();'; } -function formatErrorText(error) { - return String(error && error.stack ? error.stack : error); -} - -function cloneJsonValue(value) { - return JSON.parse(JSON.stringify(value)); -} - -function createToolCaller(protocol) { - return (name, input) => - protocol.request('tool_call', { - name: String(name), - input, - }); -} - -function createToolsNamespace(callTool, enabledTools) { - const tools = Object.create(null); - - for (const { tool_name } of enabledTools) { - Object.defineProperty(tools, tool_name, { - value: async (args) => callTool(tool_name, args), - configurable: false, - enumerable: true, - writable: false, - }); - } - - return Object.freeze(tools); -} - -function createAllToolsMetadata(enabledTools) { - return Object.freeze( - enabledTools.map(({ module: modulePath, name, description }) => - Object.freeze({ - module: modulePath, - name, - description, - }) - ) - ); -} - -function createToolsModule(context, callTool, enabledTools) { - const tools = createToolsNamespace(callTool, enabledTools); - const allTools = createAllToolsMetadata(enabledTools); - const exportNames = ['ALL_TOOLS']; - - for (const { tool_name } of enabledTools) { - if (tool_name !== 'ALL_TOOLS') { - exportNames.push(tool_name); - } - } - - const uniqueExportNames = [...new Set(exportNames)]; - - return new SyntheticModule( - uniqueExportNames, - function initToolsModule() { - this.setExport('ALL_TOOLS', allTools); - for (const exportName of uniqueExportNames) { - if (exportName !== 'ALL_TOOLS') { - this.setExport(exportName, tools[exportName]); - } - } - }, - { context } - ); -} - -function ensureContentItems(context) { - if (!Array.isArray(context.__codexContentItems)) { - context.__codexContentItems = []; - } - return context.__codexContentItems; -} - -function serializeOutputText(value) { - if (typeof value === 'string') { - return value; - } - if ( - typeof value === 'undefined' || - value === null || - typeof value === 'boolean' || - typeof value === 'number' || - typeof value === 'bigint' - ) { - return String(value); - } - - const serialized = JSON.stringify(value); - if (typeof serialized === 'string') { - return serialized; - } - - return String(value); -} - -function normalizeOutputImageUrl(value) { - if (typeof value !== 'string' || !value) { - throw new TypeError('output_image expects a non-empty image URL string'); - } - if (/^(?:https?:\/\/|data:)/i.test(value)) { - return value; - } - throw new TypeError('output_image expects an http(s) or data URL'); -} - -function createCodeModeModule(context, state) { - const load = (key) => { - if (typeof key !== 'string') { - throw new TypeError('load key must be a string'); - } - if (!Object.prototype.hasOwnProperty.call(state.storedValues, key)) { - return undefined; - } - return cloneJsonValue(state.storedValues[key]); - }; - const store = (key, value) => { - if (typeof key !== 'string') { - throw new TypeError('store key must be a string'); - } - state.storedValues[key] = cloneJsonValue(value); - }; - const outputText = (value) => { - const item = { - type: 'input_text', - text: serializeOutputText(value), - }; - ensureContentItems(context).push(item); - return item; - }; - const outputImage = (value) => { - const item = { - type: 'input_image', - image_url: normalizeOutputImageUrl(value), - }; - ensureContentItems(context).push(item); - return item; +function startSession(protocol, sessions, start) { + const session = { + completed: false, + content_items: [], + id: start.session_id, + initial_yield_timer: null, + initial_yield_triggered: false, + max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + poll_yield_timer: null, + worker: new Worker(sessionWorkerSource(), { + eval: true, + workerData: start, + }), }; + sessions.set(session.id, session); - return new SyntheticModule( - ['load', 'output_text', 'output_image', 'set_max_output_tokens_per_exec_call', 'store'], - function initCodeModeModule() { - this.setExport('load', load); - this.setExport('output_text', outputText); - this.setExport('output_image', outputImage); - this.setExport('set_max_output_tokens_per_exec_call', (value) => { - const normalized = normalizeMaxOutputTokensPerExecCall(value); - state.maxOutputTokensPerExecCall = normalized; - return normalized; + session.worker.on('message', (message) => { + void handleWorkerMessage(protocol, sessions, session, message).catch((error) => { + void completeSession(protocol, sessions, session, { + type: 'result', + stored_values: {}, + error_text: formatErrorText(error), }); - this.setExport('store', store); - }, - { context } - ); -} - -function namespacesMatch(left, right) { - if (left.length !== right.length) { - return false; - } - return left.every((segment, index) => segment === right[index]); -} - -function createNamespacedToolsNamespace(callTool, enabledTools, namespace) { - const tools = Object.create(null); - - for (const tool of enabledTools) { - const toolNamespace = Array.isArray(tool.namespace) ? tool.namespace : []; - if (!namespacesMatch(toolNamespace, namespace)) { - continue; - } - - Object.defineProperty(tools, tool.name, { - value: async (args) => callTool(tool.tool_name, args), - configurable: false, - enumerable: true, - writable: false, }); - } - - return Object.freeze(tools); -} - -function createNamespacedToolsModule(context, callTool, enabledTools, namespace) { - const tools = createNamespacedToolsNamespace(callTool, enabledTools, namespace); - const exportNames = []; - - for (const exportName of Object.keys(tools)) { - if (exportName !== 'ALL_TOOLS') { - exportNames.push(exportName); - } - } - - const uniqueExportNames = [...new Set(exportNames)]; - - return new SyntheticModule( - uniqueExportNames, - function initNamespacedToolsModule() { - for (const exportName of uniqueExportNames) { - this.setExport(exportName, tools[exportName]); - } - }, - { context } - ); -} - -function createModuleResolver(context, callTool, enabledTools, state) { - const toolsModule = createToolsModule(context, callTool, enabledTools); - const codeModeModule = createCodeModeModule(context, state); - const namespacedModules = new Map(); - - return function resolveModule(specifier) { - if (specifier === 'tools.js') { - return toolsModule; - } - if (specifier === '@openai/code_mode' || specifier === 'openai/code_mode') { - return codeModeModule; - } - const namespacedMatch = /^tools\/(.+)\.js$/.exec(specifier); - if (!namespacedMatch) { - throw new Error(`Unsupported import in exec: ${specifier}`); - } - - const namespace = namespacedMatch[1] - .split('/') - .filter((segment) => segment.length > 0); - if (namespace.length === 0) { - throw new Error(`Unsupported import in exec: ${specifier}`); - } - - const cacheKey = namespace.join('/'); - if (!namespacedModules.has(cacheKey)) { - namespacedModules.set( - cacheKey, - createNamespacedToolsModule(context, callTool, enabledTools, namespace) - ); - } - return namespacedModules.get(cacheKey); - }; -} - -async function runModule(context, request, state, callTool) { - const resolveModule = createModuleResolver( - context, - callTool, - request.enabled_tools ?? [], - state - ); - const mainModule = new SourceTextModule(request.source, { - context, - identifier: 'exec_main.mjs', - importModuleDynamically: async (specifier) => resolveModule(specifier), }); + session.worker.on('error', (error) => { + void completeSession(protocol, sessions, session, { + type: 'result', + stored_values: {}, + error_text: formatErrorText(error), + }); + }); + session.worker.on('exit', (code) => { + if (code !== 0 && !session.completed) { + void completeSession(protocol, sessions, session, { + type: 'result', + stored_values: {}, + error_text: 'exec worker exited with code ' + code, + }); + } + }); +} - await mainModule.link(resolveModule); - await mainModule.evaluate(); +async function handleWorkerMessage(protocol, sessions, session, message) { + if (session.completed) { + return; + } + + if (message.type === 'content_item') { + session.content_items.push(cloneJsonValue(message.item)); + return; + } + + if (message.type === 'set_yield_time') { + scheduleInitialYield(protocol, session, normalizeYieldTime(message.value ?? 0)); + return; + } + + if (message.type === 'set_max_output_tokens_per_exec_call') { + session.max_output_tokens_per_exec_call = normalizeMaxOutputTokensPerExecCall(message.value); + return; + } + + if (message.type === 'tool_call') { + void forwardToolCall(protocol, session, message); + return; + } + + if (message.type === 'result') { + await completeSession(protocol, sessions, session, { + type: 'result', + stored_values: cloneJsonValue(message.stored_values ?? {}), + error_text: + typeof message.error_text === 'string' ? message.error_text : undefined, + }); + return; + } + + process.stderr.write('Unknown worker message type: ' + message.type + '\n'); +} + +async function forwardToolCall(protocol, session, message) { + try { + const result = await protocol.request(session.id, 'tool_call', { + name: String(message.name), + input: message.input, + }); + if (session.completed) { + return; + } + try { + session.worker.postMessage({ + type: 'tool_response', + id: message.id, + result, + }); + } catch {} + } catch (error) { + if (session.completed) { + return; + } + try { + session.worker.postMessage({ + type: 'tool_response_error', + id: message.id, + error_text: formatErrorText(error), + }); + } catch {} + } +} + +async function sendYielded(protocol, session) { + if (session.completed) { + return; + } + const contentItems = takeContentItems(session); + try { + session.worker.postMessage({ type: 'clear_content' }); + } catch {} + await protocol.send({ + type: 'yielded', + session_id: session.id, + content_items: contentItems, + }); +} + +function scheduleInitialYield(protocol, session, yieldTime) { + if (session.completed || session.initial_yield_triggered) { + return yieldTime; + } + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.initial_yield_timer = setTimeout(() => { + session.initial_yield_timer = null; + session.initial_yield_triggered = true; + void sendYielded(protocol, session); + }, yieldTime); + return yieldTime; +} + +function schedulePollYield(protocol, session, yieldTime) { + if (session.completed) { + return; + } + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + session.poll_yield_timer = setTimeout(() => { + session.poll_yield_timer = null; + void sendYielded(protocol, session); + }, yieldTime); +} + +async function completeSession(protocol, sessions, session, message) { + if (session.completed) { + return; + } + session.completed = true; + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + sessions.delete(session.id); + const contentItems = takeContentItems(session); + try { + session.worker.postMessage({ type: 'clear_content' }); + } catch {} + await protocol.send({ + ...message, + session_id: session.id, + content_items: contentItems, + max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call, + }); +} + +async function terminateSession(protocol, sessions, session) { + if (session.completed) { + return; + } + session.completed = true; + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + sessions.delete(session.id); + const contentItems = takeContentItems(session); + try { + await session.worker.terminate(); + } catch {} + await protocol.send({ + type: 'terminated', + session_id: session.id, + content_items: contentItems, + }); } async function main() { const protocol = createProtocol(); - const request = await protocol.init; - const state = { - maxOutputTokensPerExecCall: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, - storedValues: cloneJsonValue(request.stored_values ?? {}), - }; - const callTool = createToolCaller(protocol); - const context = vm.createContext({ - __codexContentItems: [], - __codex_tool_call: callTool, - }); - - try { - await runModule(context, request, state, callTool); - await protocol.send({ - type: 'result', - content_items: readContentItems(context), - stored_values: state.storedValues, - max_output_tokens_per_exec_call: state.maxOutputTokensPerExecCall, - }); - process.exit(0); - } catch (error) { - await protocol.send({ - type: 'result', - content_items: readContentItems(context), - stored_values: state.storedValues, - error_text: formatErrorText(error), - max_output_tokens_per_exec_call: state.maxOutputTokensPerExecCall, - }); - process.exit(1); - } + await protocol.closed; } void main().catch(async (error) => { try { - process.stderr.write(`${formatErrorText(error)}\n`); + process.stderr.write(formatErrorText(error) + '\n'); } finally { process.exitCode = 1; } diff --git a/codex-rs/core/src/tools/handlers/code_mode.rs b/codex-rs/core/src/tools/handlers/code_mode.rs index 4763a69b4..fe4a23965 100644 --- a/codex-rs/core/src/tools/handlers/code_mode.rs +++ b/codex-rs/core/src/tools/handlers/code_mode.rs @@ -1,16 +1,35 @@ use async_trait::async_trait; +use serde::Deserialize; -use crate::features::Feature; use crate::function_tool::FunctionCallError; use crate::tools::code_mode; +use crate::tools::code_mode::DEFAULT_WAIT_YIELD_TIME_MS; use crate::tools::code_mode::PUBLIC_TOOL_NAME; +use crate::tools::code_mode::WAIT_TOOL_NAME; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; pub struct CodeModeHandler; +pub struct CodeModeWaitHandler; + +#[derive(Debug, Deserialize)] +struct ExecWaitArgs { + session_id: i32, + #[serde(default = "default_wait_yield_time_ms")] + yield_time_ms: u64, + #[serde(default)] + max_tokens: Option, + #[serde(default)] + terminate: bool, +} + +fn default_wait_yield_time_ms() -> u64 { + DEFAULT_WAIT_YIELD_TIME_MS +} #[async_trait] impl ToolHandler for CodeModeHandler { @@ -29,25 +48,57 @@ impl ToolHandler for CodeModeHandler { session, turn, tracker, + tool_name, payload, .. } = invocation; - if !session.features().enabled(Feature::CodeMode) { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} is disabled by feature flag" - ))); - } - - let code = match payload { - ToolPayload::Custom { input } => input, - _ => { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} expects raw JavaScript source text" - ))); + match payload { + ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => { + code_mode::execute(session, turn, tracker, input).await } - }; - - code_mode::execute(session, turn, tracker, code).await + _ => Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} expects raw JavaScript source text" + ))), + } + } +} + +#[async_trait] +impl ToolHandler for CodeModeWaitHandler { + type Output = FunctionToolOutput; + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + tool_name, + payload, + .. + } = invocation; + + match payload { + ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => { + let args: ExecWaitArgs = parse_arguments(&arguments)?; + code_mode::wait( + session, + turn, + tracker, + args.session_id, + args.yield_time_ms, + args.max_tokens, + args.terminate, + ) + .await + } + _ => Err(FunctionCallError::RespondToModel(format!( + "{WAIT_TOOL_NAME} expects JSON arguments" + ))), + } } } diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index 217780046..068031b5a 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -35,6 +35,7 @@ use crate::sandboxing::normalize_additional_permissions; pub use apply_patch::ApplyPatchHandler; pub use artifacts::ArtifactsHandler; pub use code_mode::CodeModeHandler; +pub use code_mode::CodeModeWaitHandler; use codex_protocol::models::PermissionProfile; use codex_protocol::protocol::AskForApproval; pub use dynamic::DynamicToolHandler; diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index afa4c1861..7d95afaeb 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -8,7 +8,9 @@ use crate::features::Features; use crate::mcp_connection_manager::ToolInfo; use crate::models_manager::collaboration_mode_presets::CollaborationModesConfig; use crate::original_image_detail::can_request_original_image_detail; +use crate::tools::code_mode::DEFAULT_WAIT_YIELD_TIME_MS; use crate::tools::code_mode::PUBLIC_TOOL_NAME; +use crate::tools::code_mode::WAIT_TOOL_NAME; use crate::tools::code_mode_description::augment_tool_spec_for_code_mode; use crate::tools::discoverable::DiscoverablePluginInfo; use crate::tools::discoverable::DiscoverableTool; @@ -589,6 +591,55 @@ fn create_write_stdin_tool() -> ToolSpec { }) } +fn create_exec_wait_tool() -> ToolSpec { + let properties = BTreeMap::from([ + ( + "session_id".to_string(), + JsonSchema::Number { + description: Some("Identifier of the running exec session.".to_string()), + }, + ), + ( + "yield_time_ms".to_string(), + JsonSchema::Number { + description: Some( + "How long to wait (in milliseconds) for more output before yielding again." + .to_string(), + ), + }, + ), + ( + "max_tokens".to_string(), + JsonSchema::Number { + description: Some( + "Maximum number of output tokens to return for this wait call.".to_string(), + ), + }, + ), + ( + "terminate".to_string(), + JsonSchema::Boolean { + description: Some("Whether to terminate the running exec session.".to_string()), + }, + ), + ]); + + ToolSpec::Function(ResponsesApiTool { + name: WAIT_TOOL_NAME.to_string(), + description: format!( + "Waits on a yielded `{PUBLIC_TOOL_NAME}` session and returns new output or completion." + ), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["session_id".to_string()]), + additional_properties: Some(false.into()), + }, + output_schema: None, + defer_loading: None, + }) +} + fn create_shell_tool(request_permission_enabled: bool) -> ToolSpec { let mut properties = BTreeMap::from([ ( @@ -1832,7 +1883,7 @@ source: /[\s\S]+/ enabled_tool_names.join(", ") }; let description = format!( - "Runs JavaScript in a Node-backed `node:vm` context. This is a freeform tool: send raw JavaScript source text (no JSON/quotes/markdown fences). Direct tool calls remain available while `{PUBLIC_TOOL_NAME}` is enabled. Inside JavaScript, import nested tools from `tools.js`, for example `import {{ exec_command }} from \"tools.js\"` or `import {{ ALL_TOOLS }} from \"tools.js\"` to inspect the available `{{ module, name, description }}` entries. Namespaced tools are also available from `tools/.js`; MCP tools use `tools/mcp/.js`, for example `import {{ append_notebook_logs_chart }} from \"tools/mcp/ologs.js\"`. Nested tool calls resolve to their code-mode result values. Import `{{ output_text, output_image, set_max_output_tokens_per_exec_call, store, load }}` from `\"@openai/code_mode\"` (or `\"openai/code_mode\"`); `output_text(value)` surfaces text back to the model and stringifies non-string objects when possible, `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs, `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, `load(key)` returns a cloned stored value or `undefined`, and `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate the final Rust-side result of the current `{PUBLIC_TOOL_NAME}` execution. The default is `10000`. This guards the overall `{PUBLIC_TOOL_NAME}` output, not individual nested tool invocations. The returned content starts with a separate `Script completed` or `Script failed` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker. Function tools require JSON object arguments. Freeform tools require raw strings. `add_content(value)` remains available for compatibility with a content item, content-item array, or string. Structured nested-tool results should be converted to text first, for example with `JSON.stringify(...)`. Only content passed to `output_text(...)`, `output_image(...)`, or `add_content(value)` is surfaced back to the model. Enabled nested tools: {enabled_list}." + "Runs JavaScript in a Node-backed `node:vm` context. This is a freeform tool: send raw JavaScript source text (no JSON/quotes/markdown fences). Direct tool calls remain available while `{PUBLIC_TOOL_NAME}` is enabled. Inside JavaScript, import nested tools from `tools.js`, for example `import {{ exec_command }} from \"tools.js\"` or `import {{ ALL_TOOLS }} from \"tools.js\"` to inspect the available `{{ module, name, description }}` entries. Namespaced tools are also available from `tools/.js`; MCP tools use `tools/mcp/.js`, for example `import {{ append_notebook_logs_chart }} from \"tools/mcp/ologs.js\"`. Nested tool calls resolve to their code-mode result values. Import `{{ output_text, output_image, set_max_output_tokens_per_exec_call, set_yield_time, store, load }}` from `\"@openai/code_mode\"` (or `\"openai/code_mode\"`); `output_text(value)` surfaces text back to the model and stringifies non-string objects when possible, `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs, `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, `load(key)` returns a cloned stored value or `undefined`, `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate direct `{PUBLIC_TOOL_NAME}` returns, and `{WAIT_TOOL_NAME}` uses its own `max_tokens` argument with a default of `10000`. `set_yield_time(value)` asks `{PUBLIC_TOOL_NAME}` to return early if the script is still running after that many milliseconds so `{WAIT_TOOL_NAME}` can resume it later. The default wait timeout for `{WAIT_TOOL_NAME}` is {DEFAULT_WAIT_YIELD_TIME_MS}. The returned content starts with a separate `Script completed`, `Script failed`, or `Script running with session ID …` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker. Function tools require JSON object arguments. Freeform tools require raw strings. `add_content(value)` remains available for compatibility with a content item, content-item array, or string. Structured nested-tool results should be converted to text first, for example with `JSON.stringify(...)`. Only content passed to `output_text(...)`, `output_image(...)`, or `add_content(value)` is surfaced back to the model. Enabled nested tools: {enabled_list}." ); ToolSpec::Freeform(FreeformTool { @@ -1847,7 +1898,9 @@ source: /[\s\S]+/ } fn is_code_mode_nested_tool(spec: &ToolSpec) -> bool { - spec.name() != PUBLIC_TOOL_NAME && matches!(spec, ToolSpec::Function(_) | ToolSpec::Freeform(_)) + spec.name() != PUBLIC_TOOL_NAME + && spec.name() != WAIT_TOOL_NAME + && matches!(spec, ToolSpec::Function(_) | ToolSpec::Freeform(_)) } fn create_list_mcp_resources_tool() -> ToolSpec { @@ -2243,6 +2296,7 @@ pub(crate) fn build_specs_with_discoverable_tools( use crate::tools::handlers::ApplyPatchHandler; use crate::tools::handlers::ArtifactsHandler; use crate::tools::handlers::CodeModeHandler; + use crate::tools::handlers::CodeModeWaitHandler; use crate::tools::handlers::DynamicToolHandler; use crate::tools::handlers::GrepFilesHandler; use crate::tools::handlers::JsReplHandler; @@ -2281,6 +2335,7 @@ pub(crate) fn build_specs_with_discoverable_tools( }); let tool_suggest_handler = Arc::new(ToolSuggestHandler); let code_mode_handler = Arc::new(CodeModeHandler); + let code_mode_wait_handler = Arc::new(CodeModeWaitHandler); let js_repl_handler = Arc::new(JsReplHandler); let js_repl_reset_handler = Arc::new(JsReplResetHandler); let artifacts_handler = Arc::new(ArtifactsHandler); @@ -2311,6 +2366,13 @@ pub(crate) fn build_specs_with_discoverable_tools( config.code_mode_enabled, ); builder.register_handler(PUBLIC_TOOL_NAME, code_mode_handler); + push_tool_spec( + &mut builder, + create_exec_wait_tool(), + false, + config.code_mode_enabled, + ); + builder.register_handler(WAIT_TOOL_NAME, code_mode_wait_handler); } match &config.shell_type { diff --git a/codex-rs/core/tests/suite/code_mode.rs b/codex-rs/core/tests/suite/code_mode.rs index 07cadc343..23fcd9c08 100644 --- a/codex-rs/core/tests/suite/code_mode.rs +++ b/codex-rs/core/tests/suite/code_mode.rs @@ -21,6 +21,7 @@ use pretty_assertions::assert_eq; use serde_json::Value; use std::collections::HashMap; use std::fs; +use std::path::Path; use std::time::Duration; use wiremock::MockServer; @@ -32,6 +33,16 @@ fn custom_tool_output_items(req: &ResponsesRequest, call_id: &str) -> Vec .clone() } +fn function_tool_output_items(req: &ResponsesRequest, call_id: &str) -> Vec { + match req.function_call_output(call_id).get("output") { + Some(Value::Array(items)) => items.clone(), + Some(Value::String(text)) => { + vec![serde_json::json!({ "type": "input_text", "text": text })] + } + _ => panic!("function tool output should be serialized as text or content items"), + } +} + fn text_item(items: &[Value], index: usize) -> &str { items[index] .get("text") @@ -39,6 +50,23 @@ fn text_item(items: &[Value], index: usize) -> &str { .expect("content item should be input_text") } +fn extract_running_session_id(text: &str) -> i32 { + text.strip_prefix("Script running with session ID ") + .and_then(|rest| rest.split('\n').next()) + .expect("running header should contain a session ID") + .parse() + .expect("session ID should parse as i32") +} + +fn wait_for_file_source(path: &Path) -> Result { + let quoted_path = shlex::try_join([path.to_string_lossy().as_ref()])?; + let command = format!("if [ -f {quoted_path} ]; then printf ready; fi"); + Ok(format!( + r#"while ((await exec_command({{ cmd: {command:?} }})).output !== "ready") {{ +}}"# + )) +} + fn custom_tool_output_body_and_success( req: &ResponsesRequest, call_id: &str, @@ -289,6 +317,799 @@ Error:\ boom\n Ok(()) } +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_can_yield_and_resume_with_exec_wait() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let phase_2_gate = test.workspace_path("code-mode-phase-2.ready"); + let phase_3_gate = test.workspace_path("code-mode-phase-3.ready"); + let phase_2_wait = wait_for_file_source(&phase_2_gate)?; + let phase_3_wait = wait_for_file_source(&phase_3_gate)?; + + let code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +output_text("phase 1"); +set_yield_time(10); +{phase_2_wait} +output_text("phase 2"); +{phase_3_wait} +output_text("phase 3"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start the long exec").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script running with session ID \d+\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&first_items, 0), + ); + assert_eq!(text_item(&first_items, 1), "phase 1"); + let session_id = extract_running_session_id(text_item(&first_items, 0)); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + responses::ev_function_call( + "call-2", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "still waiting"), + ev_completed("resp-4"), + ]), + ) + .await; + + fs::write(&phase_2_gate, "ready")?; + test.submit_turn("wait again").await?; + + let second_request = second_completion.single_request(); + let second_items = function_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script running with session ID \d+\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&second_items, 0), + ); + assert_eq!( + extract_running_session_id(text_item(&second_items, 0)), + session_id + ); + assert_eq!(text_item(&second_items, 1), "phase 2"); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-5"), + responses::ev_function_call( + "call-3", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-5"), + ]), + ) + .await; + let third_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-3", "done"), + ev_completed("resp-6"), + ]), + ) + .await; + + fs::write(&phase_3_gate, "ready")?; + test.submit_turn("wait for completion").await?; + + let third_request = third_completion.single_request(); + let third_items = function_tool_output_items(&third_request, "call-3"); + assert_eq!(third_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&third_items, 0), + ); + assert_eq!(text_item(&third_items, 1), "phase 3"); + + Ok(()) +} + +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_can_run_multiple_yielded_sessions() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let session_a_gate = test.workspace_path("code-mode-session-a.ready"); + let session_b_gate = test.workspace_path("code-mode-session-b.ready"); + let session_a_wait = wait_for_file_source(&session_a_gate)?; + let session_b_wait = wait_for_file_source(&session_b_gate)?; + + let session_a_code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +output_text("session a start"); +set_yield_time(10); +{session_a_wait} +output_text("session a done"); +"# + ); + let session_b_code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +output_text("session b start"); +set_yield_time(10); +{session_b_wait} +output_text("session b done"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &session_a_code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "session a waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start session a").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + let session_a_id = extract_running_session_id(text_item(&first_items, 0)); + assert_eq!(text_item(&first_items, 1), "session a start"); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + ev_custom_tool_call("call-2", "exec", &session_b_code), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "session b waiting"), + ev_completed("resp-4"), + ]), + ) + .await; + + test.submit_turn("start session b").await?; + + let second_request = second_completion.single_request(); + let second_items = custom_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 2); + let session_b_id = extract_running_session_id(text_item(&second_items, 0)); + assert_eq!(text_item(&second_items, 1), "session b start"); + assert_ne!(session_a_id, session_b_id); + + fs::write(&session_a_gate, "ready")?; + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-5"), + responses::ev_function_call( + "call-3", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_a_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-5"), + ]), + ) + .await; + let third_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-3", "session a done"), + ev_completed("resp-6"), + ]), + ) + .await; + + test.submit_turn("wait session a").await?; + + let third_request = third_completion.single_request(); + let third_items = function_tool_output_items(&third_request, "call-3"); + assert_eq!(third_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&third_items, 0), + ); + assert_eq!(text_item(&third_items, 1), "session a done"); + + fs::write(&session_b_gate, "ready")?; + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-7"), + responses::ev_function_call( + "call-4", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_b_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-7"), + ]), + ) + .await; + let fourth_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-4", "session b done"), + ev_completed("resp-8"), + ]), + ) + .await; + + test.submit_turn("wait session b").await?; + + let fourth_request = fourth_completion.single_request(); + let fourth_items = function_tool_output_items(&fourth_request, "call-4"); + assert_eq!(fourth_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&fourth_items, 0), + ); + assert_eq!(text_item(&fourth_items, 1), "session b done"); + + Ok(()) +} + +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_wait_can_terminate_and_continue() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let termination_gate = test.workspace_path("code-mode-terminate.ready"); + let termination_wait = wait_for_file_source(&termination_gate)?; + + let code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +output_text("phase 1"); +set_yield_time(10); +{termination_wait} +output_text("phase 2"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start the long exec").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + let session_id = extract_running_session_id(text_item(&first_items, 0)); + assert_eq!(text_item(&first_items, 1), "phase 1"); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + responses::ev_function_call( + "call-2", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_id, + "terminate": true, + }))?, + ), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "terminated"), + ev_completed("resp-4"), + ]), + ) + .await; + + test.submit_turn("terminate it").await?; + + let second_request = second_completion.single_request(); + let second_items = function_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 1); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script terminated\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&second_items, 0), + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-5"), + ev_custom_tool_call( + "call-3", + "exec", + r#" +import { output_text } from "@openai/code_mode"; + +output_text("after terminate"); +"#, + ), + ev_completed("resp-5"), + ]), + ) + .await; + let third_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-3", "done"), + ev_completed("resp-6"), + ]), + ) + .await; + + test.submit_turn("run another exec").await?; + + let third_request = third_completion.single_request(); + let third_items = custom_tool_output_items(&third_request, "call-3"); + assert_eq!(third_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&third_items, 0), + ); + assert_eq!(text_item(&third_items, 1), "after terminate"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_wait_returns_error_for_unknown_session() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + responses::ev_function_call( + "call-1", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": 999_999, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-1"), + ]), + ) + .await; + let completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("wait on an unknown exec session").await?; + + let request = completion.single_request(); + let (_, success) = request + .function_call_output_content_and_success("call-1") + .expect("function tool output should be present"); + assert_ne!(success, Some(true)); + + let items = function_tool_output_items(&request, "call-1"); + assert_eq!(items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script failed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&items, 0), + ); + assert_eq!( + text_item(&items, 1), + "Script error:\nexec session 999999 not found" + ); + + Ok(()) +} + +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_wait_terminate_returns_completed_session_if_it_finished_in_background() +-> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let session_a_gate = test.workspace_path("code-mode-session-a-finished.ready"); + let session_b_gate = test.workspace_path("code-mode-session-b-blocked.ready"); + let session_a_wait = wait_for_file_source(&session_a_gate)?; + let session_b_wait = wait_for_file_source(&session_b_gate)?; + + let session_a_code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +output_text("session a start"); +set_yield_time(10); +{session_a_wait} +output_text("session a done"); +"# + ); + let session_b_code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +output_text("session b start"); +set_yield_time(10); +{session_b_wait} +output_text("session b done"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &session_a_code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "session a waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start session a").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + let session_a_id = extract_running_session_id(text_item(&first_items, 0)); + assert_eq!(text_item(&first_items, 1), "session a start"); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + ev_custom_tool_call("call-2", "exec", &session_b_code), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "session b waiting"), + ev_completed("resp-4"), + ]), + ) + .await; + + test.submit_turn("start session b").await?; + + let second_request = second_completion.single_request(); + let second_items = custom_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 2); + let session_b_id = extract_running_session_id(text_item(&second_items, 0)); + assert_eq!(text_item(&second_items, 1), "session b start"); + + fs::write(&session_a_gate, "ready")?; + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-5"), + responses::ev_function_call( + "call-3", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_b_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-5"), + ]), + ) + .await; + let third_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-3", "session b still waiting"), + ev_completed("resp-6"), + ]), + ) + .await; + + test.submit_turn("wait session b").await?; + + let third_request = third_completion.single_request(); + let third_items = function_tool_output_items(&third_request, "call-3"); + assert_eq!(third_items.len(), 1); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script running with session ID \d+\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&third_items, 0), + ); + assert_eq!( + extract_running_session_id(text_item(&third_items, 0)), + session_b_id + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-7"), + responses::ev_function_call( + "call-4", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_a_id, + "terminate": true, + }))?, + ), + ev_completed("resp-7"), + ]), + ) + .await; + let fourth_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-4", "session a already done"), + ev_completed("resp-8"), + ]), + ) + .await; + + test.submit_turn("terminate session a").await?; + + let fourth_request = fourth_completion.single_request(); + let fourth_items = function_tool_output_items(&fourth_request, "call-4"); + assert_eq!(fourth_items.len(), 1); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script terminated\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&fourth_items, 0), + ); + + Ok(()) +} + +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_wait_uses_its_own_max_tokens_budget() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let completion_gate = test.workspace_path("code-mode-max-tokens.ready"); + let completion_wait = wait_for_file_source(&completion_gate)?; + + let code = format!( + r#" +import {{ output_text, set_max_output_tokens_per_exec_call, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +output_text("phase 1"); +set_max_output_tokens_per_exec_call(100); +set_yield_time(10); +{completion_wait} +output_text("token one token two token three token four token five token six token seven"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start the long exec").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + assert_eq!(text_item(&first_items, 1), "phase 1"); + let session_id = extract_running_session_id(text_item(&first_items, 0)); + + fs::write(&completion_gate, "ready")?; + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + responses::ev_function_call( + "call-2", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_id, + "yield_time_ms": 1_000, + "max_tokens": 6, + }))?, + ), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "done"), + ev_completed("resp-4"), + ]), + ) + .await; + + test.submit_turn("wait for completion").await?; + + let second_request = second_completion.single_request(); + let second_items = function_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&second_items, 0), + ); + let expected_pattern = r#"(?sx) +\A +Total\ output\ lines:\ 1\n +\n +.*…\d+\ tokens\ truncated….* +\z +"#; + assert_regex_match(expected_pattern, text_item(&second_items, 1)); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn code_mode_can_output_serialized_text_via_openai_code_mode_module() -> Result<()> { skip_if_no_network!(Ok(()));