From e4eedd6170580d5b06fb539635a78f261a6b7369 Mon Sep 17 00:00:00 2001 From: Channing Conger Date: Fri, 20 Mar 2026 23:36:58 -0700 Subject: [PATCH] Code mode on v8 (#15276) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moves Code Mode to a new crate with no dependencies on codex. This create encodes the code mode semantics that we want for lifetime, mounting, tool calling. The model-facing surface is mostly unchanged. `exec` still runs raw JavaScript, `wait` still resumes or terminates a `cell_id`, nested tools are still available through `tools.*`, and helpers like `text`, `image`, `store`, `load`, `notify`, `yield_control`, and `exit` still exist. The major change is underneath that surface: - Old code mode was an external Node runtime. - New code mode is an in-process V8 runtime embedded directly in Rust. - Old code mode managed cells inside a long-lived Node runner process. - New code mode manages cells in Rust, with one V8 runtime thread per active `exec`. - Old code mode used JSON protocol messages over child stdin/stdout plus Node worker-thread messages. - New code mode uses Rust channels and direct V8 callbacks/events. This PR also fixes the two migration regressions that fell out of that substrate change: - `wait { terminate: true }` now waits for the V8 runtime to actually stop before reporting termination. - synchronous top-level `exit()` now succeeds again instead of surfacing as a script error. --- - `core/src/tools/code_mode/*` is now mostly an adapter layer for the public `exec` / `wait` tools. - `code-mode/src/service.rs` owns cell sessions and async control flow in Rust. - `code-mode/src/runtime/*.rs` owns the embedded V8 isolate and JavaScript execution. - each `exec` spawns a dedicated runtime thread plus a Rust session-control task. - helper globals are installed directly into the V8 context instead of being injected through a source prelude. - helper modules like `tools.js` and `@openai/code_mode` are synthesized through V8 module resolution callbacks in Rust. --- Also added a benchmark for showing the speed of init and use of a code mode env: ``` $ cargo bench -p codex-code-mode --bench exec_overhead -- --samples 30 --warm-iterations 25 --tool-counts 0,32,128 Finished [`bench` profile [optimized]](https://doc.rust-lang.org/cargo/reference/profiles.html#default-profiles) target(s) in 0.18s Running benches/exec_overhead.rs (target/release/deps/exec_overhead-008c440d800545ae) exec_overhead: samples=30, warm_iterations=25, tool_counts=[0, 32, 128] scenario tools samples warmups iters mean/exec p95/exec rssΔ p50 rssΔ max cold_exec 0 30 0 1 1.13ms 1.20ms 8.05MiB 8.06MiB warm_exec 0 30 1 25 473.43us 512.49us 912.00KiB 1.33MiB cold_exec 32 30 0 1 1.03ms 1.15ms 8.08MiB 8.11MiB warm_exec 32 30 1 25 509.73us 545.76us 960.00KiB 1.30MiB cold_exec 128 30 0 1 1.14ms 1.19ms 8.30MiB 8.34MiB warm_exec 128 30 1 25 575.08us 591.03us 736.00KiB 864.00KiB memory uses a fresh-process max RSS delta for each scenario ``` --------- Co-authored-by: Codex --- codex-rs/Cargo.lock | 15 + codex-rs/Cargo.toml | 4 +- codex-rs/code-mode/BUILD.bazel | 6 + codex-rs/code-mode/Cargo.toml | 25 + codex-rs/code-mode/src/description.rs | 555 +++++++++++ codex-rs/code-mode/src/lib.rs | 30 + codex-rs/code-mode/src/response.rs | 24 + codex-rs/code-mode/src/runtime/callbacks.rs | 209 ++++ codex-rs/code-mode/src/runtime/globals.rs | 138 +++ codex-rs/code-mode/src/runtime/mod.rs | 349 +++++++ .../code-mode/src/runtime/module_loader.rs | 235 +++++ codex-rs/code-mode/src/runtime/value.rs | 163 +++ codex-rs/code-mode/src/service.rs | 673 +++++++++++++ codex-rs/core/Cargo.toml | 1 + codex-rs/core/src/tools/code_mode/bridge.js | 51 - .../core/src/tools/code_mode/description.md | 19 - .../src/tools/code_mode/execute_handler.rs | 189 +--- codex-rs/core/src/tools/code_mode/mod.rs | 360 +++---- codex-rs/core/src/tools/code_mode/process.rs | 173 ---- codex-rs/core/src/tools/code_mode/protocol.rs | 169 ---- .../src/tools/code_mode/response_adapter.rs | 44 + codex-rs/core/src/tools/code_mode/runner.cjs | 938 ------------------ codex-rs/core/src/tools/code_mode/service.rs | 108 -- .../src/tools/code_mode/wait_description.md | 8 - .../core/src/tools/code_mode/wait_handler.rs | 68 +- codex-rs/core/src/tools/code_mode/worker.rs | 116 --- .../core/src/tools/code_mode_description.rs | 298 +----- codex-rs/core/src/tools/router.rs | 3 +- codex-rs/core/src/tools/spec.rs | 12 +- codex-rs/core/tests/suite/code_mode.rs | 8 +- codex-rs/core/tests/suite/unified_exec.rs | 4 +- 31 files changed, 2730 insertions(+), 2265 deletions(-) create mode 100644 codex-rs/code-mode/BUILD.bazel create mode 100644 codex-rs/code-mode/Cargo.toml create mode 100644 codex-rs/code-mode/src/description.rs create mode 100644 codex-rs/code-mode/src/lib.rs create mode 100644 codex-rs/code-mode/src/response.rs create mode 100644 codex-rs/code-mode/src/runtime/callbacks.rs create mode 100644 codex-rs/code-mode/src/runtime/globals.rs create mode 100644 codex-rs/code-mode/src/runtime/mod.rs create mode 100644 codex-rs/code-mode/src/runtime/module_loader.rs create mode 100644 codex-rs/code-mode/src/runtime/value.rs create mode 100644 codex-rs/code-mode/src/service.rs delete mode 100644 codex-rs/core/src/tools/code_mode/bridge.js delete mode 100644 codex-rs/core/src/tools/code_mode/description.md delete mode 100644 codex-rs/core/src/tools/code_mode/process.rs delete mode 100644 codex-rs/core/src/tools/code_mode/protocol.rs create mode 100644 codex-rs/core/src/tools/code_mode/response_adapter.rs delete mode 100644 codex-rs/core/src/tools/code_mode/runner.cjs delete mode 100644 codex-rs/core/src/tools/code_mode/service.rs delete mode 100644 codex-rs/core/src/tools/code_mode/wait_description.md delete mode 100644 codex-rs/core/src/tools/code_mode/worker.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index b3f8d8802..d0917ee38 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1800,6 +1800,20 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "codex-code-mode" +version = "0.0.0" +dependencies = [ + "async-trait", + "pretty_assertions", + "serde", + "serde_json", + "tokio", + "tokio-util", + "tracing", + "v8", +] + [[package]] name = "codex-config" version = "0.0.0" @@ -1857,6 +1871,7 @@ dependencies = [ "codex-arg0", "codex-artifacts", "codex-async-utils", + "codex-code-mode", "codex-config", "codex-connectors", "codex-exec-server", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 6d768d696..524b61e3b 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -13,6 +13,7 @@ members = [ "feedback", "features", "codex-backend-openapi-models", + "code-mode", "cloud-requirements", "cloud-tasks", "cloud-tasks-client", @@ -91,6 +92,7 @@ app_test_support = { path = "app-server/tests/common" } codex-ansi-escape = { path = "ansi-escape" } codex-api = { path = "codex-api" } codex-artifacts = { path = "artifacts" } +codex-code-mode = { path = "code-mode" } codex-package-manager = { path = "package-manager" } codex-app-server = { path = "app-server" } codex-app-server-client = { path = "app-server-client" } @@ -374,7 +376,7 @@ ignored = [ "openssl-sys", "codex-utils-readiness", "codex-secrets", - "codex-v8-poc" + "codex-v8-poc", ] [profile.release] diff --git a/codex-rs/code-mode/BUILD.bazel b/codex-rs/code-mode/BUILD.bazel new file mode 100644 index 000000000..bf39d9d5a --- /dev/null +++ b/codex-rs/code-mode/BUILD.bazel @@ -0,0 +1,6 @@ +load("//:defs.bzl", "codex_rust_crate") + +codex_rust_crate( + name = "code-mode", + crate_name = "codex_code_mode", +) diff --git a/codex-rs/code-mode/Cargo.toml b/codex-rs/code-mode/Cargo.toml new file mode 100644 index 000000000..e821ca0e4 --- /dev/null +++ b/codex-rs/code-mode/Cargo.toml @@ -0,0 +1,25 @@ +[package] +edition.workspace = true +license.workspace = true +name = "codex-code-mode" +version.workspace = true + +[lib] +doctest = false +name = "codex_code_mode" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +async-trait = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "sync", "time"] } +tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } +v8 = { workspace = true } + +[dev-dependencies] +pretty_assertions = { workspace = true } diff --git a/codex-rs/code-mode/src/description.rs b/codex-rs/code-mode/src/description.rs new file mode 100644 index 000000000..c875e2a1b --- /dev/null +++ b/codex-rs/code-mode/src/description.rs @@ -0,0 +1,555 @@ +use serde::Deserialize; +use serde::Serialize; +use serde_json::Value as JsonValue; + +use crate::PUBLIC_TOOL_NAME; + +const MAX_JS_SAFE_INTEGER: u64 = (1_u64 << 53) - 1; +const CODE_MODE_ONLY_PREFACE: &str = + "Use `exec/wait` tool to run all other tools, do not attempt to use any other tools directly"; +const EXEC_DESCRIPTION_TEMPLATE: &str = r#"## exec +- Runs raw JavaScript in an isolated context (no Node, no file system, or network access, no console). +- Send raw JavaScript source text, not JSON, quoted strings, or markdown code fences. +- You may optionally start the tool input with a first-line pragma like `// @exec: {"yield_time_ms": 10000, "max_output_tokens": 1000}`. +- `yield_time_ms` asks `exec` to yield early after that many milliseconds if the script is still running. +- `max_output_tokens` sets the token budget for direct `exec` results. By default the result is truncated to 10000 tokens. +- All nested tools are available on the global `tools` object, for example `await tools.exec_command(...)`. Tool names are exposed as normalized JavaScript identifiers, for example `await tools.mcp__ologs__get_profile(...)`. +- Tool methods take either string or object as parameter. +- They return either a structured value or a string based on the description above. + +- Global helpers: +- `exit()`: Immediately ends the current script successfully (like an early return from the top level). +- `text(value: string | number | boolean | undefined | null)`: Appends a text item. Non-string values are stringified with `JSON.stringify(...)` when possible. +- `image(imageUrlOrItem: string | { image_url: string; detail?: "auto" | "low" | "high" | "original" | null })`: Appends an image item. `image_url` can be an HTTPS URL or a base64-encoded `data:` URL. +- `store(key: string, value: any)`: stores a serializable value under a string key for later `exec` calls in the same session. +- `load(key: string)`: returns the stored value for a string key, or `undefined` if it is missing. +- `notify(value: string | number | boolean | undefined | null)`: immediately injects an extra `custom_tool_call_output` for the current `exec` call. Values are stringified like `text(...)`. +- `ALL_TOOLS`: metadata for the enabled nested tools as `{ name, description }` entries. +- `yield_control()`: yields the accumulated output to the model immediately while the script keeps running."#; +const WAIT_DESCRIPTION_TEMPLATE: &str = r#"- Use `wait` only after `exec` returns `Script running with cell ID ...`. +- `cell_id` identifies the running `exec` cell to resume. +- `yield_time_ms` controls how long to wait for more output before yielding again. If omitted, `wait` uses its default wait timeout. +- `max_tokens` limits how much new output this wait call returns. +- `terminate: true` stops the running cell instead of waiting for more output. +- `wait` returns only the new output since the last yield, or the final completion or termination result for that cell. +- If the cell is still running, `wait` may yield again with the same `cell_id`. +- If the cell has already finished, `wait` returns the completed result and closes the cell."#; + +pub const CODE_MODE_PRAGMA_PREFIX: &str = "// @exec:"; + +#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum CodeModeToolKind { + Function, + Freeform, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ToolDefinition { + pub name: String, + pub description: String, + pub kind: CodeModeToolKind, + pub input_schema: Option, + pub output_schema: Option, +} + +#[derive(Debug, Default, Deserialize, PartialEq, Eq)] +#[serde(deny_unknown_fields)] +struct CodeModeExecPragma { + #[serde(default)] + yield_time_ms: Option, + #[serde(default)] + max_output_tokens: Option, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ParsedExecSource { + pub code: String, + pub yield_time_ms: Option, + pub max_output_tokens: Option, +} + +pub fn parse_exec_source(input: &str) -> Result { + if input.trim().is_empty() { + return Err( + "exec expects raw JavaScript source text (non-empty). Provide JS only, optionally with first-line `// @exec: {\"yield_time_ms\": 10000, \"max_output_tokens\": 1000}`.".to_string(), + ); + } + + let mut args = ParsedExecSource { + code: input.to_string(), + yield_time_ms: None, + max_output_tokens: None, + }; + + let mut lines = input.splitn(2, '\n'); + let first_line = lines.next().unwrap_or_default(); + let rest = lines.next().unwrap_or_default(); + let trimmed = first_line.trim_start(); + let Some(pragma) = trimmed.strip_prefix(CODE_MODE_PRAGMA_PREFIX) else { + return Ok(args); + }; + + if rest.trim().is_empty() { + return Err( + "exec pragma must be followed by JavaScript source on subsequent lines".to_string(), + ); + } + + let directive = pragma.trim(); + if directive.is_empty() { + return Err( + "exec pragma must be a JSON object with supported fields `yield_time_ms` and `max_output_tokens`" + .to_string(), + ); + } + + let value: serde_json::Value = serde_json::from_str(directive).map_err(|err| { + format!( + "exec pragma must be valid JSON with supported fields `yield_time_ms` and `max_output_tokens`: {err}" + ) + })?; + let object = value.as_object().ok_or_else(|| { + "exec pragma must be a JSON object with supported fields `yield_time_ms` and `max_output_tokens`" + .to_string() + })?; + for key in object.keys() { + match key.as_str() { + "yield_time_ms" | "max_output_tokens" => {} + _ => { + return Err(format!( + "exec pragma only supports `yield_time_ms` and `max_output_tokens`; got `{key}`" + )); + } + } + } + + let pragma: CodeModeExecPragma = serde_json::from_value(value).map_err(|err| { + format!( + "exec pragma fields `yield_time_ms` and `max_output_tokens` must be non-negative safe integers: {err}" + ) + })?; + if pragma + .yield_time_ms + .is_some_and(|yield_time_ms| yield_time_ms > MAX_JS_SAFE_INTEGER) + { + return Err( + "exec pragma field `yield_time_ms` must be a non-negative safe integer".to_string(), + ); + } + if pragma.max_output_tokens.is_some_and(|max_output_tokens| { + u64::try_from(max_output_tokens) + .map(|max_output_tokens| max_output_tokens > MAX_JS_SAFE_INTEGER) + .unwrap_or(true) + }) { + return Err( + "exec pragma field `max_output_tokens` must be a non-negative safe integer".to_string(), + ); + } + + args.code = rest.to_string(); + args.yield_time_ms = pragma.yield_time_ms; + args.max_output_tokens = pragma.max_output_tokens; + Ok(args) +} + +pub fn is_code_mode_nested_tool(tool_name: &str) -> bool { + tool_name != crate::PUBLIC_TOOL_NAME && tool_name != crate::WAIT_TOOL_NAME +} + +pub fn build_exec_tool_description( + enabled_tools: &[(String, String)], + code_mode_only: bool, +) -> String { + if !code_mode_only { + return EXEC_DESCRIPTION_TEMPLATE.to_string(); + } + + let mut sections = vec![ + CODE_MODE_ONLY_PREFACE.to_string(), + EXEC_DESCRIPTION_TEMPLATE.to_string(), + ]; + + if !enabled_tools.is_empty() { + let nested_tool_reference = enabled_tools + .iter() + .map(|(name, nested_description)| { + let global_name = normalize_code_mode_identifier(name); + format!( + "### `{global_name}` (`{name}`)\n{}", + nested_description.trim() + ) + }) + .collect::>() + .join("\n\n"); + sections.push(nested_tool_reference); + } + + sections.join("\n\n") +} + +pub fn build_wait_tool_description() -> &'static str { + WAIT_DESCRIPTION_TEMPLATE +} + +pub fn normalize_code_mode_identifier(tool_key: &str) -> String { + let mut identifier = String::new(); + + for (index, ch) in tool_key.chars().enumerate() { + let is_valid = if index == 0 { + ch == '_' || ch == '$' || ch.is_ascii_alphabetic() + } else { + ch == '_' || ch == '$' || ch.is_ascii_alphanumeric() + }; + + if is_valid { + identifier.push(ch); + } else { + identifier.push('_'); + } + } + + if identifier.is_empty() { + "_".to_string() + } else { + identifier + } +} + +pub fn augment_tool_definition(mut definition: ToolDefinition) -> ToolDefinition { + if definition.name != PUBLIC_TOOL_NAME { + definition.description = append_code_mode_sample_for_definition(&definition); + } + definition +} + +pub fn enabled_tool_metadata(definition: &ToolDefinition) -> EnabledToolMetadata { + EnabledToolMetadata { + tool_name: definition.name.clone(), + global_name: normalize_code_mode_identifier(&definition.name), + description: definition.description.clone(), + kind: definition.kind, + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize)] +pub struct EnabledToolMetadata { + pub tool_name: String, + pub global_name: String, + pub description: String, + pub kind: CodeModeToolKind, +} + +pub fn append_code_mode_sample( + description: &str, + tool_name: &str, + input_name: &str, + input_type: String, + output_type: String, +) -> String { + let declaration = format!( + "declare const tools: {{ {} }};", + render_code_mode_tool_declaration(tool_name, input_name, input_type, output_type) + ); + format!("{description}\n\nexec tool declaration:\n```ts\n{declaration}\n```") +} + +fn append_code_mode_sample_for_definition(definition: &ToolDefinition) -> String { + let input_name = match definition.kind { + CodeModeToolKind::Function => "args", + CodeModeToolKind::Freeform => "input", + }; + let input_type = match definition.kind { + CodeModeToolKind::Function => definition + .input_schema + .as_ref() + .map(render_json_schema_to_typescript) + .unwrap_or_else(|| "unknown".to_string()), + CodeModeToolKind::Freeform => "string".to_string(), + }; + let output_type = definition + .output_schema + .as_ref() + .map(render_json_schema_to_typescript) + .unwrap_or_else(|| "unknown".to_string()); + append_code_mode_sample( + &definition.description, + &definition.name, + input_name, + input_type, + output_type, + ) +} + +fn render_code_mode_tool_declaration( + tool_name: &str, + input_name: &str, + input_type: String, + output_type: String, +) -> String { + let tool_name = normalize_code_mode_identifier(tool_name); + format!("{tool_name}({input_name}: {input_type}): Promise<{output_type}>;") +} + +pub fn render_json_schema_to_typescript(schema: &JsonValue) -> String { + render_json_schema_to_typescript_inner(schema) +} + +fn render_json_schema_to_typescript_inner(schema: &JsonValue) -> String { + match schema { + JsonValue::Bool(true) => "unknown".to_string(), + JsonValue::Bool(false) => "never".to_string(), + JsonValue::Object(map) => { + if let Some(value) = map.get("const") { + return render_json_schema_literal(value); + } + + if let Some(values) = map.get("enum").and_then(JsonValue::as_array) { + let rendered = values + .iter() + .map(render_json_schema_literal) + .collect::>(); + if !rendered.is_empty() { + return rendered.join(" | "); + } + } + + for key in ["anyOf", "oneOf"] { + if let Some(variants) = map.get(key).and_then(JsonValue::as_array) { + let rendered = variants + .iter() + .map(render_json_schema_to_typescript_inner) + .collect::>(); + if !rendered.is_empty() { + return rendered.join(" | "); + } + } + } + + if let Some(variants) = map.get("allOf").and_then(JsonValue::as_array) { + let rendered = variants + .iter() + .map(render_json_schema_to_typescript_inner) + .collect::>(); + if !rendered.is_empty() { + return rendered.join(" & "); + } + } + + if let Some(schema_type) = map.get("type") { + if let Some(types) = schema_type.as_array() { + let rendered = types + .iter() + .filter_map(JsonValue::as_str) + .map(|schema_type| render_json_schema_type_keyword(map, schema_type)) + .collect::>(); + if !rendered.is_empty() { + return rendered.join(" | "); + } + } + + if let Some(schema_type) = schema_type.as_str() { + return render_json_schema_type_keyword(map, schema_type); + } + } + + if map.contains_key("properties") + || map.contains_key("additionalProperties") + || map.contains_key("required") + { + return render_json_schema_object(map); + } + + if map.contains_key("items") || map.contains_key("prefixItems") { + return render_json_schema_array(map); + } + + "unknown".to_string() + } + _ => "unknown".to_string(), + } +} + +fn render_json_schema_type_keyword( + map: &serde_json::Map, + schema_type: &str, +) -> String { + match schema_type { + "string" => "string".to_string(), + "number" | "integer" => "number".to_string(), + "boolean" => "boolean".to_string(), + "null" => "null".to_string(), + "array" => render_json_schema_array(map), + "object" => render_json_schema_object(map), + _ => "unknown".to_string(), + } +} + +fn render_json_schema_array(map: &serde_json::Map) -> String { + if let Some(items) = map.get("items") { + let item_type = render_json_schema_to_typescript_inner(items); + return format!("Array<{item_type}>"); + } + + if let Some(items) = map.get("prefixItems").and_then(JsonValue::as_array) { + let item_types = items + .iter() + .map(render_json_schema_to_typescript_inner) + .collect::>(); + if !item_types.is_empty() { + return format!("[{}]", item_types.join(", ")); + } + } + + "unknown[]".to_string() +} + +fn render_json_schema_object(map: &serde_json::Map) -> String { + let required = map + .get("required") + .and_then(JsonValue::as_array) + .map(|items| { + items + .iter() + .filter_map(JsonValue::as_str) + .collect::>() + }) + .unwrap_or_default(); + let properties = map + .get("properties") + .and_then(JsonValue::as_object) + .cloned() + .unwrap_or_default(); + + let mut sorted_properties = properties.iter().collect::>(); + sorted_properties.sort_unstable_by(|(name_a, _), (name_b, _)| name_a.cmp(name_b)); + let mut lines = sorted_properties + .into_iter() + .map(|(name, value)| { + let optional = if required.iter().any(|required_name| required_name == name) { + "" + } else { + "?" + }; + let property_name = render_json_schema_property_name(name); + let property_type = render_json_schema_to_typescript_inner(value); + format!("{property_name}{optional}: {property_type};") + }) + .collect::>(); + + if let Some(additional_properties) = map.get("additionalProperties") { + let property_type = match additional_properties { + JsonValue::Bool(true) => Some("unknown".to_string()), + JsonValue::Bool(false) => None, + value => Some(render_json_schema_to_typescript_inner(value)), + }; + + if let Some(property_type) = property_type { + lines.push(format!("[key: string]: {property_type};")); + } + } else if properties.is_empty() { + lines.push("[key: string]: unknown;".to_string()); + } + + if lines.is_empty() { + return "{}".to_string(); + } + + format!("{{ {} }}", lines.join(" ")) +} + +fn render_json_schema_property_name(name: &str) -> String { + if normalize_code_mode_identifier(name) == name { + name.to_string() + } else { + serde_json::to_string(name).unwrap_or_else(|_| format!("\"{}\"", name.replace('"', "\\\""))) + } +} + +fn render_json_schema_literal(value: &JsonValue) -> String { + serde_json::to_string(value).unwrap_or_else(|_| "unknown".to_string()) +} + +#[cfg(test)] +mod tests { + use super::CodeModeToolKind; + use super::ParsedExecSource; + use super::ToolDefinition; + use super::augment_tool_definition; + use super::build_exec_tool_description; + use super::normalize_code_mode_identifier; + use super::parse_exec_source; + use pretty_assertions::assert_eq; + use serde_json::json; + + #[test] + fn parse_exec_source_without_pragma() { + assert_eq!( + parse_exec_source("text('hi')").unwrap(), + ParsedExecSource { + code: "text('hi')".to_string(), + yield_time_ms: None, + max_output_tokens: None, + } + ); + } + + #[test] + fn parse_exec_source_with_pragma() { + assert_eq!( + parse_exec_source("// @exec: {\"yield_time_ms\": 10}\ntext('hi')").unwrap(), + ParsedExecSource { + code: "text('hi')".to_string(), + yield_time_ms: Some(10), + max_output_tokens: None, + } + ); + } + + #[test] + fn normalize_identifier_rewrites_invalid_characters() { + assert_eq!( + "mcp__ologs__get_profile", + normalize_code_mode_identifier("mcp__ologs__get_profile") + ); + assert_eq!( + "hidden_dynamic_tool", + normalize_code_mode_identifier("hidden-dynamic-tool") + ); + } + + #[test] + fn augment_tool_definition_appends_typed_declaration() { + let definition = ToolDefinition { + name: "hidden_dynamic_tool".to_string(), + description: "Test tool".to_string(), + kind: CodeModeToolKind::Function, + input_schema: Some(json!({ + "type": "object", + "properties": { "city": { "type": "string" } }, + "required": ["city"], + "additionalProperties": false + })), + output_schema: Some(json!({ + "type": "object", + "properties": { "ok": { "type": "boolean" } }, + "required": ["ok"] + })), + }; + + let description = augment_tool_definition(definition).description; + assert!(description.contains("declare const tools")); + assert!( + description.contains( + "hidden_dynamic_tool(args: { city: string; }): Promise<{ ok: boolean; }>;" + ) + ); + } + + #[test] + fn code_mode_only_description_includes_nested_tools() { + let description = + build_exec_tool_description(&[("foo".to_string(), "bar".to_string())], true); + assert!(description.contains("### `foo` (`foo`)")); + } +} diff --git a/codex-rs/code-mode/src/lib.rs b/codex-rs/code-mode/src/lib.rs new file mode 100644 index 000000000..841e568be --- /dev/null +++ b/codex-rs/code-mode/src/lib.rs @@ -0,0 +1,30 @@ +mod description; +mod response; +mod runtime; +mod service; + +pub use description::CODE_MODE_PRAGMA_PREFIX; +pub use description::CodeModeToolKind; +pub use description::ToolDefinition; +pub use description::append_code_mode_sample; +pub use description::augment_tool_definition; +pub use description::build_exec_tool_description; +pub use description::build_wait_tool_description; +pub use description::is_code_mode_nested_tool; +pub use description::normalize_code_mode_identifier; +pub use description::parse_exec_source; +pub use description::render_json_schema_to_typescript; +pub use response::FunctionCallOutputContentItem; +pub use response::ImageDetail; +pub use runtime::DEFAULT_EXEC_YIELD_TIME_MS; +pub use runtime::DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL; +pub use runtime::DEFAULT_WAIT_YIELD_TIME_MS; +pub use runtime::ExecuteRequest; +pub use runtime::RuntimeResponse; +pub use runtime::WaitRequest; +pub use service::CodeModeService; +pub use service::CodeModeTurnHost; +pub use service::CodeModeTurnWorker; + +pub const PUBLIC_TOOL_NAME: &str = "exec"; +pub const WAIT_TOOL_NAME: &str = "wait"; diff --git a/codex-rs/code-mode/src/response.rs b/codex-rs/code-mode/src/response.rs new file mode 100644 index 000000000..43579fac8 --- /dev/null +++ b/codex-rs/code-mode/src/response.rs @@ -0,0 +1,24 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ImageDetail { + Auto, + Low, + High, + Original, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum FunctionCallOutputContentItem { + InputText { + text: String, + }, + InputImage { + image_url: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + detail: Option, + }, +} diff --git a/codex-rs/code-mode/src/runtime/callbacks.rs b/codex-rs/code-mode/src/runtime/callbacks.rs new file mode 100644 index 000000000..b77ae82d6 --- /dev/null +++ b/codex-rs/code-mode/src/runtime/callbacks.rs @@ -0,0 +1,209 @@ +use crate::response::FunctionCallOutputContentItem; + +use super::EXIT_SENTINEL; +use super::RuntimeEvent; +use super::RuntimeState; +use super::value::json_to_v8; +use super::value::normalize_output_image; +use super::value::serialize_output_text; +use super::value::throw_type_error; +use super::value::v8_value_to_json; + +pub(super) fn tool_callback( + scope: &mut v8::PinScope<'_, '_>, + args: v8::FunctionCallbackArguments, + mut retval: v8::ReturnValue, +) { + let tool_name = args.data().to_rust_string_lossy(scope); + let input = if args.length() == 0 { + Ok(None) + } else { + v8_value_to_json(scope, args.get(0)) + }; + let input = match input { + Ok(input) => input, + Err(error_text) => { + throw_type_error(scope, &error_text); + return; + } + }; + + let Some(resolver) = v8::PromiseResolver::new(scope) else { + throw_type_error(scope, "failed to create tool promise"); + return; + }; + let promise = resolver.get_promise(scope); + + let resolver = v8::Global::new(scope, resolver); + let Some(state) = scope.get_slot_mut::() else { + throw_type_error(scope, "runtime state unavailable"); + return; + }; + let id = format!("tool-{}", state.next_tool_call_id); + state.next_tool_call_id = state.next_tool_call_id.saturating_add(1); + let event_tx = state.event_tx.clone(); + state.pending_tool_calls.insert(id.clone(), resolver); + let _ = event_tx.send(RuntimeEvent::ToolCall { + id, + name: tool_name, + input, + }); + retval.set(promise.into()); +} + +pub(super) fn text_callback( + scope: &mut v8::PinScope<'_, '_>, + args: v8::FunctionCallbackArguments, + mut retval: v8::ReturnValue, +) { + let value = if args.length() == 0 { + v8::undefined(scope).into() + } else { + args.get(0) + }; + let text = match serialize_output_text(scope, value) { + Ok(text) => text, + Err(error_text) => { + throw_type_error(scope, &error_text); + return; + } + }; + if let Some(state) = scope.get_slot::() { + let _ = state.event_tx.send(RuntimeEvent::ContentItem( + FunctionCallOutputContentItem::InputText { text }, + )); + } + retval.set(v8::undefined(scope).into()); +} + +pub(super) fn image_callback( + scope: &mut v8::PinScope<'_, '_>, + args: v8::FunctionCallbackArguments, + mut retval: v8::ReturnValue, +) { + let value = if args.length() == 0 { + v8::undefined(scope).into() + } else { + args.get(0) + }; + let image_item = match normalize_output_image(scope, value) { + Ok(image_item) => image_item, + Err(()) => return, + }; + if let Some(state) = scope.get_slot::() { + let _ = state.event_tx.send(RuntimeEvent::ContentItem(image_item)); + } + retval.set(v8::undefined(scope).into()); +} + +pub(super) fn store_callback( + scope: &mut v8::PinScope<'_, '_>, + args: v8::FunctionCallbackArguments, + _retval: v8::ReturnValue, +) { + let key = match args.get(0).to_string(scope) { + Some(key) => key.to_rust_string_lossy(scope), + None => { + throw_type_error(scope, "store key must be a string"); + return; + } + }; + let value = args.get(1); + let serialized = match v8_value_to_json(scope, value) { + Ok(Some(value)) => value, + Ok(None) => { + throw_type_error( + scope, + &format!("Unable to store {key:?}. Only plain serializable objects can be stored."), + ); + return; + } + Err(error_text) => { + throw_type_error(scope, &error_text); + return; + } + }; + if let Some(state) = scope.get_slot_mut::() { + state.stored_values.insert(key, serialized); + } +} + +pub(super) fn load_callback( + scope: &mut v8::PinScope<'_, '_>, + args: v8::FunctionCallbackArguments, + mut retval: v8::ReturnValue, +) { + let key = match args.get(0).to_string(scope) { + Some(key) => key.to_rust_string_lossy(scope), + None => { + throw_type_error(scope, "load key must be a string"); + return; + } + }; + let value = scope + .get_slot::() + .and_then(|state| state.stored_values.get(&key)) + .cloned(); + let Some(value) = value else { + retval.set(v8::undefined(scope).into()); + return; + }; + let Some(value) = json_to_v8(scope, &value) else { + throw_type_error(scope, "failed to load stored value"); + return; + }; + retval.set(value); +} + +pub(super) fn notify_callback( + scope: &mut v8::PinScope<'_, '_>, + args: v8::FunctionCallbackArguments, + mut retval: v8::ReturnValue, +) { + let value = if args.length() == 0 { + v8::undefined(scope).into() + } else { + args.get(0) + }; + let text = match serialize_output_text(scope, value) { + Ok(text) => text, + Err(error_text) => { + throw_type_error(scope, &error_text); + return; + } + }; + if text.trim().is_empty() { + throw_type_error(scope, "notify expects non-empty text"); + return; + } + if let Some(state) = scope.get_slot::() { + let _ = state.event_tx.send(RuntimeEvent::Notify { + call_id: state.tool_call_id.clone(), + text, + }); + } + retval.set(v8::undefined(scope).into()); +} + +pub(super) fn yield_control_callback( + scope: &mut v8::PinScope<'_, '_>, + _args: v8::FunctionCallbackArguments, + _retval: v8::ReturnValue, +) { + if let Some(state) = scope.get_slot::() { + let _ = state.event_tx.send(RuntimeEvent::YieldRequested); + } +} + +pub(super) fn exit_callback( + scope: &mut v8::PinScope<'_, '_>, + _args: v8::FunctionCallbackArguments, + _retval: v8::ReturnValue, +) { + if let Some(state) = scope.get_slot_mut::() { + state.exit_requested = true; + } + if let Some(error) = v8::String::new(scope, EXIT_SENTINEL) { + scope.throw_exception(error.into()); + } +} diff --git a/codex-rs/code-mode/src/runtime/globals.rs b/codex-rs/code-mode/src/runtime/globals.rs new file mode 100644 index 000000000..371479497 --- /dev/null +++ b/codex-rs/code-mode/src/runtime/globals.rs @@ -0,0 +1,138 @@ +use super::RuntimeState; +use super::callbacks::exit_callback; +use super::callbacks::image_callback; +use super::callbacks::load_callback; +use super::callbacks::notify_callback; +use super::callbacks::store_callback; +use super::callbacks::text_callback; +use super::callbacks::tool_callback; +use super::callbacks::yield_control_callback; + +pub(super) fn install_globals(scope: &mut v8::PinScope<'_, '_>) -> Result<(), String> { + let global = scope.get_current_context().global(scope); + let console = v8::String::new(scope, "console") + .ok_or_else(|| "failed to allocate global `console`".to_string())?; + if global.delete(scope, console.into()) != Some(true) { + return Err("failed to remove global `console`".to_string()); + } + + let tools = build_tools_object(scope)?; + let all_tools = build_all_tools_value(scope)?; + let text = helper_function(scope, "text", text_callback)?; + let image = helper_function(scope, "image", image_callback)?; + let store = helper_function(scope, "store", store_callback)?; + let load = helper_function(scope, "load", load_callback)?; + let notify = helper_function(scope, "notify", notify_callback)?; + let yield_control = helper_function(scope, "yield_control", yield_control_callback)?; + let exit = helper_function(scope, "exit", exit_callback)?; + + set_global(scope, global, "tools", tools.into())?; + set_global(scope, global, "ALL_TOOLS", all_tools)?; + set_global(scope, global, "text", text.into())?; + set_global(scope, global, "image", image.into())?; + set_global(scope, global, "store", store.into())?; + set_global(scope, global, "load", load.into())?; + set_global(scope, global, "notify", notify.into())?; + set_global(scope, global, "yield_control", yield_control.into())?; + set_global(scope, global, "exit", exit.into())?; + Ok(()) +} + +fn build_tools_object<'s>( + scope: &mut v8::PinScope<'s, '_>, +) -> Result, String> { + let tools = v8::Object::new(scope); + let enabled_tools = scope + .get_slot::() + .map(|state| state.enabled_tools.clone()) + .unwrap_or_default(); + + for tool in enabled_tools { + let name = v8::String::new(scope, &tool.global_name) + .ok_or_else(|| "failed to allocate tool name".to_string())?; + let function = tool_function(scope, &tool.tool_name)?; + tools.set(scope, name.into(), function.into()); + } + Ok(tools) +} + +fn build_all_tools_value<'s>( + scope: &mut v8::PinScope<'s, '_>, +) -> Result, String> { + let enabled_tools = scope + .get_slot::() + .map(|state| state.enabled_tools.clone()) + .unwrap_or_default(); + let array = v8::Array::new(scope, enabled_tools.len() as i32); + let name_key = v8::String::new(scope, "name") + .ok_or_else(|| "failed to allocate ALL_TOOLS name key".to_string())?; + let description_key = v8::String::new(scope, "description") + .ok_or_else(|| "failed to allocate ALL_TOOLS description key".to_string())?; + + for (index, tool) in enabled_tools.iter().enumerate() { + let item = v8::Object::new(scope); + let name = v8::String::new(scope, &tool.global_name) + .ok_or_else(|| "failed to allocate ALL_TOOLS name".to_string())?; + let description = v8::String::new(scope, &tool.description) + .ok_or_else(|| "failed to allocate ALL_TOOLS description".to_string())?; + + if item.set(scope, name_key.into(), name.into()) != Some(true) { + return Err("failed to set ALL_TOOLS name".to_string()); + } + if item.set(scope, description_key.into(), description.into()) != Some(true) { + return Err("failed to set ALL_TOOLS description".to_string()); + } + if array.set_index(scope, index as u32, item.into()) != Some(true) { + return Err("failed to append ALL_TOOLS metadata".to_string()); + } + } + + Ok(array.into()) +} + +fn helper_function<'s, F>( + scope: &mut v8::PinScope<'s, '_>, + name: &str, + callback: F, +) -> Result, String> +where + F: v8::MapFnTo, +{ + let name = + v8::String::new(scope, name).ok_or_else(|| "failed to allocate helper name".to_string())?; + let template = v8::FunctionTemplate::builder(callback) + .data(name.into()) + .build(scope); + template + .get_function(scope) + .ok_or_else(|| "failed to create helper function".to_string()) +} + +fn tool_function<'s>( + scope: &mut v8::PinScope<'s, '_>, + tool_name: &str, +) -> Result, String> { + let data = v8::String::new(scope, tool_name) + .ok_or_else(|| "failed to allocate tool callback data".to_string())?; + let template = v8::FunctionTemplate::builder(tool_callback) + .data(data.into()) + .build(scope); + template + .get_function(scope) + .ok_or_else(|| "failed to create tool function".to_string()) +} + +fn set_global<'s>( + scope: &mut v8::PinScope<'s, '_>, + global: v8::Local<'s, v8::Object>, + name: &str, + value: v8::Local<'s, v8::Value>, +) -> Result<(), String> { + let key = v8::String::new(scope, name) + .ok_or_else(|| format!("failed to allocate global `{name}`"))?; + if global.set(scope, key.into(), value) == Some(true) { + Ok(()) + } else { + Err(format!("failed to set global `{name}`")) + } +} diff --git a/codex-rs/code-mode/src/runtime/mod.rs b/codex-rs/code-mode/src/runtime/mod.rs new file mode 100644 index 000000000..df90eda67 --- /dev/null +++ b/codex-rs/code-mode/src/runtime/mod.rs @@ -0,0 +1,349 @@ +mod callbacks; +mod globals; +mod module_loader; +mod value; + +use std::collections::HashMap; +use std::sync::OnceLock; +use std::sync::mpsc as std_mpsc; +use std::thread; + +use serde_json::Value as JsonValue; +use tokio::sync::mpsc; + +use crate::description::EnabledToolMetadata; +use crate::description::ToolDefinition; +use crate::description::enabled_tool_metadata; +use crate::response::FunctionCallOutputContentItem; + +pub const DEFAULT_EXEC_YIELD_TIME_MS: u64 = 10_000; +pub const DEFAULT_WAIT_YIELD_TIME_MS: u64 = 10_000; +pub const DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL: usize = 10_000; +const EXIT_SENTINEL: &str = "__codex_code_mode_exit__"; + +#[derive(Clone, Debug)] +pub struct ExecuteRequest { + pub tool_call_id: String, + pub enabled_tools: Vec, + pub source: String, + pub stored_values: HashMap, + pub yield_time_ms: Option, + pub max_output_tokens: Option, +} + +#[derive(Clone, Debug)] +pub struct WaitRequest { + pub cell_id: String, + pub yield_time_ms: u64, + pub terminate: bool, +} + +#[derive(Debug, PartialEq)] +pub enum RuntimeResponse { + Yielded { + cell_id: String, + content_items: Vec, + }, + Terminated { + cell_id: String, + content_items: Vec, + }, + Result { + cell_id: String, + content_items: Vec, + stored_values: HashMap, + error_text: Option, + }, +} + +#[derive(Debug)] +pub(crate) enum TurnMessage { + ToolCall { + cell_id: String, + id: String, + name: String, + input: Option, + }, + Notify { + cell_id: String, + call_id: String, + text: String, + }, +} + +#[derive(Debug)] +pub(crate) enum RuntimeCommand { + ToolResponse { id: String, result: JsonValue }, + ToolError { id: String, error_text: String }, + Terminate, +} + +#[derive(Debug)] +pub(crate) enum RuntimeEvent { + Started, + ContentItem(FunctionCallOutputContentItem), + YieldRequested, + ToolCall { + id: String, + name: String, + input: Option, + }, + Notify { + call_id: String, + text: String, + }, + Result { + stored_values: HashMap, + error_text: Option, + }, +} + +pub(crate) fn spawn_runtime( + request: ExecuteRequest, + event_tx: mpsc::UnboundedSender, +) -> Result<(std_mpsc::Sender, v8::IsolateHandle), String> { + let (command_tx, command_rx) = std_mpsc::channel(); + let (isolate_handle_tx, isolate_handle_rx) = std_mpsc::sync_channel(1); + let enabled_tools = request + .enabled_tools + .iter() + .map(enabled_tool_metadata) + .collect::>(); + let config = RuntimeConfig { + tool_call_id: request.tool_call_id, + enabled_tools, + source: request.source, + stored_values: request.stored_values, + }; + + thread::spawn(move || { + run_runtime(config, event_tx, command_rx, isolate_handle_tx); + }); + + let isolate_handle = isolate_handle_rx + .recv() + .map_err(|_| "failed to initialize code mode runtime".to_string())?; + Ok((command_tx, isolate_handle)) +} + +#[derive(Clone)] +struct RuntimeConfig { + tool_call_id: String, + enabled_tools: Vec, + source: String, + stored_values: HashMap, +} + +pub(super) struct RuntimeState { + event_tx: mpsc::UnboundedSender, + pending_tool_calls: HashMap>, + stored_values: HashMap, + enabled_tools: Vec, + next_tool_call_id: u64, + tool_call_id: String, + exit_requested: bool, +} + +pub(super) enum CompletionState { + Pending, + Completed { + stored_values: HashMap, + error_text: Option, + }, +} + +fn initialize_v8() { + static PLATFORM: OnceLock> = OnceLock::new(); + + let _ = PLATFORM.get_or_init(|| { + let platform = v8::new_default_platform(0, false).make_shared(); + v8::V8::initialize_platform(platform.clone()); + v8::V8::initialize(); + platform + }); +} + +fn run_runtime( + config: RuntimeConfig, + event_tx: mpsc::UnboundedSender, + command_rx: std_mpsc::Receiver, + isolate_handle_tx: std_mpsc::SyncSender, +) { + initialize_v8(); + + let isolate = &mut v8::Isolate::new(v8::CreateParams::default()); + let isolate_handle = isolate.thread_safe_handle(); + if isolate_handle_tx.send(isolate_handle).is_err() { + return; + } + isolate.set_host_import_module_dynamically_callback(module_loader::dynamic_import_callback); + + v8::scope!(let scope, isolate); + let context = v8::Context::new(scope, Default::default()); + let scope = &mut v8::ContextScope::new(scope, context); + + scope.set_slot(RuntimeState { + event_tx: event_tx.clone(), + pending_tool_calls: HashMap::new(), + stored_values: config.stored_values, + enabled_tools: config.enabled_tools, + next_tool_call_id: 1, + tool_call_id: config.tool_call_id, + exit_requested: false, + }); + + if let Err(error_text) = globals::install_globals(scope) { + send_result(&event_tx, HashMap::new(), Some(error_text)); + return; + } + + let _ = event_tx.send(RuntimeEvent::Started); + + let pending_promise = match module_loader::evaluate_main_module(scope, &config.source) { + Ok(pending_promise) => pending_promise, + Err(error_text) => { + capture_scope_send_error(scope, &event_tx, Some(error_text)); + return; + } + }; + + match module_loader::completion_state(scope, pending_promise.as_ref()) { + CompletionState::Completed { + stored_values, + error_text, + } => { + send_result(&event_tx, stored_values, error_text); + return; + } + CompletionState::Pending => {} + } + + let mut pending_promise = pending_promise; + loop { + let Ok(command) = command_rx.recv() else { + break; + }; + match command { + RuntimeCommand::Terminate => break, + RuntimeCommand::ToolResponse { id, result } => { + if let Err(error_text) = + module_loader::resolve_tool_response(scope, &id, Ok(result)) + { + capture_scope_send_error(scope, &event_tx, Some(error_text)); + return; + } + } + RuntimeCommand::ToolError { id, error_text } => { + if let Err(runtime_error) = + module_loader::resolve_tool_response(scope, &id, Err(error_text)) + { + capture_scope_send_error(scope, &event_tx, Some(runtime_error)); + return; + } + } + } + + scope.perform_microtask_checkpoint(); + match module_loader::completion_state(scope, pending_promise.as_ref()) { + CompletionState::Completed { + stored_values, + error_text, + } => { + send_result(&event_tx, stored_values, error_text); + return; + } + CompletionState::Pending => {} + } + + if let Some(promise) = pending_promise.as_ref() { + let promise = v8::Local::new(scope, promise); + if promise.state() != v8::PromiseState::Pending { + pending_promise = None; + } + } + } +} + +fn capture_scope_send_error( + scope: &mut v8::PinScope<'_, '_>, + event_tx: &mpsc::UnboundedSender, + error_text: Option, +) { + let stored_values = scope + .get_slot::() + .map(|state| state.stored_values.clone()) + .unwrap_or_default(); + + send_result(event_tx, stored_values, error_text); +} + +fn send_result( + event_tx: &mpsc::UnboundedSender, + stored_values: HashMap, + error_text: Option, +) { + let _ = event_tx.send(RuntimeEvent::Result { + stored_values, + error_text, + }); +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::time::Duration; + + use pretty_assertions::assert_eq; + use tokio::sync::mpsc; + + use super::ExecuteRequest; + use super::RuntimeEvent; + use super::spawn_runtime; + + fn execute_request(source: &str) -> ExecuteRequest { + ExecuteRequest { + tool_call_id: "call_1".to_string(), + enabled_tools: Vec::new(), + source: source.to_string(), + stored_values: HashMap::new(), + yield_time_ms: Some(1), + max_output_tokens: None, + } + } + + #[tokio::test] + async fn terminate_execution_stops_cpu_bound_module() { + let (event_tx, mut event_rx) = mpsc::unbounded_channel(); + let (_runtime_tx, runtime_terminate_handle) = + spawn_runtime(execute_request("while (true) {}"), event_tx).unwrap(); + + let started_event = tokio::time::timeout(Duration::from_secs(1), event_rx.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(started_event, RuntimeEvent::Started)); + + assert!(runtime_terminate_handle.terminate_execution()); + + let result_event = tokio::time::timeout(Duration::from_secs(1), event_rx.recv()) + .await + .unwrap() + .unwrap(); + let RuntimeEvent::Result { + stored_values, + error_text, + } = result_event + else { + panic!("expected runtime result after termination"); + }; + assert_eq!(stored_values, HashMap::new()); + assert!(error_text.is_some()); + + assert!( + tokio::time::timeout(Duration::from_secs(1), event_rx.recv()) + .await + .unwrap() + .is_none() + ); + } +} diff --git a/codex-rs/code-mode/src/runtime/module_loader.rs b/codex-rs/code-mode/src/runtime/module_loader.rs new file mode 100644 index 000000000..83ce3d347 --- /dev/null +++ b/codex-rs/code-mode/src/runtime/module_loader.rs @@ -0,0 +1,235 @@ +use serde_json::Value as JsonValue; + +use super::CompletionState; +use super::EXIT_SENTINEL; +use super::RuntimeState; +use super::value::json_to_v8; +use super::value::value_to_error_text; + +pub(super) fn evaluate_main_module( + scope: &mut v8::PinScope<'_, '_>, + source_text: &str, +) -> Result>, String> { + let tc = std::pin::pin!(v8::TryCatch::new(scope)); + let mut tc = tc.init(); + let source = v8::String::new(&tc, source_text) + .ok_or_else(|| "failed to allocate exec source".to_string())?; + let origin = script_origin(&mut tc, "exec_main.mjs")?; + let mut source = v8::script_compiler::Source::new(source, Some(&origin)); + let module = v8::script_compiler::compile_module(&tc, &mut source).ok_or_else(|| { + tc.exception() + .map(|exception| value_to_error_text(&mut tc, exception)) + .unwrap_or_else(|| "unknown code mode exception".to_string()) + })?; + module + .instantiate_module(&tc, resolve_module_callback) + .ok_or_else(|| { + tc.exception() + .map(|exception| value_to_error_text(&mut tc, exception)) + .unwrap_or_else(|| "unknown code mode exception".to_string()) + })?; + let result = match module.evaluate(&tc) { + Some(result) => result, + None => { + if let Some(exception) = tc.exception() { + if is_exit_exception(&mut tc, exception) { + return Ok(None); + } + return Err(value_to_error_text(&mut tc, exception)); + } + return Err("unknown code mode exception".to_string()); + } + }; + tc.perform_microtask_checkpoint(); + + if result.is_promise() { + let promise = v8::Local::::try_from(result) + .map_err(|_| "failed to read exec promise".to_string())?; + return Ok(Some(v8::Global::new(&tc, promise))); + } + + Ok(None) +} + +fn is_exit_exception( + scope: &mut v8::PinScope<'_, '_>, + exception: v8::Local<'_, v8::Value>, +) -> bool { + scope + .get_slot::() + .map(|state| state.exit_requested) + .unwrap_or(false) + && exception.is_string() + && exception.to_rust_string_lossy(scope) == EXIT_SENTINEL +} + +pub(super) fn resolve_tool_response( + scope: &mut v8::PinScope<'_, '_>, + id: &str, + response: Result, +) -> Result<(), String> { + let resolver = { + let state = scope + .get_slot_mut::() + .ok_or_else(|| "runtime state unavailable".to_string())?; + state.pending_tool_calls.remove(id) + } + .ok_or_else(|| format!("unknown tool call `{id}`"))?; + + let tc = std::pin::pin!(v8::TryCatch::new(scope)); + let mut tc = tc.init(); + let resolver = v8::Local::new(&tc, &resolver); + match response { + Ok(result) => { + let value = json_to_v8(&mut tc, &result) + .ok_or_else(|| "failed to serialize tool response".to_string())?; + resolver.resolve(&tc, value); + } + Err(error_text) => { + let value = v8::String::new(&tc, &error_text) + .ok_or_else(|| "failed to allocate tool error".to_string())?; + resolver.reject(&tc, value.into()); + } + } + if tc.has_caught() { + return Err(tc + .exception() + .map(|exception| value_to_error_text(&mut tc, exception)) + .unwrap_or_else(|| "unknown code mode exception".to_string())); + } + Ok(()) +} + +pub(super) fn completion_state( + scope: &mut v8::PinScope<'_, '_>, + pending_promise: Option<&v8::Global>, +) -> CompletionState { + let stored_values = scope + .get_slot::() + .map(|state| state.stored_values.clone()) + .unwrap_or_default(); + + let Some(pending_promise) = pending_promise else { + return CompletionState::Completed { + stored_values, + error_text: None, + }; + }; + + let promise = v8::Local::new(scope, pending_promise); + match promise.state() { + v8::PromiseState::Pending => CompletionState::Pending, + v8::PromiseState::Fulfilled => CompletionState::Completed { + stored_values, + error_text: None, + }, + v8::PromiseState::Rejected => { + let result = promise.result(scope); + let error_text = if is_exit_exception(scope, result) { + None + } else { + Some(value_to_error_text(scope, result)) + }; + CompletionState::Completed { + stored_values, + error_text, + } + } + } +} + +fn script_origin<'s>( + scope: &mut v8::PinScope<'s, '_>, + resource_name_: &str, +) -> Result, String> { + let resource_name = v8::String::new(scope, resource_name_) + .ok_or_else(|| "failed to allocate script origin".to_string())?; + let source_map_url = v8::String::new(scope, resource_name_) + .ok_or_else(|| "failed to allocate source map url".to_string())?; + Ok(v8::ScriptOrigin::new( + scope, + resource_name.into(), + 0, + 0, + true, + 0, + Some(source_map_url.into()), + true, + false, + true, + None, + )) +} + +fn resolve_module_callback<'s>( + context: v8::Local<'s, v8::Context>, + specifier: v8::Local<'s, v8::String>, + _import_attributes: v8::Local<'s, v8::FixedArray>, + _referrer: v8::Local<'s, v8::Module>, +) -> Option> { + v8::callback_scope!(unsafe scope, context); + let specifier = specifier.to_rust_string_lossy(scope); + resolve_module(scope, &specifier) +} + +pub(super) fn dynamic_import_callback<'s>( + scope: &mut v8::PinScope<'s, '_>, + _host_defined_options: v8::Local<'s, v8::Data>, + _resource_name: v8::Local<'s, v8::Value>, + specifier: v8::Local<'s, v8::String>, + _import_attributes: v8::Local<'s, v8::FixedArray>, +) -> Option> { + let specifier = specifier.to_rust_string_lossy(scope); + let resolver = v8::PromiseResolver::new(scope)?; + + match resolve_module(scope, &specifier) { + Some(module) => { + if module.get_status() == v8::ModuleStatus::Uninstantiated + && module + .instantiate_module(scope, resolve_module_callback) + .is_none() + { + let error = v8::String::new(scope, "failed to instantiate module") + .map(Into::into) + .unwrap_or_else(|| v8::undefined(scope).into()); + resolver.reject(scope, error); + return Some(resolver.get_promise(scope)); + } + if matches!( + module.get_status(), + v8::ModuleStatus::Instantiated | v8::ModuleStatus::Evaluated + ) && module.evaluate(scope).is_none() + { + let error = v8::String::new(scope, "failed to evaluate module") + .map(Into::into) + .unwrap_or_else(|| v8::undefined(scope).into()); + resolver.reject(scope, error); + return Some(resolver.get_promise(scope)); + } + let namespace = module.get_module_namespace(); + resolver.resolve(scope, namespace); + Some(resolver.get_promise(scope)) + } + None => { + let error = v8::String::new(scope, "unsupported import in exec") + .map(Into::into) + .unwrap_or_else(|| v8::undefined(scope).into()); + resolver.reject(scope, error); + Some(resolver.get_promise(scope)) + } + } +} + +fn resolve_module<'s>( + scope: &mut v8::PinScope<'s, '_>, + specifier: &str, +) -> Option> { + if let Some(message) = + v8::String::new(scope, &format!("Unsupported import in exec: {specifier}")) + { + scope.throw_exception(message.into()); + } else { + scope.throw_exception(v8::undefined(scope).into()); + } + None +} diff --git a/codex-rs/code-mode/src/runtime/value.rs b/codex-rs/code-mode/src/runtime/value.rs new file mode 100644 index 000000000..eb0280142 --- /dev/null +++ b/codex-rs/code-mode/src/runtime/value.rs @@ -0,0 +1,163 @@ +use serde_json::Value as JsonValue; + +use crate::response::FunctionCallOutputContentItem; +use crate::response::ImageDetail; + +pub(super) fn serialize_output_text( + scope: &mut v8::PinScope<'_, '_>, + value: v8::Local<'_, v8::Value>, +) -> Result { + if value.is_undefined() + || value.is_null() + || value.is_boolean() + || value.is_number() + || value.is_big_int() + || value.is_string() + { + return Ok(value.to_rust_string_lossy(scope)); + } + + let tc = std::pin::pin!(v8::TryCatch::new(scope)); + let mut tc = tc.init(); + if let Some(stringified) = v8::json::stringify(&tc, value) { + return Ok(stringified.to_rust_string_lossy(&tc)); + } + if tc.has_caught() { + return Err(tc + .exception() + .map(|exception| value_to_error_text(&mut tc, exception)) + .unwrap_or_else(|| "unknown code mode exception".to_string())); + } + Ok(value.to_rust_string_lossy(&tc)) +} + +pub(super) fn normalize_output_image( + scope: &mut v8::PinScope<'_, '_>, + value: v8::Local<'_, v8::Value>, +) -> Result { + let result = (|| -> Result { + let (image_url, detail) = if value.is_string() { + (value.to_rust_string_lossy(scope), None) + } else if value.is_object() && !value.is_array() { + let object = v8::Local::::try_from(value).map_err(|_| { + "image expects a non-empty image URL string or an object with image_url and optional detail".to_string() + })?; + let image_url_key = v8::String::new(scope, "image_url") + .ok_or_else(|| "failed to allocate image helper keys".to_string())?; + let detail_key = v8::String::new(scope, "detail") + .ok_or_else(|| "failed to allocate image helper keys".to_string())?; + let image_url = object + .get(scope, image_url_key.into()) + .filter(|value| value.is_string()) + .map(|value| value.to_rust_string_lossy(scope)) + .ok_or_else(|| { + "image expects a non-empty image URL string or an object with image_url and optional detail" + .to_string() + })?; + let detail = match object.get(scope, detail_key.into()) { + Some(value) if value.is_string() => Some(value.to_rust_string_lossy(scope)), + Some(value) if value.is_null() || value.is_undefined() => None, + Some(_) => return Err("image detail must be a string when provided".to_string()), + None => None, + }; + (image_url, detail) + } else { + return Err( + "image expects a non-empty image URL string or an object with image_url and optional detail" + .to_string(), + ); + }; + + if image_url.is_empty() { + return Err( + "image expects a non-empty image URL string or an object with image_url and optional detail" + .to_string(), + ); + } + let lower = image_url.to_ascii_lowercase(); + if !(lower.starts_with("http://") + || lower.starts_with("https://") + || lower.starts_with("data:")) + { + return Err("image expects an http(s) or data URL".to_string()); + } + + let detail = match detail { + Some(detail) => { + let normalized = detail.to_ascii_lowercase(); + Some(match normalized.as_str() { + "auto" => ImageDetail::Auto, + "low" => ImageDetail::Low, + "high" => ImageDetail::High, + "original" => ImageDetail::Original, + _ => { + return Err( + "image detail must be one of: auto, low, high, original".to_string() + ); + } + }) + } + None => None, + }; + + Ok(FunctionCallOutputContentItem::InputImage { image_url, detail }) + })(); + + match result { + Ok(item) => Ok(item), + Err(error_text) => { + throw_type_error(scope, &error_text); + Err(()) + } + } +} + +pub(super) fn v8_value_to_json( + scope: &mut v8::PinScope<'_, '_>, + value: v8::Local<'_, v8::Value>, +) -> Result, String> { + let tc = std::pin::pin!(v8::TryCatch::new(scope)); + let mut tc = tc.init(); + let Some(stringified) = v8::json::stringify(&tc, value) else { + if tc.has_caught() { + return Err(tc + .exception() + .map(|exception| value_to_error_text(&mut tc, exception)) + .unwrap_or_else(|| "unknown code mode exception".to_string())); + } + return Ok(None); + }; + serde_json::from_str(&stringified.to_rust_string_lossy(&tc)) + .map(Some) + .map_err(|err| format!("failed to serialize JavaScript value: {err}")) +} + +pub(super) fn json_to_v8<'s>( + scope: &mut v8::PinScope<'s, '_>, + value: &JsonValue, +) -> Option> { + let json = serde_json::to_string(value).ok()?; + let json = v8::String::new(scope, &json)?; + v8::json::parse(scope, json) +} + +pub(super) fn value_to_error_text( + scope: &mut v8::PinScope<'_, '_>, + value: v8::Local<'_, v8::Value>, +) -> String { + if value.is_object() + && let Ok(object) = v8::Local::::try_from(value) + && let Some(key) = v8::String::new(scope, "stack") + && let Some(stack) = object.get(scope, key.into()) + && stack.is_string() + { + return stack.to_rust_string_lossy(scope); + } + value.to_rust_string_lossy(scope) +} + +pub(super) fn throw_type_error(scope: &mut v8::PinScope<'_, '_>, message: &str) { + if let Some(message) = v8::String::new(scope, message) { + scope.throw_exception(message.into()); + } +} diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs new file mode 100644 index 000000000..260b891d3 --- /dev/null +++ b/codex-rs/code-mode/src/service.rs @@ -0,0 +1,673 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use async_trait::async_trait; +use serde_json::Value as JsonValue; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio_util::sync::CancellationToken; +use tracing::warn; + +use crate::FunctionCallOutputContentItem; +use crate::runtime::DEFAULT_EXEC_YIELD_TIME_MS; +use crate::runtime::ExecuteRequest; +use crate::runtime::RuntimeCommand; +use crate::runtime::RuntimeEvent; +use crate::runtime::RuntimeResponse; +use crate::runtime::TurnMessage; +use crate::runtime::WaitRequest; +use crate::runtime::spawn_runtime; + +#[async_trait] +pub trait CodeModeTurnHost: Send + Sync { + async fn invoke_tool( + &self, + tool_name: String, + input: Option, + cancellation_token: CancellationToken, + ) -> Result; + + async fn notify(&self, call_id: String, cell_id: String, text: String) -> Result<(), String>; +} + +#[derive(Clone)] +struct SessionHandle { + control_tx: mpsc::UnboundedSender, + runtime_tx: std::sync::mpsc::Sender, +} + +struct Inner { + stored_values: Mutex>, + sessions: Mutex>, + turn_message_tx: mpsc::UnboundedSender, + turn_message_rx: Arc>>, + next_cell_id: AtomicU64, +} + +pub struct CodeModeService { + inner: Arc, +} + +impl CodeModeService { + pub fn new() -> Self { + let (turn_message_tx, turn_message_rx) = mpsc::unbounded_channel(); + + Self { + inner: Arc::new(Inner { + stored_values: Mutex::new(HashMap::new()), + sessions: Mutex::new(HashMap::new()), + turn_message_tx, + turn_message_rx: Arc::new(Mutex::new(turn_message_rx)), + next_cell_id: AtomicU64::new(1), + }), + } + } + + pub async fn stored_values(&self) -> HashMap { + self.inner.stored_values.lock().await.clone() + } + + pub async fn replace_stored_values(&self, values: HashMap) { + *self.inner.stored_values.lock().await = values; + } + + pub async fn execute(&self, request: ExecuteRequest) -> Result { + let cell_id = self + .inner + .next_cell_id + .fetch_add(1, Ordering::Relaxed) + .to_string(); + let (event_tx, event_rx) = mpsc::unbounded_channel(); + let (runtime_tx, runtime_terminate_handle) = spawn_runtime(request.clone(), event_tx)?; + let (control_tx, control_rx) = mpsc::unbounded_channel(); + let (response_tx, response_rx) = oneshot::channel(); + + self.inner.sessions.lock().await.insert( + cell_id.clone(), + SessionHandle { + control_tx: control_tx.clone(), + runtime_tx: runtime_tx.clone(), + }, + ); + + tokio::spawn(run_session_control( + Arc::clone(&self.inner), + SessionControlContext { + cell_id: cell_id.clone(), + runtime_tx, + runtime_terminate_handle, + }, + event_rx, + control_rx, + response_tx, + request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS), + )); + + response_rx + .await + .map_err(|_| "exec runtime ended unexpectedly".to_string()) + } + + pub async fn wait(&self, request: WaitRequest) -> Result { + let cell_id = request.cell_id.clone(); + let handle = self + .inner + .sessions + .lock() + .await + .get(&request.cell_id) + .cloned(); + let Some(handle) = handle else { + return Ok(missing_cell_response(cell_id)); + }; + let (response_tx, response_rx) = oneshot::channel(); + let control_message = if request.terminate { + SessionControlCommand::Terminate { response_tx } + } else { + SessionControlCommand::Poll { + yield_time_ms: request.yield_time_ms, + response_tx, + } + }; + if handle.control_tx.send(control_message).is_err() { + return Ok(missing_cell_response(cell_id)); + } + match response_rx.await { + Ok(response) => Ok(response), + Err(_) => Ok(missing_cell_response(request.cell_id)), + } + } + + pub fn start_turn_worker(&self, host: Arc) -> CodeModeTurnWorker { + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let inner = Arc::clone(&self.inner); + let turn_message_rx = Arc::clone(&self.inner.turn_message_rx); + + tokio::spawn(async move { + loop { + let next_message = tokio::select! { + _ = &mut shutdown_rx => break, + message = async { + let mut turn_message_rx = turn_message_rx.lock().await; + turn_message_rx.recv().await + } => message, + }; + let Some(next_message) = next_message else { + break; + }; + match next_message { + TurnMessage::Notify { + cell_id, + call_id, + text, + } => { + if let Err(err) = host.notify(call_id, cell_id.clone(), text).await { + warn!( + "failed to deliver code mode notification for cell {cell_id}: {err}" + ); + } + } + TurnMessage::ToolCall { + cell_id, + id, + name, + input, + } => { + let host = Arc::clone(&host); + let inner = Arc::clone(&inner); + tokio::spawn(async move { + let response = host + .invoke_tool(name, input, CancellationToken::new()) + .await; + let runtime_tx = inner + .sessions + .lock() + .await + .get(&cell_id) + .map(|handle| handle.runtime_tx.clone()); + let Some(runtime_tx) = runtime_tx else { + return; + }; + let command = match response { + Ok(result) => RuntimeCommand::ToolResponse { id, result }, + Err(error_text) => RuntimeCommand::ToolError { id, error_text }, + }; + let _ = runtime_tx.send(command); + }); + } + } + } + }); + + CodeModeTurnWorker { + shutdown_tx: Some(shutdown_tx), + } + } +} + +impl Default for CodeModeService { + fn default() -> Self { + Self::new() + } +} + +pub struct CodeModeTurnWorker { + shutdown_tx: Option>, +} + +impl Drop for CodeModeTurnWorker { + fn drop(&mut self) { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + } +} + +enum SessionControlCommand { + Poll { + yield_time_ms: u64, + response_tx: oneshot::Sender, + }, + Terminate { + response_tx: oneshot::Sender, + }, +} + +struct PendingResult { + content_items: Vec, + stored_values: HashMap, + error_text: Option, +} + +struct SessionControlContext { + cell_id: String, + runtime_tx: std::sync::mpsc::Sender, + runtime_terminate_handle: v8::IsolateHandle, +} + +fn missing_cell_response(cell_id: String) -> RuntimeResponse { + RuntimeResponse::Result { + error_text: Some(format!("exec cell {cell_id} not found")), + cell_id, + content_items: Vec::new(), + stored_values: HashMap::new(), + } +} + +fn pending_result_response(cell_id: &str, result: PendingResult) -> RuntimeResponse { + RuntimeResponse::Result { + cell_id: cell_id.to_string(), + content_items: result.content_items, + stored_values: result.stored_values, + error_text: result.error_text, + } +} + +fn send_or_buffer_result( + cell_id: &str, + result: PendingResult, + response_tx: &mut Option>, + pending_result: &mut Option, +) -> bool { + if let Some(response_tx) = response_tx.take() { + let _ = response_tx.send(pending_result_response(cell_id, result)); + return true; + } + + *pending_result = Some(result); + false +} + +async fn run_session_control( + inner: Arc, + context: SessionControlContext, + mut event_rx: mpsc::UnboundedReceiver, + mut control_rx: mpsc::UnboundedReceiver, + initial_response_tx: oneshot::Sender, + initial_yield_time_ms: u64, +) { + let SessionControlContext { + cell_id, + runtime_tx, + runtime_terminate_handle, + } = context; + let mut content_items = Vec::new(); + let mut pending_result: Option = None; + let mut response_tx = Some(initial_response_tx); + let mut termination_requested = false; + let mut runtime_closed = false; + let mut yield_timer: Option>> = None; + + loop { + tokio::select! { + maybe_event = async { + if runtime_closed { + std::future::pending::>().await + } else { + event_rx.recv().await + } + } => { + let Some(event) = maybe_event else { + runtime_closed = true; + if termination_requested { + if let Some(response_tx) = response_tx.take() { + let _ = response_tx.send(RuntimeResponse::Terminated { + cell_id: cell_id.clone(), + content_items: std::mem::take(&mut content_items), + }); + } + break; + } + if pending_result.is_none() { + let result = PendingResult { + content_items: std::mem::take(&mut content_items), + stored_values: HashMap::new(), + error_text: Some("exec runtime ended unexpectedly".to_string()), + }; + if send_or_buffer_result( + &cell_id, + result, + &mut response_tx, + &mut pending_result, + ) { + break; + } + } + continue; + }; + match event { + RuntimeEvent::Started => { + yield_timer = Some(Box::pin(tokio::time::sleep(Duration::from_millis(initial_yield_time_ms)))); + } + RuntimeEvent::ContentItem(item) => { + content_items.push(item); + } + RuntimeEvent::YieldRequested => { + yield_timer = None; + if let Some(response_tx) = response_tx.take() { + let _ = response_tx.send(RuntimeResponse::Yielded { + cell_id: cell_id.clone(), + content_items: std::mem::take(&mut content_items), + }); + } + } + RuntimeEvent::Notify { call_id, text } => { + let _ = inner.turn_message_tx.send(TurnMessage::Notify { + cell_id: cell_id.clone(), + call_id, + text, + }); + } + RuntimeEvent::ToolCall { id, name, input } => { + let _ = inner.turn_message_tx.send(TurnMessage::ToolCall { + cell_id: cell_id.clone(), + id, + name, + input, + }); + } + RuntimeEvent::Result { + stored_values, + error_text, + } => { + yield_timer = None; + if termination_requested { + if let Some(response_tx) = response_tx.take() { + let _ = response_tx.send(RuntimeResponse::Terminated { + cell_id: cell_id.clone(), + content_items: std::mem::take(&mut content_items), + }); + } + break; + } + let result = PendingResult { + content_items: std::mem::take(&mut content_items), + stored_values, + error_text, + }; + if send_or_buffer_result( + &cell_id, + result, + &mut response_tx, + &mut pending_result, + ) { + break; + } + } + } + } + maybe_command = control_rx.recv() => { + let Some(command) = maybe_command else { + break; + }; + match command { + SessionControlCommand::Poll { + yield_time_ms, + response_tx: next_response_tx, + } => { + if let Some(result) = pending_result.take() { + let _ = next_response_tx.send(pending_result_response(&cell_id, result)); + break; + } + response_tx = Some(next_response_tx); + yield_timer = Some(Box::pin(tokio::time::sleep(Duration::from_millis(yield_time_ms)))); + } + SessionControlCommand::Terminate { response_tx: next_response_tx } => { + if let Some(result) = pending_result.take() { + let _ = next_response_tx.send(pending_result_response(&cell_id, result)); + break; + } + + response_tx = Some(next_response_tx); + termination_requested = true; + yield_timer = None; + let _ = runtime_tx.send(RuntimeCommand::Terminate); + let _ = runtime_terminate_handle.terminate_execution(); + if runtime_closed { + if let Some(response_tx) = response_tx.take() { + let _ = response_tx.send(RuntimeResponse::Terminated { + cell_id: cell_id.clone(), + content_items: std::mem::take(&mut content_items), + }); + } + break; + } else { + continue; + } + } + } + } + _ = async { + if let Some(yield_timer) = yield_timer.as_mut() { + yield_timer.await; + } else { + std::future::pending::<()>().await; + } + } => { + yield_timer = None; + if let Some(response_tx) = response_tx.take() { + let _ = response_tx.send(RuntimeResponse::Yielded { + cell_id: cell_id.clone(), + content_items: std::mem::take(&mut content_items), + }); + } + } + } + } + + let _ = runtime_tx.send(RuntimeCommand::Terminate); + inner.sessions.lock().await.remove(&cell_id); +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use std::sync::atomic::AtomicU64; + use std::time::Duration; + + use pretty_assertions::assert_eq; + use tokio::sync::Mutex; + use tokio::sync::mpsc; + use tokio::sync::oneshot; + + use super::CodeModeService; + use super::Inner; + use super::RuntimeCommand; + use super::RuntimeResponse; + use super::SessionControlCommand; + use super::SessionControlContext; + use super::run_session_control; + use crate::FunctionCallOutputContentItem; + use crate::runtime::ExecuteRequest; + use crate::runtime::RuntimeEvent; + use crate::runtime::spawn_runtime; + + fn execute_request(source: &str) -> ExecuteRequest { + ExecuteRequest { + tool_call_id: "call_1".to_string(), + enabled_tools: Vec::new(), + source: source.to_string(), + stored_values: HashMap::new(), + yield_time_ms: Some(1), + max_output_tokens: None, + } + } + + fn test_inner() -> Arc { + let (turn_message_tx, turn_message_rx) = mpsc::unbounded_channel(); + Arc::new(Inner { + stored_values: Mutex::new(HashMap::new()), + sessions: Mutex::new(HashMap::new()), + turn_message_tx, + turn_message_rx: Arc::new(Mutex::new(turn_message_rx)), + next_cell_id: AtomicU64::new(1), + }) + } + + #[tokio::test] + async fn synchronous_exit_returns_successfully() { + let service = CodeModeService::new(); + + let response = service + .execute(ExecuteRequest { + source: r#"text("before"); exit(); text("after");"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }) + .await + .unwrap(); + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: "1".to_string(), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "before".to_string(), + }], + stored_values: HashMap::new(), + error_text: None, + } + ); + } + + #[tokio::test] + async fn v8_console_is_not_exposed_on_global_this() { + let service = CodeModeService::new(); + + let response = service + .execute(ExecuteRequest { + source: r#"text(String(Object.hasOwn(globalThis, "console")));"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }) + .await + .unwrap(); + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: "1".to_string(), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "false".to_string(), + }], + stored_values: HashMap::new(), + error_text: None, + } + ); + } + + #[tokio::test] + async fn output_helpers_return_undefined() { + let service = CodeModeService::new(); + + let response = service + .execute(ExecuteRequest { + source: r#" +const returnsUndefined = [ + text("first"), + image("https://example.com/image.jpg"), + notify("ping"), +].map((value) => value === undefined); +text(JSON.stringify(returnsUndefined)); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }) + .await + .unwrap(); + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: "1".to_string(), + content_items: vec![ + FunctionCallOutputContentItem::InputText { + text: "first".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "https://example.com/image.jpg".to_string(), + detail: None, + }, + FunctionCallOutputContentItem::InputText { + text: "[true,true,true]".to_string(), + }, + ], + stored_values: HashMap::new(), + error_text: None, + } + ); + } + + #[tokio::test] + async fn terminate_waits_for_runtime_shutdown_before_responding() { + let inner = test_inner(); + let (event_tx, event_rx) = mpsc::unbounded_channel(); + let (control_tx, control_rx) = mpsc::unbounded_channel(); + let (initial_response_tx, initial_response_rx) = oneshot::channel(); + let (runtime_event_tx, _runtime_event_rx) = mpsc::unbounded_channel(); + let (runtime_tx, runtime_terminate_handle) = spawn_runtime( + ExecuteRequest { + source: "await new Promise(() => {})".to_string(), + yield_time_ms: None, + ..execute_request("") + }, + runtime_event_tx, + ) + .unwrap(); + + tokio::spawn(run_session_control( + inner, + SessionControlContext { + cell_id: "cell-1".to_string(), + runtime_tx: runtime_tx.clone(), + runtime_terminate_handle, + }, + event_rx, + control_rx, + initial_response_tx, + 60_000, + )); + + event_tx.send(RuntimeEvent::Started).unwrap(); + event_tx.send(RuntimeEvent::YieldRequested).unwrap(); + assert_eq!( + initial_response_rx.await.unwrap(), + RuntimeResponse::Yielded { + cell_id: "cell-1".to_string(), + content_items: Vec::new(), + } + ); + + let (terminate_response_tx, terminate_response_rx) = oneshot::channel(); + control_tx + .send(SessionControlCommand::Terminate { + response_tx: terminate_response_tx, + }) + .unwrap(); + let terminate_response = async { terminate_response_rx.await.unwrap() }; + tokio::pin!(terminate_response); + assert!( + tokio::time::timeout(Duration::from_millis(100), terminate_response.as_mut()) + .await + .is_err() + ); + + drop(event_tx); + + assert_eq!( + terminate_response.await, + RuntimeResponse::Terminated { + cell_id: "cell-1".to_string(), + content_items: Vec::new(), + } + ); + + let _ = runtime_tx.send(RuntimeCommand::Terminate); + } +} diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index d648655b2..6386c8d42 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -31,6 +31,7 @@ codex-api = { workspace = true } codex-app-server-protocol = { workspace = true } codex-apply-patch = { workspace = true } codex-async-utils = { workspace = true } +codex-code-mode = { workspace = true } codex-connectors = { workspace = true } codex-config = { workspace = true } codex-exec-server = { workspace = true } diff --git a/codex-rs/core/src/tools/code_mode/bridge.js b/codex-rs/core/src/tools/code_mode/bridge.js deleted file mode 100644 index 0c61a9db1..000000000 --- a/codex-rs/core/src/tools/code_mode/bridge.js +++ /dev/null @@ -1,51 +0,0 @@ -const __codexContentItems = Array.isArray(globalThis.__codexContentItems) - ? globalThis.__codexContentItems - : []; -const __codexRuntime = globalThis.__codexRuntime; - -delete globalThis.__codexRuntime; - -Object.defineProperty(globalThis, '__codexContentItems', { - value: __codexContentItems, - configurable: true, - enumerable: false, - writable: false, -}); - -(() => { - if (!__codexRuntime || typeof __codexRuntime !== 'object') { - throw new Error('code mode runtime is unavailable'); - } - - function defineGlobal(name, value) { - Object.defineProperty(globalThis, name, { - value, - configurable: true, - enumerable: true, - writable: false, - }); - } - - defineGlobal('ALL_TOOLS', __codexRuntime.ALL_TOOLS); - defineGlobal('exit', __codexRuntime.exit); - defineGlobal('image', __codexRuntime.image); - defineGlobal('load', __codexRuntime.load); - defineGlobal('notify', __codexRuntime.notify); - defineGlobal('store', __codexRuntime.store); - defineGlobal('text', __codexRuntime.text); - defineGlobal('tools', __codexRuntime.tools); - defineGlobal('yield_control', __codexRuntime.yield_control); - - defineGlobal( - 'console', - Object.freeze({ - log() {}, - info() {}, - warn() {}, - error() {}, - debug() {}, - }) - ); -})(); - -__CODE_MODE_USER_CODE_PLACEHOLDER__ diff --git a/codex-rs/core/src/tools/code_mode/description.md b/codex-rs/core/src/tools/code_mode/description.md deleted file mode 100644 index e0a124c65..000000000 --- a/codex-rs/core/src/tools/code_mode/description.md +++ /dev/null @@ -1,19 +0,0 @@ -## exec -- Runs raw JavaScript in an isolated context (no Node, no file system, or network access, no console). -- Send raw JavaScript source text, not JSON, quoted strings, or markdown code fences. -- You may optionally start the tool input with a first-line pragma like `// @exec: {"yield_time_ms": 10000, "max_output_tokens": 1000}`. -- `yield_time_ms` asks `exec` to yield early after that many milliseconds if the script is still running. -- `max_output_tokens` sets the token budget for direct `exec` results. By default the result is truncated to 10000 tokens. -- All nested tools are available on the global `tools` object, for example `await tools.exec_command(...)`. Tool names are exposed as normalized JavaScript identifiers, for example `await tools.mcp__ologs__get_profile(...)`. -- Tool methods take either string or object as parameter. -- They return either a structured value or a string based on the description above. - -- Global helpers: -- `exit()`: Immediately ends the current script successfully (like an early return from the top level). -- `text(value: string | number | boolean | undefined | null)`: Appends a text item and returns it. Non-string values are stringified with `JSON.stringify(...)` when possible. -- `image(imageUrlOrItem: string | { image_url: string; detail?: "auto" | "low" | "high" | "original" | null })`: Appends an image item and returns it. `image_url` can be an HTTPS URL or a base64-encoded `data:` URL. -- `store(key: string, value: any)`: stores a serializable value under a string key for later `exec` calls in the same session. -- `load(key: string)`: returns the stored value for a string key, or `undefined` if it is missing. -- `notify(value: string | number | boolean | undefined | null)`: immediately injects an extra `custom_tool_call_output` for the current `exec` call. Values are stringified like `text(...)`. -- `ALL_TOOLS`: metadata for the enabled nested tools as `{ name, description }` entries. -- `yield_control()`: yields the accumulated output to the model immediately while the script keeps running. diff --git a/codex-rs/core/src/tools/code_mode/execute_handler.rs b/codex-rs/core/src/tools/code_mode/execute_handler.rs index 9eba126dd..3f77216c1 100644 --- a/codex-rs/core/src/tools/code_mode/execute_handler.rs +++ b/codex-rs/core/src/tools/code_mode/execute_handler.rs @@ -1,8 +1,5 @@ use async_trait::async_trait; -use serde::Deserialize; -use crate::codex::Session; -use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; @@ -10,180 +7,52 @@ use crate::tools::context::ToolPayload; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use super::CODE_MODE_PRAGMA_PREFIX; -use super::CodeModeSessionProgress; use super::ExecContext; use super::PUBLIC_TOOL_NAME; use super::build_enabled_tools; -use super::handle_node_message; -use super::protocol::HostToNodeMessage; -use super::protocol::build_source; +use super::handle_runtime_response; pub struct CodeModeExecuteHandler; -const MAX_JS_SAFE_INTEGER: u64 = (1_u64 << 53) - 1; - -#[derive(Debug, Default, Deserialize, PartialEq, Eq)] -#[serde(deny_unknown_fields)] -struct CodeModeExecPragma { - #[serde(default)] - yield_time_ms: Option, - #[serde(default)] - max_output_tokens: Option, -} - -#[derive(Debug, PartialEq, Eq)] -struct CodeModeExecArgs { - code: String, - yield_time_ms: Option, - max_output_tokens: Option, -} impl CodeModeExecuteHandler { async fn execute( &self, - session: std::sync::Arc, - turn: std::sync::Arc, + session: std::sync::Arc, + turn: std::sync::Arc, call_id: String, code: String, ) -> Result { - let args = parse_freeform_args(&code)?; + let args = + codex_code_mode::parse_exec_source(&code).map_err(FunctionCallError::RespondToModel)?; let exec = ExecContext { session, turn }; let enabled_tools = build_enabled_tools(&exec).await; - let service = &exec.session.services.code_mode_service; - let stored_values = service.stored_values().await; - let source = - build_source(&args.code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?; - let cell_id = service.allocate_cell_id().await; - let request_id = service.allocate_request_id().await; - let process_slot = service - .ensure_started() - .await - .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?; + let stored_values = exec + .session + .services + .code_mode_service + .stored_values() + .await; let started_at = std::time::Instant::now(); - let message = HostToNodeMessage::Start { - request_id: request_id.clone(), - cell_id: cell_id.clone(), - tool_call_id: call_id, - default_yield_time_ms: super::DEFAULT_EXEC_YIELD_TIME_MS, - enabled_tools, - stored_values, - source, - yield_time_ms: args.yield_time_ms, - max_output_tokens: args.max_output_tokens, - }; - let result = { - let mut process_slot = process_slot; - let Some(process) = process_slot.as_mut() else { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner failed to start" - ))); - }; - let message = process - .send(&request_id, &message) - .await - .map_err(|err| err.to_string()); - let message = match message { - Ok(message) => message, - Err(error) => return Err(FunctionCallError::RespondToModel(error)), - }; - handle_node_message( - &exec, cell_id, message, /*poll_max_output_tokens*/ None, started_at, - ) + let response = exec + .session + .services + .code_mode_service + .execute(codex_code_mode::ExecuteRequest { + tool_call_id: call_id, + enabled_tools, + source: args.code, + stored_values, + yield_time_ms: args.yield_time_ms, + max_output_tokens: args.max_output_tokens, + }) .await - }; - match result { - Ok(CodeModeSessionProgress::Finished(output)) - | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), - Err(error) => Err(FunctionCallError::RespondToModel(error)), - } + .map_err(FunctionCallError::RespondToModel)?; + handle_runtime_response(&exec, response, args.max_output_tokens, started_at) + .await + .map_err(FunctionCallError::RespondToModel) } } -fn parse_freeform_args(input: &str) -> Result { - if input.trim().is_empty() { - return Err(FunctionCallError::RespondToModel( - "exec expects raw JavaScript source text (non-empty). Provide JS only, optionally with first-line `// @exec: {\"yield_time_ms\": 10000, \"max_output_tokens\": 1000}`.".to_string(), - )); - } - - let mut args = CodeModeExecArgs { - code: input.to_string(), - yield_time_ms: None, - max_output_tokens: None, - }; - - let mut lines = input.splitn(2, '\n'); - let first_line = lines.next().unwrap_or_default(); - let rest = lines.next().unwrap_or_default(); - let trimmed = first_line.trim_start(); - let Some(pragma) = trimmed.strip_prefix(CODE_MODE_PRAGMA_PREFIX) else { - return Ok(args); - }; - - if rest.trim().is_empty() { - return Err(FunctionCallError::RespondToModel( - "exec pragma must be followed by JavaScript source on subsequent lines".to_string(), - )); - } - - let directive = pragma.trim(); - if directive.is_empty() { - return Err(FunctionCallError::RespondToModel( - "exec pragma must be a JSON object with supported fields `yield_time_ms` and `max_output_tokens`" - .to_string(), - )); - } - - let value: serde_json::Value = serde_json::from_str(directive).map_err(|err| { - FunctionCallError::RespondToModel(format!( - "exec pragma must be valid JSON with supported fields `yield_time_ms` and `max_output_tokens`: {err}" - )) - })?; - let object = value.as_object().ok_or_else(|| { - FunctionCallError::RespondToModel( - "exec pragma must be a JSON object with supported fields `yield_time_ms` and `max_output_tokens`" - .to_string(), - ) - })?; - for key in object.keys() { - match key.as_str() { - "yield_time_ms" | "max_output_tokens" => {} - _ => { - return Err(FunctionCallError::RespondToModel(format!( - "exec pragma only supports `yield_time_ms` and `max_output_tokens`; got `{key}`" - ))); - } - } - } - - let pragma: CodeModeExecPragma = serde_json::from_value(value).map_err(|err| { - FunctionCallError::RespondToModel(format!( - "exec pragma fields `yield_time_ms` and `max_output_tokens` must be non-negative safe integers: {err}" - )) - })?; - if pragma - .yield_time_ms - .is_some_and(|yield_time_ms| yield_time_ms > MAX_JS_SAFE_INTEGER) - { - return Err(FunctionCallError::RespondToModel( - "exec pragma field `yield_time_ms` must be a non-negative safe integer".to_string(), - )); - } - if pragma.max_output_tokens.is_some_and(|max_output_tokens| { - u64::try_from(max_output_tokens) - .map(|max_output_tokens| max_output_tokens > MAX_JS_SAFE_INTEGER) - .unwrap_or(true) - }) { - return Err(FunctionCallError::RespondToModel( - "exec pragma field `max_output_tokens` must be a non-negative safe integer".to_string(), - )); - } - args.code = rest.to_string(); - args.yield_time_ms = pragma.yield_time_ms; - args.max_output_tokens = pragma.max_output_tokens; - Ok(args) -} - #[async_trait] impl ToolHandler for CodeModeExecuteHandler { type Output = FunctionToolOutput; @@ -216,7 +85,3 @@ impl ToolHandler for CodeModeExecuteHandler { } } } - -#[cfg(test)] -#[path = "execute_handler_tests.rs"] -mod execute_handler_tests; diff --git a/codex-rs/core/src/tools/code_mode/mod.rs b/codex-rs/core/src/tools/code_mode/mod.rs index c8e1e0c16..a4838d246 100644 --- a/codex-rs/core/src/tools/code_mode/mod.rs +++ b/codex-rs/core/src/tools/code_mode/mod.rs @@ -1,15 +1,18 @@ mod execute_handler; -mod process; -mod protocol; -mod service; +mod response_adapter; mod wait_handler; -mod worker; +use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use codex_code_mode::CodeModeTurnHost; +use codex_code_mode::RuntimeResponse; use codex_protocol::models::FunctionCallOutputContentItem; +use codex_protocol::models::FunctionCallOutputPayload; +use codex_protocol::models::ResponseInputItem; use serde_json::Value as JsonValue; +use tokio_util::sync::CancellationToken; use crate::client_common::tools::ToolSpec; use crate::codex::Session; @@ -17,9 +20,8 @@ use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; use crate::tools::ToolRouter; use crate::tools::code_mode_description::augment_tool_spec_for_code_mode; -use crate::tools::code_mode_description::code_mode_tool_reference; -use crate::tools::code_mode_description::normalize_code_mode_identifier; use crate::tools::context::FunctionToolOutput; +use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolPayload; use crate::tools::parallel::ToolCallRuntime; use crate::tools::router::ToolCall; @@ -29,180 +31,202 @@ use crate::truncate::TruncationPolicy; use crate::truncate::formatted_truncate_text_content_items_with_policy; use crate::truncate::truncate_function_output_items_with_policy; use crate::unified_exec::resolve_max_tokens; +use codex_features::Feature; -const CODE_MODE_RUNNER_SOURCE: &str = include_str!("runner.cjs"); -const CODE_MODE_BRIDGE_SOURCE: &str = include_str!("bridge.js"); -const CODE_MODE_DESCRIPTION_TEMPLATE: &str = include_str!("description.md"); -const CODE_MODE_WAIT_DESCRIPTION_TEMPLATE: &str = include_str!("wait_description.md"); -const CODE_MODE_PRAGMA_PREFIX: &str = "// @exec:"; -const CODE_MODE_ONLY_PREFACE: &str = - "Use `exec/wait` tool to run all other tools, do not attempt to use any other tools directly"; +pub(crate) use execute_handler::CodeModeExecuteHandler; +use response_adapter::into_function_call_output_content_items; +pub(crate) use wait_handler::CodeModeWaitHandler; -pub(crate) const PUBLIC_TOOL_NAME: &str = "exec"; -pub(crate) const WAIT_TOOL_NAME: &str = "wait"; - -pub(crate) fn is_code_mode_nested_tool(tool_name: &str) -> bool { - tool_name != PUBLIC_TOOL_NAME && tool_name != WAIT_TOOL_NAME -} -pub(crate) const DEFAULT_EXEC_YIELD_TIME_MS: u64 = 10_000; -pub(crate) const DEFAULT_WAIT_YIELD_TIME_MS: u64 = 10_000; +pub(crate) const PUBLIC_TOOL_NAME: &str = codex_code_mode::PUBLIC_TOOL_NAME; +pub(crate) const WAIT_TOOL_NAME: &str = codex_code_mode::WAIT_TOOL_NAME; +pub(crate) const DEFAULT_WAIT_YIELD_TIME_MS: u64 = codex_code_mode::DEFAULT_WAIT_YIELD_TIME_MS; #[derive(Clone)] -pub(super) struct ExecContext { +pub(crate) struct ExecContext { pub(super) session: Arc, pub(super) turn: Arc, } -pub(crate) use execute_handler::CodeModeExecuteHandler; -pub(crate) use service::CodeModeService; -pub(crate) use wait_handler::CodeModeWaitHandler; - -enum CodeModeSessionProgress { - Finished(FunctionToolOutput), - Yielded { output: FunctionToolOutput }, +pub(crate) struct CodeModeService { + inner: codex_code_mode::CodeModeService, } -enum CodeModeExecutionStatus { - Completed, - Failed, - Running(String), - Terminated, -} - -pub(crate) fn tool_description(enabled_tools: &[(String, String)], code_mode_only: bool) -> String { - let description_template = CODE_MODE_DESCRIPTION_TEMPLATE.trim_end(); - if !code_mode_only { - return description_template.to_string(); +impl CodeModeService { + pub(crate) fn new(_js_repl_node_path: Option) -> Self { + Self { + inner: codex_code_mode::CodeModeService::new(), + } } - let mut sections = vec![ - CODE_MODE_ONLY_PREFACE.to_string(), - description_template.to_string(), - ]; + pub(crate) async fn stored_values(&self) -> std::collections::HashMap { + self.inner.stored_values().await + } - if !enabled_tools.is_empty() { - let nested_tool_reference = enabled_tools - .iter() - .map(|(name, nested_description)| { - let global_name = normalize_code_mode_identifier(name); - format!( - "### `{global_name}` (`{name}`)\n{}", - nested_description.trim() - ) + pub(crate) async fn replace_stored_values( + &self, + values: std::collections::HashMap, + ) { + self.inner.replace_stored_values(values).await; + } + + pub(crate) async fn execute( + &self, + request: codex_code_mode::ExecuteRequest, + ) -> Result { + self.inner.execute(request).await + } + + pub(crate) async fn wait( + &self, + request: codex_code_mode::WaitRequest, + ) -> Result { + self.inner.wait(request).await + } + + pub(crate) async fn start_turn_worker( + &self, + session: &Arc, + turn: &Arc, + router: Arc, + tracker: SharedTurnDiffTracker, + ) -> Option { + if !turn.features.enabled(Feature::CodeMode) { + return None; + } + + let exec = ExecContext { + session: Arc::clone(session), + turn: Arc::clone(turn), + }; + let tool_runtime = + ToolCallRuntime::new(router, Arc::clone(session), Arc::clone(turn), tracker); + let host = Arc::new(CoreTurnHost { exec, tool_runtime }); + Some(self.inner.start_turn_worker(host)) + } +} + +struct CoreTurnHost { + exec: ExecContext, + tool_runtime: ToolCallRuntime, +} + +#[async_trait::async_trait] +impl CodeModeTurnHost for CoreTurnHost { + async fn invoke_tool( + &self, + tool_name: String, + input: Option, + cancellation_token: CancellationToken, + ) -> Result { + call_nested_tool( + self.exec.clone(), + self.tool_runtime.clone(), + tool_name, + input, + cancellation_token, + ) + .await + .map_err(|error| error.to_string()) + } + + async fn notify(&self, call_id: String, cell_id: String, text: String) -> Result<(), String> { + if text.trim().is_empty() { + return Ok(()); + } + self.exec + .session + .inject_response_items(vec![ResponseInputItem::CustomToolCallOutput { + call_id, + name: Some(PUBLIC_TOOL_NAME.to_string()), + output: FunctionCallOutputPayload::from_text(text), + }]) + .await + .map_err(|_| { + format!("failed to inject exec notify message for cell {cell_id}: no active turn") }) - .collect::>() - .join("\n\n"); - sections.push(nested_tool_reference); } - - sections.join("\n\n") } -pub(crate) fn wait_tool_description() -> &'static str { - CODE_MODE_WAIT_DESCRIPTION_TEMPLATE -} - -async fn handle_node_message( +pub(super) async fn handle_runtime_response( exec: &ExecContext, - cell_id: String, - message: protocol::NodeToHostMessage, - poll_max_output_tokens: Option>, + response: RuntimeResponse, + max_output_tokens: Option, started_at: std::time::Instant, -) -> Result { - match message { - protocol::NodeToHostMessage::ToolCall { .. } => Err(protocol::unexpected_tool_call_error()), - protocol::NodeToHostMessage::Notify { .. } => Err(format!( - "unexpected {PUBLIC_TOOL_NAME} notify message in response path" - )), - protocol::NodeToHostMessage::Yielded { content_items, .. } => { - let mut delta_items = output_content_items_from_json_values(content_items)?; - delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); - prepend_script_status( - &mut delta_items, - CodeModeExecutionStatus::Running(cell_id), - started_at.elapsed(), - ); - Ok(CodeModeSessionProgress::Yielded { - output: FunctionToolOutput::from_content(delta_items, Some(true)), - }) +) -> Result { + let script_status = format_script_status(&response); + + match response { + RuntimeResponse::Yielded { content_items, .. } => { + let mut content_items = into_function_call_output_content_items(content_items); + content_items = truncate_code_mode_result(content_items, max_output_tokens); + prepend_script_status(&mut content_items, &script_status, started_at.elapsed()); + Ok(FunctionToolOutput::from_content(content_items, Some(true))) } - protocol::NodeToHostMessage::Terminated { content_items, .. } => { - let mut delta_items = output_content_items_from_json_values(content_items)?; - delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); - prepend_script_status( - &mut delta_items, - CodeModeExecutionStatus::Terminated, - started_at.elapsed(), - ); - Ok(CodeModeSessionProgress::Finished( - FunctionToolOutput::from_content(delta_items, Some(true)), - )) + RuntimeResponse::Terminated { content_items, .. } => { + let mut content_items = into_function_call_output_content_items(content_items); + content_items = truncate_code_mode_result(content_items, max_output_tokens); + prepend_script_status(&mut content_items, &script_status, started_at.elapsed()); + Ok(FunctionToolOutput::from_content(content_items, Some(true))) } - protocol::NodeToHostMessage::Result { + RuntimeResponse::Result { content_items, stored_values, error_text, - max_output_tokens_per_exec_call, .. } => { + let mut content_items = into_function_call_output_content_items(content_items); exec.session .services .code_mode_service .replace_stored_values(stored_values) .await; - let mut delta_items = output_content_items_from_json_values(content_items)?; let success = error_text.is_none(); if let Some(error_text) = error_text { - delta_items.push(FunctionCallOutputContentItem::InputText { + content_items.push(FunctionCallOutputContentItem::InputText { text: format!("Script error:\n{error_text}"), }); } - - let mut delta_items = truncate_code_mode_result( - delta_items, - poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), - ); - prepend_script_status( - &mut delta_items, - if success { - CodeModeExecutionStatus::Completed - } else { - CodeModeExecutionStatus::Failed - }, - started_at.elapsed(), - ); - Ok(CodeModeSessionProgress::Finished( - FunctionToolOutput::from_content(delta_items, Some(success)), + content_items = truncate_code_mode_result(content_items, max_output_tokens); + prepend_script_status(&mut content_items, &script_status, started_at.elapsed()); + Ok(FunctionToolOutput::from_content( + content_items, + Some(success), )) } } } +fn format_script_status(response: &RuntimeResponse) -> String { + match response { + RuntimeResponse::Yielded { cell_id, .. } => { + format!("Script running with cell ID {cell_id}") + } + RuntimeResponse::Terminated { .. } => "Script terminated".to_string(), + RuntimeResponse::Result { error_text, .. } => { + if error_text.is_none() { + "Script completed".to_string() + } else { + "Script failed".to_string() + } + } + } +} + fn prepend_script_status( content_items: &mut Vec, - status: CodeModeExecutionStatus, + status: &str, wall_time: Duration, ) { let wall_time_seconds = ((wall_time.as_secs_f32()) * 10.0).round() / 10.0; - let header = format!( - "{}\nWall time {wall_time_seconds:.1} seconds\nOutput:\n", - match status { - CodeModeExecutionStatus::Completed => "Script completed".to_string(), - CodeModeExecutionStatus::Failed => "Script failed".to_string(), - CodeModeExecutionStatus::Running(cell_id) => { - format!("Script running with cell ID {cell_id}") - } - CodeModeExecutionStatus::Terminated => "Script terminated".to_string(), - } - ); + let header = format!("{status}\nWall time {wall_time_seconds:.1} seconds\nOutput:\n"); content_items.insert(0, FunctionCallOutputContentItem::InputText { text: header }); } fn truncate_code_mode_result( items: Vec, - max_output_tokens_per_exec_call: Option, + max_output_tokens: Option, ) -> Vec { - let max_output_tokens = resolve_max_tokens(max_output_tokens_per_exec_call); + let max_output_tokens = resolve_max_tokens(max_output_tokens); let policy = TruncationPolicy::Tokens(max_output_tokens); if items .iter() @@ -216,21 +240,9 @@ fn truncate_code_mode_result( truncate_function_output_items_with_policy(&items, policy) } -fn output_content_items_from_json_values( - content_items: Vec, -) -> Result, String> { - content_items - .into_iter() - .enumerate() - .map(|(index, item)| { - serde_json::from_value(item).map_err(|err| { - format!("invalid {PUBLIC_TOOL_NAME} content item at index {index}: {err}") - }) - }) - .collect() -} - -async fn build_enabled_tools(exec: &ExecContext) -> Vec { +pub(super) async fn build_enabled_tools( + exec: &ExecContext, +) -> Vec { let router = build_nested_router(exec).await; let mut out = router .specs() @@ -238,39 +250,37 @@ async fn build_enabled_tools(exec: &ExecContext) -> Vec { .map(|spec| augment_tool_spec_for_code_mode(spec, /*code_mode_enabled*/ true)) .filter_map(enabled_tool_from_spec) .collect::>(); - out.sort_by(|left, right| left.tool_name.cmp(&right.tool_name)); - out.dedup_by(|left, right| left.tool_name == right.tool_name); + out.sort_by(|left, right| left.name.cmp(&right.name)); + out.dedup_by(|left, right| left.name == right.name); out } -fn enabled_tool_from_spec(spec: ToolSpec) -> Option { +fn enabled_tool_from_spec(spec: ToolSpec) -> Option { let tool_name = spec.name().to_string(); - if !is_code_mode_nested_tool(&tool_name) { + if !codex_code_mode::is_code_mode_nested_tool(&tool_name) { return None; } - let reference = code_mode_tool_reference(&tool_name); - let global_name = normalize_code_mode_identifier(&tool_name); - let (description, kind) = match spec { - ToolSpec::Function(tool) => (tool.description, protocol::CodeModeToolKind::Function), - ToolSpec::Freeform(tool) => (tool.description, protocol::CodeModeToolKind::Freeform), + match spec { + ToolSpec::Function(tool) => Some(codex_code_mode::ToolDefinition { + name: tool_name, + description: tool.description, + kind: codex_code_mode::CodeModeToolKind::Function, + input_schema: serde_json::to_value(&tool.parameters).ok(), + output_schema: tool.output_schema, + }), + ToolSpec::Freeform(tool) => Some(codex_code_mode::ToolDefinition { + name: tool_name, + description: tool.description, + kind: codex_code_mode::CodeModeToolKind::Freeform, + input_schema: None, + output_schema: None, + }), ToolSpec::LocalShell {} | ToolSpec::ImageGeneration { .. } | ToolSpec::ToolSearch { .. } - | ToolSpec::WebSearch { .. } => { - return None; - } - }; - - Some(protocol::EnabledTool { - tool_name, - global_name, - module_path: reference.module_path, - namespace: reference.namespace, - name: normalize_code_mode_identifier(&reference.tool_key), - description, - kind, - }) + | ToolSpec::WebSearch { .. } => None, + } } async fn build_nested_router(exec: &ExecContext) -> ToolRouter { @@ -303,7 +313,7 @@ async fn call_nested_tool( tool_runtime: ToolCallRuntime, tool_name: String, input: Option, - cancellation_token: tokio_util::sync::CancellationToken, + cancellation_token: CancellationToken, ) -> Result { if tool_name == PUBLIC_TOOL_NAME { return Err(FunctionCallError::RespondToModel(format!( @@ -340,18 +350,18 @@ async fn call_nested_tool( Ok(result.code_mode_result()) } -fn tool_kind_for_spec(spec: &ToolSpec) -> protocol::CodeModeToolKind { +fn tool_kind_for_spec(spec: &ToolSpec) -> codex_code_mode::CodeModeToolKind { if matches!(spec, ToolSpec::Freeform(_)) { - protocol::CodeModeToolKind::Freeform + codex_code_mode::CodeModeToolKind::Freeform } else { - protocol::CodeModeToolKind::Function + codex_code_mode::CodeModeToolKind::Function } } fn tool_kind_for_name( spec: Option, tool_name: &str, -) -> Result { +) -> Result { spec.as_ref() .map(tool_kind_for_spec) .ok_or_else(|| format!("tool `{tool_name}` is not enabled in {PUBLIC_TOOL_NAME}")) @@ -364,8 +374,12 @@ fn build_nested_tool_payload( ) -> Result { let actual_kind = tool_kind_for_name(spec, tool_name)?; match actual_kind { - protocol::CodeModeToolKind::Function => build_function_tool_payload(tool_name, input), - protocol::CodeModeToolKind::Freeform => build_freeform_tool_payload(tool_name, input), + codex_code_mode::CodeModeToolKind::Function => { + build_function_tool_payload(tool_name, input) + } + codex_code_mode::CodeModeToolKind::Freeform => { + build_freeform_tool_payload(tool_name, input) + } } } diff --git a/codex-rs/core/src/tools/code_mode/process.rs b/codex-rs/core/src/tools/code_mode/process.rs deleted file mode 100644 index 6dd6cde3a..000000000 --- a/codex-rs/core/src/tools/code_mode/process.rs +++ /dev/null @@ -1,173 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use tokio::io::AsyncBufReadExt; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::BufReader; -use tokio::sync::Mutex; -use tokio::sync::mpsc; -use tokio::sync::oneshot; -use tokio::task::JoinHandle; -use tracing::warn; - -use super::CODE_MODE_RUNNER_SOURCE; -use super::PUBLIC_TOOL_NAME; -use super::protocol::HostToNodeMessage; -use super::protocol::NodeToHostMessage; -use super::protocol::message_request_id; - -pub(super) struct CodeModeProcess { - pub(super) child: tokio::process::Child, - pub(super) stdin: Arc>, - pub(super) stdout_task: JoinHandle<()>, - pub(super) response_waiters: Arc>>>, - pub(super) message_rx: Arc>>, -} - -impl CodeModeProcess { - pub(super) async fn send( - &mut self, - request_id: &str, - message: &HostToNodeMessage, - ) -> Result { - if self.stdout_task.is_finished() { - return Err(std::io::Error::other(format!( - "{PUBLIC_TOOL_NAME} runner is not available" - ))); - } - - let (tx, rx) = oneshot::channel(); - self.response_waiters - .lock() - .await - .insert(request_id.to_string(), tx); - if let Err(err) = write_message(&self.stdin, message).await { - self.response_waiters.lock().await.remove(request_id); - return Err(err); - } - - match rx.await { - Ok(message) => Ok(message), - Err(_) => Err(std::io::Error::other(format!( - "{PUBLIC_TOOL_NAME} runner is not available" - ))), - } - } - - pub(super) fn has_exited(&mut self) -> Result { - self.child - .try_wait() - .map(|status| status.is_some()) - .map_err(std::io::Error::other) - } -} - -pub(super) async fn spawn_code_mode_process( - node_path: &std::path::Path, -) -> Result { - let mut cmd = tokio::process::Command::new(node_path); - cmd.arg("--experimental-vm-modules"); - cmd.arg("--eval"); - cmd.arg(CODE_MODE_RUNNER_SOURCE); - cmd.stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); - - let mut child = cmd.spawn().map_err(std::io::Error::other)?; - let stdout = child.stdout.take().ok_or_else(|| { - std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdout")) - })?; - let stderr = child.stderr.take().ok_or_else(|| { - std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stderr")) - })?; - let stdin = child - .stdin - .take() - .ok_or_else(|| std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdin")))?; - let stdin = Arc::new(Mutex::new(stdin)); - let response_waiters = Arc::new(Mutex::new(HashMap::< - String, - oneshot::Sender, - >::new())); - let (message_tx, message_rx) = mpsc::unbounded_channel(); - - tokio::spawn(async move { - let mut reader = BufReader::new(stderr); - let mut buf = Vec::new(); - match reader.read_to_end(&mut buf).await { - Ok(_) => { - let stderr = String::from_utf8_lossy(&buf).trim().to_string(); - if !stderr.is_empty() { - warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}"); - } - } - Err(err) => { - warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}"); - } - } - }); - let stdout_task = tokio::spawn({ - let response_waiters = Arc::clone(&response_waiters); - async move { - let mut stdout_lines = BufReader::new(stdout).lines(); - loop { - let line = match stdout_lines.next_line().await { - Ok(line) => line, - Err(err) => { - warn!("failed to read {PUBLIC_TOOL_NAME} stdout: {err}"); - break; - } - }; - let Some(line) = line else { - break; - }; - if line.trim().is_empty() { - continue; - } - let message: NodeToHostMessage = match serde_json::from_str(&line) { - Ok(message) => message, - Err(err) => { - warn!("failed to parse {PUBLIC_TOOL_NAME} stdout message: {err}"); - break; - } - }; - match message { - message @ (NodeToHostMessage::ToolCall { .. } - | NodeToHostMessage::Notify { .. }) => { - let _ = message_tx.send(message); - } - message => { - if let Some(request_id) = message_request_id(&message) - && let Some(waiter) = response_waiters.lock().await.remove(request_id) - { - let _ = waiter.send(message); - } - } - } - } - response_waiters.lock().await.clear(); - } - }); - - Ok(CodeModeProcess { - child, - stdin, - stdout_task, - response_waiters, - message_rx: Arc::new(Mutex::new(message_rx)), - }) -} - -pub(super) async fn write_message( - stdin: &Arc>, - message: &HostToNodeMessage, -) -> Result<(), std::io::Error> { - let line = serde_json::to_string(message).map_err(std::io::Error::other)?; - let mut stdin = stdin.lock().await; - stdin.write_all(line.as_bytes()).await?; - stdin.write_all(b"\n").await?; - stdin.flush().await?; - Ok(()) -} diff --git a/codex-rs/core/src/tools/code_mode/protocol.rs b/codex-rs/core/src/tools/code_mode/protocol.rs deleted file mode 100644 index 2e72e1229..000000000 --- a/codex-rs/core/src/tools/code_mode/protocol.rs +++ /dev/null @@ -1,169 +0,0 @@ -use std::collections::HashMap; - -use serde::Deserialize; -use serde::Serialize; -use serde_json::Value as JsonValue; - -use super::CODE_MODE_BRIDGE_SOURCE; -use super::PUBLIC_TOOL_NAME; - -#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] -#[serde(rename_all = "snake_case")] -pub(super) enum CodeModeToolKind { - Function, - Freeform, -} - -#[derive(Clone, Debug, Serialize)] -pub(super) struct EnabledTool { - pub(super) tool_name: String, - pub(super) global_name: String, - #[serde(rename = "module")] - pub(super) module_path: String, - pub(super) namespace: Vec, - pub(super) name: String, - pub(super) description: String, - pub(super) kind: CodeModeToolKind, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "snake_case")] -pub(super) struct CodeModeToolCall { - pub(super) request_id: String, - pub(super) id: String, - pub(super) name: String, - #[serde(default)] - pub(super) input: Option, -} - -#[derive(Clone, Debug, Deserialize)] -pub(super) struct CodeModeNotify { - pub(super) cell_id: String, - pub(super) call_id: String, - pub(super) text: String, -} - -#[derive(Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub(super) enum HostToNodeMessage { - Start { - request_id: String, - cell_id: String, - tool_call_id: String, - default_yield_time_ms: u64, - enabled_tools: Vec, - stored_values: HashMap, - source: String, - yield_time_ms: Option, - max_output_tokens: Option, - }, - Poll { - request_id: String, - cell_id: String, - yield_time_ms: u64, - }, - Terminate { - request_id: String, - cell_id: String, - }, - Response { - request_id: String, - id: String, - code_mode_result: JsonValue, - #[serde(default)] - error_text: Option, - }, -} - -#[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub(super) enum NodeToHostMessage { - ToolCall { - #[serde(flatten)] - tool_call: CodeModeToolCall, - }, - Yielded { - request_id: String, - content_items: Vec, - }, - Terminated { - request_id: String, - content_items: Vec, - }, - Notify { - #[serde(flatten)] - notify: CodeModeNotify, - }, - Result { - request_id: String, - content_items: Vec, - stored_values: HashMap, - #[serde(default)] - error_text: Option, - #[serde(default)] - max_output_tokens_per_exec_call: Option, - }, -} - -pub(super) fn build_source( - user_code: &str, - enabled_tools: &[EnabledTool], -) -> Result { - let enabled_tools_json = serde_json::to_string(enabled_tools) - .map_err(|err| format!("failed to serialize enabled tools: {err}"))?; - Ok(CODE_MODE_BRIDGE_SOURCE - .replace( - "__CODE_MODE_ENABLED_TOOLS_PLACEHOLDER__", - &enabled_tools_json, - ) - .replace("__CODE_MODE_USER_CODE_PLACEHOLDER__", user_code)) -} - -pub(super) fn message_request_id(message: &NodeToHostMessage) -> Option<&str> { - match message { - NodeToHostMessage::ToolCall { .. } => None, - NodeToHostMessage::Yielded { request_id, .. } - | NodeToHostMessage::Terminated { request_id, .. } - | NodeToHostMessage::Result { request_id, .. } => Some(request_id), - NodeToHostMessage::Notify { .. } => None, - } -} - -pub(super) fn unexpected_tool_call_error() -> String { - format!("{PUBLIC_TOOL_NAME} received an unexpected tool call response") -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use super::CodeModeNotify; - use super::NodeToHostMessage; - use super::message_request_id; - - #[test] - fn message_request_id_absent_for_notify() { - let message = NodeToHostMessage::Notify { - notify: CodeModeNotify { - cell_id: "1".to_string(), - call_id: "call-1".to_string(), - text: "hello".to_string(), - }, - }; - - assert_eq!(None, message_request_id(&message)); - } - - #[test] - fn message_request_id_present_for_result() { - let message = NodeToHostMessage::Result { - request_id: "req-1".to_string(), - content_items: Vec::new(), - stored_values: HashMap::new(), - error_text: None, - max_output_tokens_per_exec_call: None, - }; - - assert_eq!(Some("req-1"), message_request_id(&message)); - } -} diff --git a/codex-rs/core/src/tools/code_mode/response_adapter.rs b/codex-rs/core/src/tools/code_mode/response_adapter.rs new file mode 100644 index 000000000..b90448acf --- /dev/null +++ b/codex-rs/core/src/tools/code_mode/response_adapter.rs @@ -0,0 +1,44 @@ +use codex_code_mode::ImageDetail as CodeModeImageDetail; +use codex_protocol::models::FunctionCallOutputContentItem; +use codex_protocol::models::ImageDetail; + +trait IntoProtocol { + fn into_protocol(self) -> T; +} + +pub(super) fn into_function_call_output_content_items( + items: Vec, +) -> Vec { + items.into_iter().map(IntoProtocol::into_protocol).collect() +} + +impl IntoProtocol for CodeModeImageDetail { + fn into_protocol(self) -> ImageDetail { + let value = self; + match value { + CodeModeImageDetail::Auto => ImageDetail::Auto, + CodeModeImageDetail::Low => ImageDetail::Low, + CodeModeImageDetail::High => ImageDetail::High, + CodeModeImageDetail::Original => ImageDetail::Original, + } + } +} + +impl IntoProtocol + for codex_code_mode::FunctionCallOutputContentItem +{ + fn into_protocol(self) -> FunctionCallOutputContentItem { + let value = self; + match value { + codex_code_mode::FunctionCallOutputContentItem::InputText { text } => { + FunctionCallOutputContentItem::InputText { text } + } + codex_code_mode::FunctionCallOutputContentItem::InputImage { image_url, detail } => { + FunctionCallOutputContentItem::InputImage { + image_url, + detail: detail.map(IntoProtocol::into_protocol), + } + } + } + } +} diff --git a/codex-rs/core/src/tools/code_mode/runner.cjs b/codex-rs/core/src/tools/code_mode/runner.cjs deleted file mode 100644 index 8b4b322eb..000000000 --- a/codex-rs/core/src/tools/code_mode/runner.cjs +++ /dev/null @@ -1,938 +0,0 @@ -'use strict'; - -const readline = require('node:readline'); -const { Worker } = require('node:worker_threads'); - -const DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL = 10000; - -function normalizeMaxOutputTokensPerExecCall(value) { - if (!Number.isSafeInteger(value) || value < 0) { - throw new TypeError('max_output_tokens_per_exec_call must be a non-negative safe integer'); - } - return value; -} - -function normalizeYieldTime(value) { - if (!Number.isSafeInteger(value) || value < 0) { - throw new TypeError('yield_time must be a non-negative safe integer'); - } - return value; -} - -function formatErrorText(error) { - return String(error && error.stack ? error.stack : error); -} - -function cloneJsonValue(value) { - return JSON.parse(JSON.stringify(value)); -} - -function clearTimer(timer) { - if (timer !== null) { - clearTimeout(timer); - } - return null; -} - -function takeContentItems(session) { - const clonedContentItems = cloneJsonValue(session.content_items); - session.content_items.splice(0, session.content_items.length); - return Array.isArray(clonedContentItems) ? clonedContentItems : []; -} - -function codeModeWorkerMain() { - 'use strict'; - - const { parentPort, workerData } = require('node:worker_threads'); - const vm = require('node:vm'); - const { SourceTextModule, SyntheticModule } = vm; - - function formatErrorText(error) { - return String(error && error.stack ? error.stack : error); - } - - function cloneJsonValue(value) { - return JSON.parse(JSON.stringify(value)); - } - - class CodeModeExitSignal extends Error { - constructor() { - super('code mode exit'); - this.name = 'CodeModeExitSignal'; - } - } - - function isCodeModeExitSignal(error) { - return error instanceof CodeModeExitSignal; - } - - function createToolCaller() { - let nextId = 0; - const pending = new Map(); - - parentPort.on('message', (message) => { - if (message.type === 'tool_response') { - const entry = pending.get(message.id); - if (!entry) { - return; - } - pending.delete(message.id); - entry.resolve(message.result ?? ''); - return; - } - - if (message.type === 'tool_response_error') { - const entry = pending.get(message.id); - if (!entry) { - return; - } - pending.delete(message.id); - entry.reject(new Error(message.error_text ?? 'tool call failed')); - return; - } - }); - - return (name, input) => { - const id = 'msg-' + ++nextId; - return new Promise((resolve, reject) => { - pending.set(id, { resolve, reject }); - parentPort.postMessage({ - type: 'tool_call', - id, - name: String(name), - input, - }); - }); - }; - } - - function createContentItems() { - const contentItems = []; - const push = contentItems.push.bind(contentItems); - contentItems.push = (...items) => { - for (const item of items) { - parentPort.postMessage({ - type: 'content_item', - item: cloneJsonValue(item), - }); - } - return push(...items); - }; - parentPort.on('message', (message) => { - if (message.type === 'clear_content') { - contentItems.splice(0, contentItems.length); - } - }); - return contentItems; - } - - function createGlobalToolsNamespace(callTool, enabledTools) { - const tools = Object.create(null); - - for (const { tool_name, global_name } of enabledTools) { - Object.defineProperty(tools, global_name, { - value: async (args) => callTool(tool_name, args), - configurable: false, - enumerable: true, - writable: false, - }); - } - - return Object.freeze(tools); - } - - function createModuleToolsNamespace(callTool, enabledTools) { - const tools = Object.create(null); - - for (const { tool_name, global_name } of enabledTools) { - Object.defineProperty(tools, global_name, { - value: async (args) => callTool(tool_name, args), - configurable: false, - enumerable: true, - writable: false, - }); - } - - return Object.freeze(tools); - } - - function createAllToolsMetadata(enabledTools) { - return Object.freeze( - enabledTools.map(({ global_name, description }) => - Object.freeze({ - name: global_name, - description, - }) - ) - ); - } - - function createToolsModule(context, callTool, enabledTools) { - const tools = createModuleToolsNamespace(callTool, enabledTools); - const allTools = createAllToolsMetadata(enabledTools); - const exportNames = ['ALL_TOOLS']; - - for (const { global_name } of enabledTools) { - if (global_name !== 'ALL_TOOLS') { - exportNames.push(global_name); - } - } - - const uniqueExportNames = [...new Set(exportNames)]; - - return new SyntheticModule( - uniqueExportNames, - function initToolsModule() { - this.setExport('ALL_TOOLS', allTools); - for (const exportName of uniqueExportNames) { - if (exportName !== 'ALL_TOOLS') { - this.setExport(exportName, tools[exportName]); - } - } - }, - { context } - ); - } - - function ensureContentItems(context) { - if (!Array.isArray(context.__codexContentItems)) { - context.__codexContentItems = []; - } - return context.__codexContentItems; - } - - function serializeOutputText(value) { - if (typeof value === 'string') { - return value; - } - if ( - typeof value === 'undefined' || - value === null || - typeof value === 'boolean' || - typeof value === 'number' || - typeof value === 'bigint' - ) { - return String(value); - } - - const serialized = JSON.stringify(value); - if (typeof serialized === 'string') { - return serialized; - } - - return String(value); - } - - function normalizeOutputImage(value) { - let imageUrl; - let detail; - if (typeof value === 'string') { - imageUrl = value; - } else if ( - value && - typeof value === 'object' && - !Array.isArray(value) - ) { - if (typeof value.image_url === 'string') { - imageUrl = value.image_url; - } - if (typeof value.detail === 'string') { - detail = value.detail; - } else if ( - Object.prototype.hasOwnProperty.call(value, 'detail') && - value.detail !== null && - typeof value.detail !== 'undefined' - ) { - throw new TypeError('image detail must be a string when provided'); - } - } - - if (typeof imageUrl !== 'string' || !imageUrl) { - throw new TypeError( - 'image expects a non-empty image URL string or an object with image_url and optional detail' - ); - } - if (!/^(?:https?:\/\/|data:)/i.test(imageUrl)) { - throw new TypeError('image expects an http(s) or data URL'); - } - - if (typeof detail !== 'undefined' && !/^(?:auto|low|high|original)$/i.test(detail)) { - throw new TypeError('image detail must be one of: auto, low, high, original'); - } - - const normalized = { image_url: imageUrl }; - if (typeof detail === 'string') { - normalized.detail = detail.toLowerCase(); - } - return normalized; - } - - function createCodeModeHelpers(context, state, toolCallId) { - const load = (key) => { - if (typeof key !== 'string') { - throw new TypeError('load key must be a string'); - } - if (!Object.prototype.hasOwnProperty.call(state.storedValues, key)) { - return undefined; - } - return cloneJsonValue(state.storedValues[key]); - }; - const store = (key, value) => { - if (typeof key !== 'string') { - throw new TypeError('store key must be a string'); - } - state.storedValues[key] = cloneJsonValue(value); - }; - const text = (value) => { - const item = { - type: 'input_text', - text: serializeOutputText(value), - }; - ensureContentItems(context).push(item); - return item; - }; - const image = (value) => { - const item = Object.assign({ type: 'input_image' }, normalizeOutputImage(value)); - ensureContentItems(context).push(item); - return item; - }; - const yieldControl = () => { - parentPort.postMessage({ type: 'yield' }); - }; - const notify = (value) => { - const text = serializeOutputText(value); - if (text.trim().length === 0) { - throw new TypeError('notify expects non-empty text'); - } - if (typeof toolCallId !== 'string' || toolCallId.length === 0) { - throw new TypeError('notify requires a valid tool call id'); - } - parentPort.postMessage({ - type: 'notify', - call_id: toolCallId, - text, - }); - return text; - }; - const exit = () => { - throw new CodeModeExitSignal(); - }; - - return Object.freeze({ - exit, - image, - load, - notify, - output_image: image, - output_text: text, - store, - text, - yield_control: yieldControl, - }); - } - - function createCodeModeModule(context, helpers) { - return new SyntheticModule( - [ - 'exit', - 'image', - 'load', - 'notify', - 'output_text', - 'output_image', - 'store', - 'text', - 'yield_control', - ], - function initCodeModeModule() { - this.setExport('exit', helpers.exit); - this.setExport('image', helpers.image); - this.setExport('load', helpers.load); - this.setExport('notify', helpers.notify); - this.setExport('output_text', helpers.output_text); - this.setExport('output_image', helpers.output_image); - this.setExport('store', helpers.store); - this.setExport('text', helpers.text); - this.setExport('yield_control', helpers.yield_control); - }, - { context } - ); - } - - function createBridgeRuntime(callTool, enabledTools, helpers) { - return Object.freeze({ - ALL_TOOLS: createAllToolsMetadata(enabledTools), - exit: helpers.exit, - image: helpers.image, - load: helpers.load, - notify: helpers.notify, - store: helpers.store, - text: helpers.text, - tools: createGlobalToolsNamespace(callTool, enabledTools), - yield_control: helpers.yield_control, - }); - } - - function namespacesMatch(left, right) { - if (left.length !== right.length) { - return false; - } - return left.every((segment, index) => segment === right[index]); - } - - function createNamespacedToolsNamespace(callTool, enabledTools, namespace) { - const tools = Object.create(null); - - for (const tool of enabledTools) { - const toolNamespace = Array.isArray(tool.namespace) ? tool.namespace : []; - if (!namespacesMatch(toolNamespace, namespace)) { - continue; - } - - Object.defineProperty(tools, tool.name, { - value: async (args) => callTool(tool.tool_name, args), - configurable: false, - enumerable: true, - writable: false, - }); - } - - return Object.freeze(tools); - } - - function createNamespacedToolsModule(context, callTool, enabledTools, namespace) { - const tools = createNamespacedToolsNamespace(callTool, enabledTools, namespace); - const exportNames = []; - - for (const exportName of Object.keys(tools)) { - if (exportName !== 'ALL_TOOLS') { - exportNames.push(exportName); - } - } - - const uniqueExportNames = [...new Set(exportNames)]; - - return new SyntheticModule( - uniqueExportNames, - function initNamespacedToolsModule() { - for (const exportName of uniqueExportNames) { - this.setExport(exportName, tools[exportName]); - } - }, - { context } - ); - } - - function createModuleResolver(context, callTool, enabledTools, helpers) { - let toolsModule; - let codeModeModule; - const namespacedModules = new Map(); - - return function resolveModule(specifier) { - if (specifier === 'tools.js') { - toolsModule ??= createToolsModule(context, callTool, enabledTools); - return toolsModule; - } - if (specifier === '@openai/code_mode' || specifier === 'openai/code_mode') { - codeModeModule ??= createCodeModeModule(context, helpers); - return codeModeModule; - } - const namespacedMatch = /^tools\/(.+)\.js$/.exec(specifier); - if (!namespacedMatch) { - throw new Error('Unsupported import in exec: ' + specifier); - } - - const namespace = namespacedMatch[1] - .split('/') - .filter((segment) => segment.length > 0); - if (namespace.length === 0) { - throw new Error('Unsupported import in exec: ' + specifier); - } - - const cacheKey = namespace.join('/'); - if (!namespacedModules.has(cacheKey)) { - namespacedModules.set( - cacheKey, - createNamespacedToolsModule(context, callTool, enabledTools, namespace) - ); - } - return namespacedModules.get(cacheKey); - }; - } - - async function resolveDynamicModule(specifier, resolveModule) { - const module = resolveModule(specifier); - - if (module.status === 'unlinked') { - await module.link(resolveModule); - } - - if (module.status === 'linked' || module.status === 'evaluating') { - await module.evaluate(); - } - - if (module.status === 'errored') { - throw module.error; - } - - return module; - } - - async function runModule(context, start, callTool, helpers) { - const resolveModule = createModuleResolver( - context, - callTool, - start.enabled_tools ?? [], - helpers - ); - const mainModule = new SourceTextModule(start.source, { - context, - identifier: 'exec_main.mjs', - importModuleDynamically: async (specifier) => - resolveDynamicModule(specifier, resolveModule), - }); - - await mainModule.link(resolveModule); - await mainModule.evaluate(); - } - - async function main() { - const start = workerData ?? {}; - const toolCallId = start.tool_call_id; - const state = { - storedValues: cloneJsonValue(start.stored_values ?? {}), - }; - const callTool = createToolCaller(); - const enabledTools = start.enabled_tools ?? []; - const contentItems = createContentItems(); - const context = vm.createContext({ - __codexContentItems: contentItems, - }); - const helpers = createCodeModeHelpers(context, state, toolCallId); - Object.defineProperty(context, '__codexRuntime', { - value: createBridgeRuntime(callTool, enabledTools, helpers), - configurable: true, - enumerable: false, - writable: false, - }); - - parentPort.postMessage({ type: 'started' }); - try { - await runModule(context, start, callTool, helpers); - parentPort.postMessage({ - type: 'result', - stored_values: state.storedValues, - }); - } catch (error) { - if (isCodeModeExitSignal(error)) { - parentPort.postMessage({ - type: 'result', - stored_values: state.storedValues, - }); - return; - } - parentPort.postMessage({ - type: 'result', - stored_values: state.storedValues, - error_text: formatErrorText(error), - }); - } - } - - void main().catch((error) => { - parentPort.postMessage({ - type: 'result', - stored_values: {}, - error_text: formatErrorText(error), - }); - }); -} - -function createProtocol() { - const rl = readline.createInterface({ - input: process.stdin, - crlfDelay: Infinity, - }); - - let nextId = 0; - const pending = new Map(); - const sessions = new Map(); - let closedResolve; - const closed = new Promise((resolve) => { - closedResolve = resolve; - }); - - rl.on('line', (line) => { - if (!line.trim()) { - return; - } - - let message; - try { - message = JSON.parse(line); - } catch (error) { - process.stderr.write(formatErrorText(error) + '\n'); - return; - } - - if (message.type === 'start') { - startSession(protocol, sessions, message); - return; - } - - if (message.type === 'poll') { - const session = sessions.get(message.cell_id); - if (session) { - session.request_id = String(message.request_id); - if (session.pending_result) { - void completeSession(protocol, sessions, session, session.pending_result); - } else { - schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0)); - } - } else { - void protocol.send({ - type: 'result', - request_id: message.request_id, - content_items: [], - stored_values: {}, - error_text: `exec cell ${message.cell_id} not found`, - max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, - }); - } - return; - } - - if (message.type === 'terminate') { - const session = sessions.get(message.cell_id); - if (session) { - session.request_id = String(message.request_id); - void terminateSession(protocol, sessions, session); - } else { - void protocol.send({ - type: 'result', - request_id: message.request_id, - content_items: [], - stored_values: {}, - error_text: `exec cell ${message.cell_id} not found`, - max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, - }); - } - return; - } - - if (message.type === 'response') { - const entry = pending.get(message.request_id + ':' + message.id); - if (!entry) { - return; - } - pending.delete(message.request_id + ':' + message.id); - if (typeof message.error_text === 'string') { - entry.reject(new Error(message.error_text)); - return; - } - entry.resolve(message.code_mode_result ?? ''); - return; - } - - process.stderr.write('Unknown protocol message type: ' + message.type + '\n'); - }); - - rl.on('close', () => { - const error = new Error('stdin closed'); - for (const entry of pending.values()) { - entry.reject(error); - } - pending.clear(); - for (const session of sessions.values()) { - session.initial_yield_timer = clearTimer(session.initial_yield_timer); - session.poll_yield_timer = clearTimer(session.poll_yield_timer); - void session.worker.terminate().catch(() => {}); - } - sessions.clear(); - closedResolve(); - }); - - function send(message) { - return new Promise((resolve, reject) => { - process.stdout.write(JSON.stringify(message) + '\n', (error) => { - if (error) { - reject(error); - } else { - resolve(); - } - }); - }); - } - - function request(type, payload) { - const requestId = 'req-' + ++nextId; - const id = 'msg-' + ++nextId; - const pendingKey = requestId + ':' + id; - return new Promise((resolve, reject) => { - pending.set(pendingKey, { resolve, reject }); - void send({ type, request_id: requestId, id, ...payload }).catch((error) => { - pending.delete(pendingKey); - reject(error); - }); - }); - } - - const protocol = { closed, request, send }; - return protocol; -} - -function sessionWorkerSource() { - return '(' + codeModeWorkerMain.toString() + ')();'; -} - -function startSession(protocol, sessions, start) { - if (typeof start.tool_call_id !== 'string' || start.tool_call_id.length === 0) { - throw new TypeError('start requires a valid tool_call_id'); - } - const maxOutputTokensPerExecCall = - start.max_output_tokens == null - ? DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL - : normalizeMaxOutputTokensPerExecCall(start.max_output_tokens); - const session = { - completed: false, - content_items: [], - default_yield_time_ms: normalizeYieldTime(start.default_yield_time_ms), - id: start.cell_id, - initial_yield_time_ms: - start.yield_time_ms == null - ? normalizeYieldTime(start.default_yield_time_ms) - : normalizeYieldTime(start.yield_time_ms), - initial_yield_timer: null, - initial_yield_triggered: false, - max_output_tokens_per_exec_call: maxOutputTokensPerExecCall, - pending_result: null, - poll_yield_timer: null, - request_id: String(start.request_id), - worker: new Worker(sessionWorkerSource(), { - eval: true, - workerData: start, - }), - }; - sessions.set(session.id, session); - - session.worker.on('message', (message) => { - void handleWorkerMessage(protocol, sessions, session, message).catch((error) => { - void completeSession(protocol, sessions, session, { - type: 'result', - stored_values: {}, - error_text: formatErrorText(error), - }); - }); - }); - session.worker.on('error', (error) => { - void completeSession(protocol, sessions, session, { - type: 'result', - stored_values: {}, - error_text: formatErrorText(error), - }); - }); - session.worker.on('exit', (code) => { - if (code !== 0 && !session.completed) { - void completeSession(protocol, sessions, session, { - type: 'result', - stored_values: {}, - error_text: 'exec worker exited with code ' + code, - }); - } - }); -} - -async function handleWorkerMessage(protocol, sessions, session, message) { - if (session.completed) { - return; - } - - if (message.type === 'content_item') { - session.content_items.push(cloneJsonValue(message.item)); - return; - } - - if (message.type === 'started') { - scheduleInitialYield(protocol, session, session.initial_yield_time_ms); - return; - } - - if (message.type === 'yield') { - void sendYielded(protocol, session); - return; - } - - if (message.type === 'notify') { - if (typeof message.text !== 'string' || message.text.trim().length === 0) { - throw new TypeError('notify requires non-empty text'); - } - if (typeof message.call_id !== 'string' || message.call_id.length === 0) { - throw new TypeError('notify requires a valid call id'); - } - await protocol.send({ - type: 'notify', - cell_id: session.id, - call_id: message.call_id, - text: message.text, - }); - return; - } - - if (message.type === 'tool_call') { - void forwardToolCall(protocol, session, message); - return; - } - - if (message.type === 'result') { - const result = { - type: 'result', - stored_values: cloneJsonValue(message.stored_values ?? {}), - error_text: - typeof message.error_text === 'string' ? message.error_text : undefined, - }; - if (session.request_id === null) { - session.pending_result = result; - session.initial_yield_timer = clearTimer(session.initial_yield_timer); - session.poll_yield_timer = clearTimer(session.poll_yield_timer); - return; - } - await completeSession(protocol, sessions, session, result); - return; - } - - process.stderr.write('Unknown worker message type: ' + message.type + '\n'); -} - -async function forwardToolCall(protocol, session, message) { - try { - const result = await protocol.request('tool_call', { - name: String(message.name), - input: message.input, - }); - if (session.completed) { - return; - } - try { - session.worker.postMessage({ - type: 'tool_response', - id: message.id, - result, - }); - } catch {} - } catch (error) { - if (session.completed) { - return; - } - try { - session.worker.postMessage({ - type: 'tool_response_error', - id: message.id, - error_text: formatErrorText(error), - }); - } catch {} - } -} - -async function sendYielded(protocol, session) { - if (session.completed || session.request_id === null) { - return; - } - session.initial_yield_timer = clearTimer(session.initial_yield_timer); - session.initial_yield_triggered = true; - session.poll_yield_timer = clearTimer(session.poll_yield_timer); - const contentItems = takeContentItems(session); - const requestId = session.request_id; - try { - session.worker.postMessage({ type: 'clear_content' }); - } catch {} - await protocol.send({ - type: 'yielded', - request_id: requestId, - content_items: contentItems, - }); - session.request_id = null; -} - -function scheduleInitialYield(protocol, session, yieldTime) { - if (session.completed || session.initial_yield_triggered) { - return yieldTime; - } - session.initial_yield_timer = clearTimer(session.initial_yield_timer); - session.initial_yield_timer = setTimeout(() => { - session.initial_yield_timer = null; - session.initial_yield_triggered = true; - void sendYielded(protocol, session); - }, yieldTime); - return yieldTime; -} - -function schedulePollYield(protocol, session, yieldTime) { - if (session.completed) { - return; - } - session.poll_yield_timer = clearTimer(session.poll_yield_timer); - session.poll_yield_timer = setTimeout(() => { - session.poll_yield_timer = null; - void sendYielded(protocol, session); - }, yieldTime); -} - -async function completeSession(protocol, sessions, session, message) { - if (session.completed) { - return; - } - if (session.request_id === null) { - session.pending_result = message; - session.initial_yield_timer = clearTimer(session.initial_yield_timer); - session.poll_yield_timer = clearTimer(session.poll_yield_timer); - return; - } - const requestId = session.request_id; - session.completed = true; - session.initial_yield_timer = clearTimer(session.initial_yield_timer); - session.poll_yield_timer = clearTimer(session.poll_yield_timer); - sessions.delete(session.id); - const contentItems = takeContentItems(session); - session.pending_result = null; - try { - session.worker.postMessage({ type: 'clear_content' }); - } catch {} - await protocol.send({ - ...message, - request_id: requestId, - content_items: contentItems, - max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call, - }); -} - -async function terminateSession(protocol, sessions, session) { - if (session.completed) { - return; - } - session.completed = true; - session.initial_yield_timer = clearTimer(session.initial_yield_timer); - session.poll_yield_timer = clearTimer(session.poll_yield_timer); - sessions.delete(session.id); - const contentItems = takeContentItems(session); - try { - await session.worker.terminate(); - } catch {} - await protocol.send({ - type: 'terminated', - request_id: session.request_id, - content_items: contentItems, - }); -} - -async function main() { - const protocol = createProtocol(); - await protocol.closed; -} - -void main().catch(async (error) => { - try { - process.stderr.write(formatErrorText(error) + '\n'); - } finally { - process.exitCode = 1; - } -}); diff --git a/codex-rs/core/src/tools/code_mode/service.rs b/codex-rs/core/src/tools/code_mode/service.rs deleted file mode 100644 index a9fadedb8..000000000 --- a/codex-rs/core/src/tools/code_mode/service.rs +++ /dev/null @@ -1,108 +0,0 @@ -use std::collections::HashMap; -use std::path::PathBuf; -use std::sync::Arc; - -use serde_json::Value as JsonValue; -use tokio::sync::Mutex; -use tracing::warn; - -use crate::codex::Session; -use crate::codex::TurnContext; -use crate::tools::ToolRouter; -use crate::tools::context::SharedTurnDiffTracker; -use crate::tools::js_repl::resolve_compatible_node; -use crate::tools::parallel::ToolCallRuntime; -use codex_features::Feature; - -use super::ExecContext; -use super::PUBLIC_TOOL_NAME; -use super::process::CodeModeProcess; -use super::process::spawn_code_mode_process; -use super::worker::CodeModeWorker; - -pub(crate) struct CodeModeService { - js_repl_node_path: Option, - stored_values: Mutex>, - process: Arc>>, - next_cell_id: Mutex, -} - -impl CodeModeService { - pub(crate) fn new(js_repl_node_path: Option) -> Self { - Self { - js_repl_node_path, - stored_values: Mutex::new(HashMap::new()), - process: Arc::new(Mutex::new(None)), - next_cell_id: Mutex::new(1), - } - } - - pub(crate) async fn stored_values(&self) -> HashMap { - self.stored_values.lock().await.clone() - } - - pub(crate) async fn replace_stored_values(&self, values: HashMap) { - *self.stored_values.lock().await = values; - } - - pub(super) async fn ensure_started( - &self, - ) -> Result>, std::io::Error> { - let mut process_slot = self.process.lock().await; - let needs_spawn = match process_slot.as_mut() { - Some(process) => !matches!(process.has_exited(), Ok(false)), - None => true, - }; - if needs_spawn { - let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()) - .await - .map_err(std::io::Error::other)?; - *process_slot = Some(spawn_code_mode_process(&node_path).await?); - } - drop(process_slot); - Ok(self.process.clone().lock_owned().await) - } - - pub(crate) async fn start_turn_worker( - &self, - session: &Arc, - turn: &Arc, - router: Arc, - tracker: SharedTurnDiffTracker, - ) -> Option { - if !turn.features.enabled(Feature::CodeMode) { - return None; - } - let exec = ExecContext { - session: Arc::clone(session), - turn: Arc::clone(turn), - }; - let tool_runtime = - ToolCallRuntime::new(router, Arc::clone(session), Arc::clone(turn), tracker); - let mut process_slot = match self.ensure_started().await { - Ok(process_slot) => process_slot, - Err(err) => { - warn!("failed to start {PUBLIC_TOOL_NAME} worker for turn: {err}"); - return None; - } - }; - let Some(process) = process_slot.as_mut() else { - warn!( - "failed to start {PUBLIC_TOOL_NAME} worker for turn: {PUBLIC_TOOL_NAME} runner failed to start" - ); - return None; - }; - Some(process.worker(exec, tool_runtime)) - } - - pub(crate) async fn allocate_cell_id(&self) -> String { - let mut next_cell_id = self.next_cell_id.lock().await; - let cell_id = *next_cell_id; - *next_cell_id = next_cell_id.saturating_add(1); - cell_id.to_string() - } - - pub(crate) async fn allocate_request_id(&self) -> String { - uuid::Uuid::new_v4().to_string() - } -} diff --git a/codex-rs/core/src/tools/code_mode/wait_description.md b/codex-rs/core/src/tools/code_mode/wait_description.md deleted file mode 100644 index 41b928f51..000000000 --- a/codex-rs/core/src/tools/code_mode/wait_description.md +++ /dev/null @@ -1,8 +0,0 @@ -- Use `wait` only after `exec` returns `Script running with cell ID ...`. -- `cell_id` identifies the running `exec` cell to resume. -- `yield_time_ms` controls how long to wait for more output before yielding again. If omitted, `wait` uses its default wait timeout. -- `max_tokens` limits how much new output this wait call returns. -- `terminate: true` stops the running cell instead of waiting for more output. -- `wait` returns only the new output since the last yield, or the final completion or termination result for that cell. -- If the cell is still running, `wait` may yield again with the same `cell_id`. -- If the cell has already finished, `wait` returns the completed result and closes the cell. diff --git a/codex-rs/core/src/tools/code_mode/wait_handler.rs b/codex-rs/core/src/tools/code_mode/wait_handler.rs index caaf8c8c4..f319985a8 100644 --- a/codex-rs/core/src/tools/code_mode/wait_handler.rs +++ b/codex-rs/core/src/tools/code_mode/wait_handler.rs @@ -8,13 +8,10 @@ use crate::tools::context::ToolPayload; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use super::CodeModeSessionProgress; use super::DEFAULT_WAIT_YIELD_TIME_MS; use super::ExecContext; -use super::PUBLIC_TOOL_NAME; use super::WAIT_TOOL_NAME; -use super::handle_node_message; -use super::protocol::HostToNodeMessage; +use super::handle_runtime_response; pub struct CodeModeWaitHandler; @@ -63,66 +60,21 @@ impl ToolHandler for CodeModeWaitHandler { ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => { let args: ExecWaitArgs = parse_arguments(&arguments)?; let exec = ExecContext { session, turn }; - let request_id = exec - .session - .services - .code_mode_service - .allocate_request_id() - .await; let started_at = std::time::Instant::now(); - let message = if args.terminate { - HostToNodeMessage::Terminate { - request_id: request_id.clone(), - cell_id: args.cell_id.clone(), - } - } else { - HostToNodeMessage::Poll { - request_id: request_id.clone(), - cell_id: args.cell_id.clone(), - yield_time_ms: args.yield_time_ms, - } - }; - let process_slot = exec + let response = exec .session .services .code_mode_service - .ensure_started() + .wait(codex_code_mode::WaitRequest { + cell_id: args.cell_id, + yield_time_ms: args.yield_time_ms, + terminate: args.terminate, + }) .await - .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?; - let result = { - let mut process_slot = process_slot; - let Some(process) = process_slot.as_mut() else { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner failed to start" - ))); - }; - if !matches!(process.has_exited(), Ok(false)) { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner failed to start" - ))); - } - let message = process - .send(&request_id, &message) - .await - .map_err(|err| err.to_string()); - let message = match message { - Ok(message) => message, - Err(error) => return Err(FunctionCallError::RespondToModel(error)), - }; - handle_node_message( - &exec, - args.cell_id, - message, - Some(args.max_tokens), - started_at, - ) + .map_err(FunctionCallError::RespondToModel)?; + handle_runtime_response(&exec, response, args.max_tokens, started_at) .await - }; - match result { - Ok(CodeModeSessionProgress::Finished(output)) - | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), - Err(error) => Err(FunctionCallError::RespondToModel(error)), - } + .map_err(FunctionCallError::RespondToModel) } _ => Err(FunctionCallError::RespondToModel(format!( "{WAIT_TOOL_NAME} expects JSON arguments" diff --git a/codex-rs/core/src/tools/code_mode/worker.rs b/codex-rs/core/src/tools/code_mode/worker.rs deleted file mode 100644 index 5853f3abe..000000000 --- a/codex-rs/core/src/tools/code_mode/worker.rs +++ /dev/null @@ -1,116 +0,0 @@ -use tokio::sync::oneshot; -use tokio_util::sync::CancellationToken; -use tracing::error; -use tracing::warn; - -use codex_protocol::models::FunctionCallOutputPayload; -use codex_protocol::models::ResponseInputItem; - -use super::ExecContext; -use super::PUBLIC_TOOL_NAME; -use super::call_nested_tool; -use super::process::CodeModeProcess; -use super::process::write_message; -use super::protocol::HostToNodeMessage; -use super::protocol::NodeToHostMessage; -use crate::tools::parallel::ToolCallRuntime; - -pub(crate) struct CodeModeWorker { - shutdown_tx: Option>, -} - -impl Drop for CodeModeWorker { - fn drop(&mut self) { - if let Some(shutdown_tx) = self.shutdown_tx.take() { - let _ = shutdown_tx.send(()); - } - } -} - -impl CodeModeProcess { - pub(super) fn worker( - &self, - exec: ExecContext, - tool_runtime: ToolCallRuntime, - ) -> CodeModeWorker { - let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); - let stdin = self.stdin.clone(); - let message_rx = self.message_rx.clone(); - tokio::spawn(async move { - loop { - let next_message = tokio::select! { - _ = &mut shutdown_rx => break, - message = async { - let mut message_rx = message_rx.lock().await; - message_rx.recv().await - } => message, - }; - let Some(next_message) = next_message else { - break; - }; - match next_message { - NodeToHostMessage::ToolCall { tool_call } => { - let exec = exec.clone(); - let tool_runtime = tool_runtime.clone(); - let stdin = stdin.clone(); - tokio::spawn(async move { - let result = call_nested_tool( - exec, - tool_runtime, - tool_call.name, - tool_call.input, - CancellationToken::new(), - ) - .await; - let (code_mode_result, error_text) = match result { - Ok(code_mode_result) => (code_mode_result, None), - Err(error) => (serde_json::Value::Null, Some(error.to_string())), - }; - let response = HostToNodeMessage::Response { - request_id: tool_call.request_id, - id: tool_call.id, - code_mode_result, - error_text, - }; - if let Err(err) = write_message(&stdin, &response).await { - warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}"); - } - }); - } - NodeToHostMessage::Notify { notify } => { - if notify.text.trim().is_empty() { - continue; - } - if exec - .session - .inject_response_items(vec![ResponseInputItem::CustomToolCallOutput { - call_id: notify.call_id.clone(), - name: Some(PUBLIC_TOOL_NAME.to_string()), - output: FunctionCallOutputPayload::from_text(notify.text), - }]) - .await - .is_err() - { - warn!( - "failed to inject {PUBLIC_TOOL_NAME} notify message for cell {}: no active turn", - notify.cell_id - ); - } - } - unexpected_message @ (NodeToHostMessage::Yielded { .. } - | NodeToHostMessage::Terminated { .. } - | NodeToHostMessage::Result { .. }) => { - error!( - "received unexpected {PUBLIC_TOOL_NAME} message in worker loop: {unexpected_message:?}" - ); - break; - } - } - } - }); - - CodeModeWorker { - shutdown_tx: Some(shutdown_tx), - } - } -} diff --git a/codex-rs/core/src/tools/code_mode_description.rs b/codex-rs/core/src/tools/code_mode_description.rs index b7722aeb7..fb4fc1f51 100644 --- a/codex-rs/core/src/tools/code_mode_description.rs +++ b/codex-rs/core/src/tools/code_mode_description.rs @@ -1,30 +1,11 @@ use crate::client_common::tools::ToolSpec; -use crate::mcp::split_qualified_tool_name; -use crate::tools::code_mode::PUBLIC_TOOL_NAME; -use serde_json::Value as JsonValue; -pub(crate) struct CodeModeToolReference { - pub(crate) module_path: String, - pub(crate) namespace: Vec, - pub(crate) tool_key: String, -} - -pub(crate) fn code_mode_tool_reference(tool_name: &str) -> CodeModeToolReference { - if let Some((server_name, tool_key)) = split_qualified_tool_name(tool_name) { - let namespace = vec!["mcp".to_string(), server_name]; - return CodeModeToolReference { - module_path: format!("tools/{}.js", namespace.join("/")), - namespace, - tool_key, - }; - } - - CodeModeToolReference { - module_path: "tools.js".to_string(), - namespace: Vec::new(), - tool_key: tool_name.to_string(), - } -} +#[allow(unused_imports)] +#[cfg(test)] +pub(crate) use codex_code_mode::append_code_mode_sample; +#[allow(unused_imports)] +#[cfg(test)] +pub(crate) use codex_code_mode::render_json_schema_to_typescript; pub(crate) fn augment_tool_spec_for_code_mode(spec: ToolSpec, code_mode_enabled: bool) -> ToolSpec { if !code_mode_enabled { @@ -33,27 +14,27 @@ pub(crate) fn augment_tool_spec_for_code_mode(spec: ToolSpec, code_mode_enabled: match spec { ToolSpec::Function(mut tool) => { - if tool.name != PUBLIC_TOOL_NAME { - tool.description = append_code_mode_sample( - &tool.description, - &tool.name, - "args", - serde_json::to_value(&tool.parameters) - .ok() - .as_ref() - .map(render_json_schema_to_typescript) - .unwrap_or_else(|| "unknown".to_string()), - tool.output_schema - .as_ref() - .map(render_json_schema_to_typescript) - .unwrap_or_else(|| "unknown".to_string()), - ); - } + let input_type = serde_json::to_value(&tool.parameters) + .ok() + .map(|schema| codex_code_mode::render_json_schema_to_typescript(&schema)) + .unwrap_or_else(|| "unknown".to_string()); + let output_type = tool + .output_schema + .as_ref() + .map(codex_code_mode::render_json_schema_to_typescript) + .unwrap_or_else(|| "unknown".to_string()); + tool.description = codex_code_mode::append_code_mode_sample( + &tool.description, + &tool.name, + "args", + input_type, + output_type, + ); ToolSpec::Function(tool) } ToolSpec::Freeform(mut tool) => { - if tool.name != PUBLIC_TOOL_NAME { - tool.description = append_code_mode_sample( + if tool.name != codex_code_mode::PUBLIC_TOOL_NAME { + tool.description = codex_code_mode::append_code_mode_sample( &tool.description, &tool.name, "input", @@ -66,234 +47,3 @@ pub(crate) fn augment_tool_spec_for_code_mode(spec: ToolSpec, code_mode_enabled: other => other, } } - -fn append_code_mode_sample( - description: &str, - tool_name: &str, - input_name: &str, - input_type: String, - output_type: String, -) -> String { - let declaration = format!( - "declare const tools: {{ {} }};", - render_code_mode_tool_declaration(tool_name, input_name, input_type, output_type) - ); - format!("{description}\n\nexec tool declaration:\n```ts\n{declaration}\n```") -} - -fn render_code_mode_tool_declaration( - tool_name: &str, - input_name: &str, - input_type: String, - output_type: String, -) -> String { - let tool_name = normalize_code_mode_identifier(tool_name); - format!("{tool_name}({input_name}: {input_type}): Promise<{output_type}>;") -} - -pub(crate) fn normalize_code_mode_identifier(tool_key: &str) -> String { - let mut identifier = String::new(); - - for (index, ch) in tool_key.chars().enumerate() { - let is_valid = if index == 0 { - ch == '_' || ch == '$' || ch.is_ascii_alphabetic() - } else { - ch == '_' || ch == '$' || ch.is_ascii_alphanumeric() - }; - - if is_valid { - identifier.push(ch); - } else { - identifier.push('_'); - } - } - - if identifier.is_empty() { - "_".to_string() - } else { - identifier - } -} - -fn render_json_schema_to_typescript(schema: &JsonValue) -> String { - render_json_schema_to_typescript_inner(schema) -} - -fn render_json_schema_to_typescript_inner(schema: &JsonValue) -> String { - match schema { - JsonValue::Bool(true) => "unknown".to_string(), - JsonValue::Bool(false) => "never".to_string(), - JsonValue::Object(map) => { - if let Some(value) = map.get("const") { - return render_json_schema_literal(value); - } - - if let Some(values) = map.get("enum").and_then(serde_json::Value::as_array) { - let rendered = values - .iter() - .map(render_json_schema_literal) - .collect::>(); - if !rendered.is_empty() { - return rendered.join(" | "); - } - } - - for key in ["anyOf", "oneOf"] { - if let Some(variants) = map.get(key).and_then(serde_json::Value::as_array) { - let rendered = variants - .iter() - .map(render_json_schema_to_typescript_inner) - .collect::>(); - if !rendered.is_empty() { - return rendered.join(" | "); - } - } - } - - if let Some(variants) = map.get("allOf").and_then(serde_json::Value::as_array) { - let rendered = variants - .iter() - .map(render_json_schema_to_typescript_inner) - .collect::>(); - if !rendered.is_empty() { - return rendered.join(" & "); - } - } - - if let Some(schema_type) = map.get("type") { - if let Some(types) = schema_type.as_array() { - let rendered = types - .iter() - .filter_map(serde_json::Value::as_str) - .map(|schema_type| render_json_schema_type_keyword(map, schema_type)) - .collect::>(); - if !rendered.is_empty() { - return rendered.join(" | "); - } - } - - if let Some(schema_type) = schema_type.as_str() { - return render_json_schema_type_keyword(map, schema_type); - } - } - - if map.contains_key("properties") - || map.contains_key("additionalProperties") - || map.contains_key("required") - { - return render_json_schema_object(map); - } - - if map.contains_key("items") || map.contains_key("prefixItems") { - return render_json_schema_array(map); - } - - "unknown".to_string() - } - _ => "unknown".to_string(), - } -} - -fn render_json_schema_type_keyword( - map: &serde_json::Map, - schema_type: &str, -) -> String { - match schema_type { - "string" => "string".to_string(), - "number" | "integer" => "number".to_string(), - "boolean" => "boolean".to_string(), - "null" => "null".to_string(), - "array" => render_json_schema_array(map), - "object" => render_json_schema_object(map), - _ => "unknown".to_string(), - } -} - -fn render_json_schema_array(map: &serde_json::Map) -> String { - if let Some(items) = map.get("items") { - let item_type = render_json_schema_to_typescript_inner(items); - return format!("Array<{item_type}>"); - } - - if let Some(items) = map.get("prefixItems").and_then(serde_json::Value::as_array) { - let item_types = items - .iter() - .map(render_json_schema_to_typescript_inner) - .collect::>(); - if !item_types.is_empty() { - return format!("[{}]", item_types.join(", ")); - } - } - - "unknown[]".to_string() -} - -fn render_json_schema_object(map: &serde_json::Map) -> String { - let required = map - .get("required") - .and_then(serde_json::Value::as_array) - .map(|items| { - items - .iter() - .filter_map(serde_json::Value::as_str) - .collect::>() - }) - .unwrap_or_default(); - let properties = map - .get("properties") - .and_then(serde_json::Value::as_object) - .cloned() - .unwrap_or_default(); - - let mut sorted_properties = properties.iter().collect::>(); - sorted_properties.sort_unstable_by(|(name_a, _), (name_b, _)| name_a.cmp(name_b)); - let mut lines = sorted_properties - .into_iter() - .map(|(name, value)| { - let optional = if required.iter().any(|required_name| required_name == name) { - "" - } else { - "?" - }; - let property_name = render_json_schema_property_name(name); - let property_type = render_json_schema_to_typescript_inner(value); - format!("{property_name}{optional}: {property_type};") - }) - .collect::>(); - - if let Some(additional_properties) = map.get("additionalProperties") { - let additional_type = match additional_properties { - JsonValue::Bool(true) => Some("unknown".to_string()), - JsonValue::Bool(false) => None, - value => Some(render_json_schema_to_typescript_inner(value)), - }; - - if let Some(additional_type) = additional_type { - lines.push(format!("[key: string]: {additional_type};")); - } - } else if properties.is_empty() { - lines.push("[key: string]: unknown;".to_string()); - } - - if lines.is_empty() { - return "{}".to_string(); - } - - format!("{{ {} }}", lines.join(" ")) -} - -fn render_json_schema_property_name(name: &str) -> String { - if normalize_code_mode_identifier(name) == name { - name.to_string() - } else { - serde_json::to_string(name).unwrap_or_else(|_| format!("\"{}\"", name.replace('"', "\\\""))) - } -} - -fn render_json_schema_literal(value: &JsonValue) -> String { - serde_json::to_string(value).unwrap_or_else(|_| "unknown".to_string()) -} - -#[cfg(test)] -#[path = "code_mode_description_tests.rs"] -mod tests; diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 8544eb404..345f7ce06 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -4,7 +4,6 @@ use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; use crate::mcp_connection_manager::ToolInfo; use crate::sandboxing::SandboxPermissions; -use crate::tools::code_mode::is_code_mode_nested_tool; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; @@ -67,7 +66,7 @@ impl ToolRouter { specs .iter() .filter_map(|configured_tool| { - if !is_code_mode_nested_tool(configured_tool.spec.name()) { + if !codex_code_mode::is_code_mode_nested_tool(configured_tool.spec.name()) { Some(configured_tool.spec.clone()) } else { None diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index 5ae2f333d..2e0a413a6 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -11,9 +11,6 @@ use crate::shell::Shell; use crate::shell::ShellType; use crate::tools::code_mode::PUBLIC_TOOL_NAME; use crate::tools::code_mode::WAIT_TOOL_NAME; -use crate::tools::code_mode::is_code_mode_nested_tool; -use crate::tools::code_mode::tool_description as code_mode_tool_description; -use crate::tools::code_mode::wait_tool_description as code_mode_wait_tool_description; use crate::tools::code_mode_description::augment_tool_spec_for_code_mode; use crate::tools::discoverable::DiscoverablePluginInfo; use crate::tools::discoverable::DiscoverableTool; @@ -833,7 +830,7 @@ fn create_wait_tool() -> ToolSpec { name: WAIT_TOOL_NAME.to_string(), description: format!( "Waits on a yielded `{PUBLIC_TOOL_NAME}` cell and returns new output or completion.\n{}", - code_mode_wait_tool_description().trim() + codex_code_mode::build_wait_tool_description().trim() ), strict: false, parameters: JsonSchema::Object { @@ -2176,7 +2173,10 @@ SOURCE: /[\s\S]+/ ToolSpec::Freeform(FreeformTool { name: PUBLIC_TOOL_NAME.to_string(), - description: code_mode_tool_description(enabled_tools, code_mode_only_enabled), + description: codex_code_mode::build_exec_tool_description( + enabled_tools, + code_mode_only_enabled, + ), format: FreeformToolFormat { r#type: "grammar".to_string(), syntax: "lark".to_string(), @@ -2647,7 +2647,7 @@ pub(crate) fn build_specs_with_discoverable_tools( ToolSpec::Freeform(tool) => (tool.name, tool.description), _ => return None, }; - is_code_mode_nested_tool(&name).then_some((name, description)) + codex_code_mode::is_code_mode_nested_tool(&name).then_some((name, description)) }) .collect::>(); enabled_tools.sort_by(|left, right| left.0.cmp(&right.0)); diff --git a/codex-rs/core/tests/suite/code_mode.rs b/codex-rs/core/tests/suite/code_mode.rs index c74de38e8..b9e4f05b3 100644 --- a/codex-rs/core/tests/suite/code_mode.rs +++ b/codex-rs/core/tests/suite/code_mode.rs @@ -1672,8 +1672,6 @@ async fn code_mode_exit_stops_script_immediately() -> Result<()> { &server, "use exec to stop script early with exit helper", r#" -import { exit, text } from "@openai/code_mode"; - text("before"); exit(); text("after"); @@ -2129,6 +2127,7 @@ text(JSON.stringify(Object.getOwnPropertyNames(globalThis).sort())); "SuppressedError", "Symbol", "SyntaxError", + "Temporal", "TypeError", "URIError", "Uint16Array", @@ -2141,7 +2140,6 @@ text(JSON.stringify(Object.getOwnPropertyNames(globalThis).sort())); "WebAssembly", "__codexContentItems", "add_content", - "console", "decodeURI", "decodeURIComponent", "encodeURI", @@ -2282,10 +2280,8 @@ async fn code_mode_can_call_hidden_dynamic_tools() -> Result<()> { test.session_configured = new_thread.session_configured; let code = r#" -import { ALL_TOOLS, hidden_dynamic_tool } from "tools.js"; - const tool = ALL_TOOLS.find(({ name }) => name === "hidden_dynamic_tool"); -const out = await hidden_dynamic_tool({ city: "Paris" }); +const out = await tools.hidden_dynamic_tool({ city: "Paris" }); text( JSON.stringify({ name: tool?.name ?? null, diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index 1e6073be0..7252d9a6b 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -159,7 +159,9 @@ async fn unified_exec_intercepts_apply_patch_exec_command() -> Result<()> { let call_id = "uexec-apply-patch"; let args = json!({ "cmd": command, - "yield_time_ms": 250, + // The intercepted apply_patch path spawns a helper process, which can + // take longer than a tiny unified-exec yield deadline on CI. + "yield_time_ms": 5_000, }); let responses = vec![