From dcab40123f5e64ba8af962ae27abe6cbcc205344 Mon Sep 17 00:00:00 2001 From: daveaitel-openai Date: Tue, 24 Feb 2026 16:00:19 -0500 Subject: [PATCH] Agent jobs (spawn_agents_on_csv) + progress UI (#10935) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add agent job support: spawn a batch of sub-agents from CSV, auto-run, auto-export, and store results in SQLite. - Simplify workflow: remove run/resume/get-status/export tools; spawn is deterministic and completes in one call. - Improve exec UX: stable, single-line progress bar with ETA; suppress sub-agent chatter in exec. ## Why Enables map-reduce style workflows over arbitrarily large repos using the existing Codex orchestrator. This addresses review feedback about overly complex job controls and non-deterministic monitoring. ## Demo (progress bar) ``` ./codex-rs/target/debug/codex exec \ --enable collab \ --enable sqlite \ --full-auto \ --progress-cursor \ -c agents.max_threads=16 \ -C /Users/daveaitel/code/codex \ - <<'PROMPT' Create /tmp/agent_job_progress_demo.csv with columns: path,area and 30 rows: path = item-01..item-30, area = test. Then call spawn_agents_on_csv with: - csv_path: /tmp/agent_job_progress_demo.csv - instruction: "Run `python - <<'PY'` to sleep a random 0.3–1.2s, then output JSON with keys: path, score (int). Set score = 1." - output_csv_path: /tmp/agent_job_progress_demo_out.csv PROMPT ``` ## Review feedback addressed - Auto-start jobs on spawn; removed run/resume/status/export tools. - Auto-export on success. - More descriptive tool spec + clearer prompts. - Avoid deadlocks on spawn failure; pending/running handled safely. - Progress bar no longer scrolls; stable single-line redraw. ## Tests - `cd codex-rs && cargo test -p codex-exec` - `cd codex-rs && cargo build -p codex-cli` --- codex-rs/Cargo.lock | 1 + codex-rs/Cargo.toml | 1 + .../app-server/tests/common/mcp_process.rs | 1 + .../suite/v2/experimental_feature_list.rs | 8 +- codex-rs/core/Cargo.toml | 1 + codex-rs/core/config.schema.json | 20 + codex-rs/core/src/agent/guards.rs | 6 +- codex-rs/core/src/agent/mod.rs | 1 + codex-rs/core/src/codex.rs | 6 + codex-rs/core/src/config/mod.rs | 107 +- codex-rs/core/src/shell_snapshot.rs | 52 +- codex-rs/core/src/state_db.rs | 10 +- codex-rs/core/src/tools/context.rs | 7 + .../core/src/tools/handlers/agent_jobs.rs | 1227 +++++++++++++++++ .../core/src/tools/handlers/apply_patch.rs | 1 + codex-rs/core/src/tools/handlers/mod.rs | 1 + .../core/src/tools/handlers/multi_agents.rs | 21 +- codex-rs/core/src/tools/handlers/shell.rs | 2 + .../core/src/tools/handlers/view_image.rs | 24 +- codex-rs/core/src/tools/router.rs | 9 +- codex-rs/core/src/tools/spec.rs | 258 +++- codex-rs/core/tests/common/test_codex.rs | 8 + codex-rs/core/tests/suite/agent_jobs.rs | 424 ++++++ codex-rs/core/tests/suite/mod.rs | 1 + codex-rs/core/tests/suite/shell_snapshot.rs | 14 +- codex-rs/core/tests/suite/sqlite_state.rs | 6 +- codex-rs/exec/src/cli.rs | 4 + .../src/event_processor_with_human_output.rs | 244 ++++ codex-rs/exec/src/lib.rs | 66 +- codex-rs/state/migrations/0014_agent_jobs.sql | 38 + .../0015_agent_jobs_max_runtime_seconds.sql | 2 + codex-rs/state/src/lib.rs | 10 + codex-rs/state/src/model/agent_job.rs | 256 ++++ codex-rs/state/src/model/mod.rs | 10 + codex-rs/state/src/runtime.rs | 567 ++++++++ docs/config.md | 6 + 36 files changed, 3370 insertions(+), 50 deletions(-) create mode 100644 codex-rs/core/src/tools/handlers/agent_jobs.rs create mode 100644 codex-rs/core/tests/suite/agent_jobs.rs create mode 100644 codex-rs/state/migrations/0014_agent_jobs.sql create mode 100644 codex-rs/state/migrations/0015_agent_jobs_max_runtime_seconds.sql create mode 100644 codex-rs/state/src/model/agent_job.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ca09a75fa..056eae957 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1726,6 +1726,7 @@ dependencies = [ "codex-windows-sandbox", "core-foundation 0.9.4", "core_test_support", + "csv", "ctor 0.6.3", "dirs", "dunce", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index baed07814..b3bb85d5a 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -160,6 +160,7 @@ clap = "4" clap_complete = "4" color-eyre = "0.6.3" crossbeam-channel = "0.5.15" +csv = "1.3.1" crossterm = "0.28.1" ctor = "0.6.3" derive_more = "2" diff --git a/codex-rs/app-server/tests/common/mcp_process.rs b/codex-rs/app-server/tests/common/mcp_process.rs index 4e48f909b..bb35c2c0a 100644 --- a/codex-rs/app-server/tests/common/mcp_process.rs +++ b/codex-rs/app-server/tests/common/mcp_process.rs @@ -105,6 +105,7 @@ impl McpProcess { cmd.stdin(Stdio::piped()); cmd.stdout(Stdio::piped()); cmd.stderr(Stdio::piped()); + cmd.current_dir(codex_home); cmd.env("CODEX_HOME", codex_home); cmd.env("RUST_LOG", "debug"); cmd.env_remove(CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR); diff --git a/codex-rs/app-server/tests/suite/v2/experimental_feature_list.rs b/codex-rs/app-server/tests/suite/v2/experimental_feature_list.rs index fdcbaca5b..58deb5f82 100644 --- a/codex-rs/app-server/tests/suite/v2/experimental_feature_list.rs +++ b/codex-rs/app-server/tests/suite/v2/experimental_feature_list.rs @@ -9,6 +9,7 @@ use codex_app_server_protocol::ExperimentalFeatureListResponse; use codex_app_server_protocol::ExperimentalFeatureStage; use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; +use codex_core::config::ConfigBuilder; use codex_core::features::FEATURES; use codex_core::features::Stage; use pretty_assertions::assert_eq; @@ -20,6 +21,11 @@ const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); #[tokio::test] async fn experimental_feature_list_returns_feature_metadata_with_stage() -> Result<()> { let codex_home = TempDir::new()?; + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .fallback_cwd(Some(codex_home.path().to_path_buf())) + .build() + .await?; let mut mcp = McpProcess::new(codex_home.path()).await?; timeout(DEFAULT_TIMEOUT, mcp.initialize()).await??; @@ -63,7 +69,7 @@ async fn experimental_feature_list_returns_feature_metadata_with_stage() -> Resu display_name, description, announcement, - enabled: spec.default_enabled, + enabled: config.features.enabled(spec.id), default_enabled: spec.default_enabled, } }) diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 5d132d58b..820f35cb1 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -51,6 +51,7 @@ codex-utils-readiness = { workspace = true } codex-secrets = { workspace = true } codex-utils-string = { workspace = true } codex-windows-sandbox = { package = "codex-windows-sandbox", path = "../windows-sandbox-rs" } +csv = { workspace = true } dirs = { workspace = true } dunce = { workspace = true } encoding_rs = { workspace = true } diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index c6689710b..ca7fb5157 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -29,12 +29,24 @@ "$ref": "#/definitions/AgentRoleToml" }, "properties": { + "job_max_runtime_seconds": { + "description": "Default maximum runtime in seconds for agent job workers.", + "format": "uint64", + "minimum": 1.0, + "type": "integer" + }, "max_depth": { "description": "Maximum nesting depth allowed for spawned agent threads. Root sessions start at depth 0.", "format": "int32", "minimum": 1.0, "type": "integer" }, + "max_spawn_depth": { + "description": "Maximum depth for thread-spawned subagents.", + "format": "uint", + "minimum": 1.0, + "type": "integer" + }, "max_threads": { "description": "Maximum number of agent threads that can be open concurrently. When unset, no limit is enforced.", "format": "uint", @@ -2040,6 +2052,14 @@ ], "description": "User-level skill config entries keyed by SKILL.md path." }, + "sqlite_home": { + "allOf": [ + { + "$ref": "#/definitions/AbsolutePathBuf" + } + ], + "description": "Directory where Codex stores the SQLite state DB. Defaults to `$CODEX_SQLITE_HOME` when set. Otherwise uses a temp dir under WorkspaceWrite sandboxing and `$CODEX_HOME` for other modes." + }, "suppress_unstable_features_warning": { "description": "Suppress warnings about unstable (under development) features.", "type": "boolean" diff --git a/codex-rs/core/src/agent/guards.rs b/codex-rs/core/src/agent/guards.rs index 422e10b65..b8db6e397 100644 --- a/codex-rs/core/src/agent/guards.rs +++ b/codex-rs/core/src/agent/guards.rs @@ -1,3 +1,4 @@ +use crate::config::DEFAULT_AGENT_MAX_SPAWN_DEPTH; use crate::error::CodexErr; use crate::error::Result; use codex_protocol::ThreadId; @@ -30,7 +31,6 @@ struct ActiveAgents { used_agent_nicknames: HashSet, nickname_reset_count: usize, } - fn session_depth(session_source: &SessionSource) -> i32 { match session_source { SessionSource::SubAgent(SubAgentSource::ThreadSpawn { depth, .. }) => *depth, @@ -43,6 +43,10 @@ pub(crate) fn next_thread_spawn_depth(session_source: &SessionSource) -> i32 { session_depth(session_source).saturating_add(1) } +pub(crate) fn max_thread_spawn_depth(max_depth: Option) -> i32 { + let max_depth = max_depth.or(DEFAULT_AGENT_MAX_SPAWN_DEPTH).unwrap_or(1); + i32::try_from(max_depth).unwrap_or(i32::MAX) +} pub(crate) fn exceeds_thread_spawn_depth_limit(depth: i32, max_depth: i32) -> bool { depth > max_depth } diff --git a/codex-rs/core/src/agent/mod.rs b/codex-rs/core/src/agent/mod.rs index 15be909c3..6ae7d9615 100644 --- a/codex-rs/core/src/agent/mod.rs +++ b/codex-rs/core/src/agent/mod.rs @@ -6,5 +6,6 @@ pub(crate) mod status; pub(crate) use codex_protocol::protocol::AgentStatus; pub(crate) use control::AgentControl; pub(crate) use guards::exceeds_thread_spawn_depth_limit; +pub(crate) use guards::max_thread_spawn_depth; pub(crate) use guards::next_thread_spawn_depth; pub(crate) use status::agent_status_from_event; diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 012ae96b3..397318c8e 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -640,6 +640,7 @@ impl TurnContext { model_info: &model_info, features: &features, web_search_mode: self.tools_config.web_search_mode, + session_source: self.session_source.clone(), }) .with_allow_login_shell(self.tools_config.allow_login_shell) .with_agent_roles(config.agent_roles.clone()); @@ -975,6 +976,7 @@ impl Session { model_info: &model_info, features: &per_turn_config.features, web_search_mode: Some(per_turn_config.web_search_mode.value()), + session_source: session_source.clone(), }) .with_allow_login_shell(per_turn_config.permissions.allow_login_shell) .with_agent_roles(per_turn_config.agent_roles.clone()); @@ -4592,6 +4594,7 @@ async fn spawn_review_thread( model_info: &review_model_info, features: &review_features, web_search_mode: Some(review_web_search_mode), + session_source: parent_turn_context.session_source.clone(), }) .with_allow_login_shell(config.permissions.allow_login_shell) .with_agent_roles(config.agent_roles.clone()); @@ -9267,6 +9270,7 @@ mod tests { }) .to_string(), }, + source: ToolCallSource::Direct, }) .await; @@ -9306,6 +9310,7 @@ mod tests { }) .to_string(), }, + source: ToolCallSource::Direct, }) .await; @@ -9365,6 +9370,7 @@ mod tests { }) .to_string(), }, + source: ToolCallSource::Direct, }) .await; diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 1f81eaca4..1b44a7ee2 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -115,8 +115,35 @@ pub use codex_git::GhostSnapshotConfig; /// the context window. pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB pub(crate) const DEFAULT_AGENT_MAX_THREADS: Option = Some(6); +pub(crate) const DEFAULT_AGENT_MAX_SPAWN_DEPTH: Option = Some(2); pub(crate) const DEFAULT_AGENT_MAX_DEPTH: i32 = 1; +pub(crate) const DEFAULT_AGENT_JOB_MAX_RUNTIME_SECONDS: Option = None; +pub const CONFIG_TOML_FILE: &str = "config.toml"; + +fn default_sqlite_home(sandbox_policy: &SandboxPolicy, codex_home: &Path) -> PathBuf { + if matches!(sandbox_policy, SandboxPolicy::WorkspaceWrite { .. }) { + let mut path = std::env::temp_dir(); + path.push("codex-sqlite"); + path + } else { + codex_home.to_path_buf() + } +} + +fn resolve_sqlite_home_env(resolved_cwd: &Path) -> Option { + let raw = std::env::var(codex_state::SQLITE_HOME_ENV).ok()?; + let trimmed = raw.trim(); + if trimmed.is_empty() { + return None; + } + let path = PathBuf::from(trimmed); + if path.is_absolute() { + Some(path) + } else { + Some(resolved_cwd.join(path)) + } +} #[cfg(test)] pub(crate) fn test_config() -> Config { let codex_home = tempdir().expect("create temp dir"); @@ -330,6 +357,10 @@ pub struct Config { /// Maximum number of agent threads that can be open concurrently. pub agent_max_threads: Option, + /// Maximum depth for thread-spawned subagents. + pub agent_max_spawn_depth: Option, + /// Maximum runtime in seconds for agent job workers before they are failed. + pub agent_job_max_runtime_seconds: Option, /// Maximum nesting depth allowed for spawned agent threads. pub agent_max_depth: i32, @@ -344,6 +375,9 @@ pub struct Config { /// overridden by the `CODEX_HOME` environment variable). pub codex_home: PathBuf, + /// Directory where Codex stores the SQLite state DB. + pub sqlite_home: PathBuf, + /// Directory where Codex writes log files (defaults to `$CODEX_HOME/log`). pub log_dir: PathBuf, @@ -1108,6 +1142,11 @@ pub struct ConfigToml { #[serde(default)] pub history: Option, + /// Directory where Codex stores the SQLite state DB. + /// Defaults to `$CODEX_SQLITE_HOME` when set. Otherwise uses a temp dir + /// under WorkspaceWrite sandboxing and `$CODEX_HOME` for other modes. + pub sqlite_home: Option, + /// Directory where Codex writes log files, for example `codex-tui.log`. /// Defaults to `$CODEX_HOME/log`. pub log_dir: Option, @@ -1295,11 +1334,16 @@ pub struct AgentsToml { /// When unset, no limit is enforced. #[schemars(range(min = 1))] pub max_threads: Option, - + /// Maximum depth for thread-spawned subagents. + #[schemars(range(min = 1))] + pub max_spawn_depth: Option, /// Maximum nesting depth allowed for spawned agent threads. /// Root sessions start at depth 0. #[schemars(range(min = 1))] pub max_depth: Option, + /// Default maximum runtime in seconds for agent job workers. + #[schemars(range(min = 1))] + pub job_max_runtime_seconds: Option, /// User-defined role declarations keyed by role name. /// @@ -1813,6 +1857,44 @@ impl Config { }) .transpose()? .unwrap_or_default(); + let agent_max_spawn_depth = cfg + .agents + .as_ref() + .and_then(|agents| agents.max_spawn_depth) + .or(DEFAULT_AGENT_MAX_SPAWN_DEPTH); + if agent_max_spawn_depth == Some(0) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "agents.max_spawn_depth must be at least 1", + )); + } + if let Some(max_spawn_depth) = agent_max_spawn_depth + && max_spawn_depth > i32::MAX as usize + { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "agents.max_spawn_depth must fit within a 32-bit signed integer", + )); + } + let agent_job_max_runtime_seconds = cfg + .agents + .as_ref() + .and_then(|agents| agents.job_max_runtime_seconds) + .or(DEFAULT_AGENT_JOB_MAX_RUNTIME_SECONDS); + if agent_job_max_runtime_seconds == Some(0) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "agents.job_max_runtime_seconds must be at least 1", + )); + } + if let Some(max_runtime_seconds) = agent_job_max_runtime_seconds + && max_runtime_seconds > i64::MAX as u64 + { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "agents.job_max_runtime_seconds must fit within a 64-bit signed integer", + )); + } let background_terminal_max_timeout = cfg .background_terminal_max_timeout .unwrap_or(DEFAULT_MAX_BACKGROUND_TERMINAL_TIMEOUT_MS) @@ -1937,6 +2019,12 @@ impl Config { p.push("log"); p }); + let sqlite_home = cfg + .sqlite_home + .as_ref() + .map(AbsolutePathBuf::to_path_buf) + .or_else(|| resolve_sqlite_home_env(&resolved_cwd)) + .unwrap_or_else(|| default_sqlite_home(&sandbox_policy, &codex_home)); // Ensure that every field of ConfigRequirements is applied to the final // Config. @@ -2053,7 +2141,10 @@ impl Config { agent_max_depth, agent_roles, memories: cfg.memories.unwrap_or_default().into(), + agent_max_spawn_depth, + agent_job_max_runtime_seconds, codex_home, + sqlite_home, log_dir, config_layer_stack, history, @@ -4387,7 +4478,9 @@ model = "gpt-5.1-codex" let cfg = ConfigToml { agents: Some(AgentsToml { max_threads: None, + max_spawn_depth: None, max_depth: None, + job_max_runtime_seconds: None, roles: BTreeMap::from([( "researcher".to_string(), AgentRoleToml { @@ -4661,7 +4754,10 @@ model_verbosity = "high" agent_max_depth: DEFAULT_AGENT_MAX_DEPTH, agent_roles: BTreeMap::new(), memories: MemoriesConfig::default(), + agent_max_spawn_depth: DEFAULT_AGENT_MAX_SPAWN_DEPTH, + agent_job_max_runtime_seconds: DEFAULT_AGENT_JOB_MAX_RUNTIME_SECONDS, codex_home: fixture.codex_home(), + sqlite_home: fixture.codex_home(), log_dir: fixture.codex_home().join("log"), config_layer_stack: Default::default(), startup_warnings: Vec::new(), @@ -4784,7 +4880,10 @@ model_verbosity = "high" agent_max_depth: DEFAULT_AGENT_MAX_DEPTH, agent_roles: BTreeMap::new(), memories: MemoriesConfig::default(), + agent_max_spawn_depth: DEFAULT_AGENT_MAX_SPAWN_DEPTH, + agent_job_max_runtime_seconds: DEFAULT_AGENT_JOB_MAX_RUNTIME_SECONDS, codex_home: fixture.codex_home(), + sqlite_home: fixture.codex_home(), log_dir: fixture.codex_home().join("log"), config_layer_stack: Default::default(), startup_warnings: Vec::new(), @@ -4905,7 +5004,10 @@ model_verbosity = "high" agent_max_depth: DEFAULT_AGENT_MAX_DEPTH, agent_roles: BTreeMap::new(), memories: MemoriesConfig::default(), + agent_max_spawn_depth: DEFAULT_AGENT_MAX_SPAWN_DEPTH, + agent_job_max_runtime_seconds: DEFAULT_AGENT_JOB_MAX_RUNTIME_SECONDS, codex_home: fixture.codex_home(), + sqlite_home: fixture.codex_home(), log_dir: fixture.codex_home().join("log"), config_layer_stack: Default::default(), startup_warnings: Vec::new(), @@ -5012,7 +5114,10 @@ model_verbosity = "high" agent_max_depth: DEFAULT_AGENT_MAX_DEPTH, agent_roles: BTreeMap::new(), memories: MemoriesConfig::default(), + agent_max_spawn_depth: DEFAULT_AGENT_MAX_SPAWN_DEPTH, + agent_job_max_runtime_seconds: DEFAULT_AGENT_JOB_MAX_RUNTIME_SECONDS, codex_home: fixture.codex_home(), + sqlite_home: fixture.codex_home(), log_dir: fixture.codex_home().join("log"), config_layer_stack: Default::default(), startup_warnings: Vec::new(), diff --git a/codex-rs/core/src/shell_snapshot.rs b/codex-rs/core/src/shell_snapshot.rs index 4e5ddd2e4..7dc122e93 100644 --- a/codex-rs/core/src/shell_snapshot.rs +++ b/codex-rs/core/src/shell_snapshot.rs @@ -123,6 +123,13 @@ impl ShellSnapshot { let path = codex_home .join(SNAPSHOT_DIR) .join(format!("{session_id}.{extension}")); + let nonce = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map(|duration| duration.as_nanos()) + .unwrap_or(0); + let temp_path = codex_home + .join(SNAPSHOT_DIR) + .join(format!("{session_id}.tmp-{nonce}")); // Clean the (unlikely) leaked snapshot files. let codex_home = codex_home.to_path_buf(); @@ -134,31 +141,42 @@ impl ShellSnapshot { }); // Make the new snapshot. - let path = match write_shell_snapshot(shell.shell_type.clone(), &path, session_cwd).await { - Ok(path) => { - tracing::info!("Shell snapshot successfully created: {}", path.display()); - path - } - Err(err) => { - tracing::warn!( - "Failed to create shell snapshot for {}: {err:?}", - shell.name() - ); - return Err("write_failed"); - } - }; + let temp_path = + match write_shell_snapshot(shell.shell_type.clone(), &temp_path, session_cwd).await { + Ok(path) => { + tracing::info!("Shell snapshot successfully created: {}", path.display()); + path + } + Err(err) => { + tracing::warn!( + "Failed to create shell snapshot for {}: {err:?}", + shell.name() + ); + return Err("write_failed"); + } + }; - let snapshot = Self { - path, + let temp_snapshot = Self { + path: temp_path.clone(), cwd: session_cwd.to_path_buf(), }; - if let Err(err) = validate_snapshot(shell, &snapshot.path, session_cwd).await { + if let Err(err) = validate_snapshot(shell, &temp_snapshot.path, session_cwd).await { tracing::error!("Shell snapshot validation failed: {err:?}"); + remove_snapshot_file(&temp_snapshot.path).await; return Err("validation_failed"); } - Ok(snapshot) + if let Err(err) = fs::rename(&temp_snapshot.path, &path).await { + tracing::warn!("Failed to finalize shell snapshot: {err:?}"); + remove_snapshot_file(&temp_snapshot.path).await; + return Err("write_failed"); + } + + Ok(Self { + path, + cwd: session_cwd.to_path_buf(), + }) } } diff --git a/codex-rs/core/src/state_db.rs b/codex-rs/core/src/state_db.rs index 0c5a14a51..7f5b62a12 100644 --- a/codex-rs/core/src/state_db.rs +++ b/codex-rs/core/src/state_db.rs @@ -37,7 +37,7 @@ pub(crate) async fn init_if_enabled( return None; } let runtime = match codex_state::StateRuntime::init( - config.codex_home.clone(), + config.sqlite_home.clone(), config.model_provider_id.clone(), otel.cloned(), ) @@ -47,7 +47,7 @@ pub(crate) async fn init_if_enabled( Err(err) => { warn!( "failed to initialize state runtime at {}: {err}", - config.codex_home.display() + config.sqlite_home.display() ); if let Some(otel) = otel { otel.counter("codex.db.init", 1, &[("status", "init_error")]); @@ -79,20 +79,20 @@ pub(crate) async fn init_if_enabled( /// Get the DB if the feature is enabled and the DB exists. pub async fn get_state_db(config: &Config, otel: Option<&OtelManager>) -> Option { - let state_path = codex_state::state_db_path(config.codex_home.as_path()); + let state_path = codex_state::state_db_path(config.sqlite_home.as_path()); if !config.features.enabled(Feature::Sqlite) || !tokio::fs::try_exists(&state_path).await.unwrap_or(false) { return None; } let runtime = codex_state::StateRuntime::init( - config.codex_home.clone(), + config.sqlite_home.clone(), config.model_provider_id.clone(), otel.cloned(), ) .await .ok()?; - require_backfill_complete(runtime, config.codex_home.as_path()).await + require_backfill_complete(runtime, config.sqlite_home.as_path()).await } /// Open the state runtime when the SQLite file exists, without feature gating. diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs index e9edd7db4..081d06687 100644 --- a/codex-rs/core/src/tools/context.rs +++ b/codex-rs/core/src/tools/context.rs @@ -16,6 +16,12 @@ use tokio::sync::Mutex; pub type SharedTurnDiffTracker = Arc>; +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ToolCallSource { + Direct, + JsRepl, +} + #[derive(Clone)] pub struct ToolInvocation { pub session: Arc, @@ -24,6 +30,7 @@ pub struct ToolInvocation { pub call_id: String, pub tool_name: String, pub payload: ToolPayload, + pub source: ToolCallSource, } #[derive(Clone, Debug)] diff --git a/codex-rs/core/src/tools/handlers/agent_jobs.rs b/codex-rs/core/src/tools/handlers/agent_jobs.rs new file mode 100644 index 000000000..6b36be05d --- /dev/null +++ b/codex-rs/core/src/tools/handlers/agent_jobs.rs @@ -0,0 +1,1227 @@ +use crate::agent::exceeds_thread_spawn_depth_limit; +use crate::agent::max_thread_spawn_depth; +use crate::agent::next_thread_spawn_depth; +use crate::agent::status::is_final; +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::config::Config; +use crate::error::CodexErr; +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::multi_agents::build_agent_spawn_config; +use crate::tools::handlers::parse_arguments; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use async_trait::async_trait; +use codex_protocol::ThreadId; +use codex_protocol::models::FunctionCallOutputBody; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; +use codex_protocol::user_input::UserInput; +use serde::Deserialize; +use serde::Serialize; +use serde_json::Value; +use std::collections::HashMap; +use std::collections::HashSet; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::time::Duration; +use tokio::time::Instant; +use uuid::Uuid; + +pub struct BatchJobHandler; + +const DEFAULT_AGENT_JOB_CONCURRENCY: usize = 16; +const MAX_AGENT_JOB_CONCURRENCY: usize = 64; +const STATUS_POLL_INTERVAL: Duration = Duration::from_millis(250); +const PROGRESS_EMIT_INTERVAL: Duration = Duration::from_secs(1); +const DEFAULT_AGENT_JOB_ITEM_TIMEOUT: Duration = Duration::from_secs(60 * 30); + +#[derive(Debug, Deserialize)] +struct SpawnAgentsOnCsvArgs { + csv_path: String, + instruction: String, + id_column: Option, + output_csv_path: Option, + output_schema: Option, + max_concurrency: Option, + max_workers: Option, + max_runtime_seconds: Option, +} + +#[derive(Debug, Deserialize)] +struct ReportAgentJobResultArgs { + job_id: String, + item_id: String, + result: Value, + stop: Option, +} + +#[derive(Debug, Serialize)] +struct SpawnAgentsOnCsvResult { + job_id: String, + status: String, + output_csv_path: String, + total_items: usize, + completed_items: usize, + failed_items: usize, + job_error: Option, + failed_item_errors: Option>, +} + +#[derive(Debug, Serialize)] +struct AgentJobFailureSummary { + item_id: String, + source_id: Option, + last_error: String, +} + +#[derive(Debug, Serialize)] +struct AgentJobProgressUpdate { + job_id: String, + total_items: usize, + pending_items: usize, + running_items: usize, + completed_items: usize, + failed_items: usize, + eta_seconds: Option, +} + +#[derive(Debug, Serialize)] +struct ReportAgentJobResultToolResult { + accepted: bool, +} + +#[derive(Debug, Clone)] +struct JobRunnerOptions { + max_concurrency: usize, + spawn_config: Config, +} + +#[derive(Debug, Clone)] +struct ActiveJobItem { + item_id: String, + started_at: Instant, +} + +struct JobProgressEmitter { + started_at: Instant, + last_emit_at: Instant, + last_processed: usize, + last_failed: usize, +} + +impl JobProgressEmitter { + fn new() -> Self { + let now = Instant::now(); + let last_emit_at = now.checked_sub(PROGRESS_EMIT_INTERVAL).unwrap_or(now); + Self { + started_at: now, + last_emit_at, + last_processed: 0, + last_failed: 0, + } + } + + async fn maybe_emit( + &mut self, + session: &Session, + turn: &TurnContext, + job_id: &str, + progress: &codex_state::AgentJobProgress, + force: bool, + ) -> anyhow::Result<()> { + let processed = progress.completed_items + progress.failed_items; + let should_emit = force + || processed != self.last_processed + || progress.failed_items != self.last_failed + || self.last_emit_at.elapsed() >= PROGRESS_EMIT_INTERVAL; + if !should_emit { + return Ok(()); + } + let elapsed = self.started_at.elapsed().as_secs_f64(); + let eta_seconds = if processed > 0 && elapsed > 0.0 { + let remaining = progress.total_items.saturating_sub(processed) as f64; + let rate = processed as f64 / elapsed; + if rate > 0.0 { + Some((remaining / rate).round() as u64) + } else { + None + } + } else { + None + }; + let update = AgentJobProgressUpdate { + job_id: job_id.to_string(), + total_items: progress.total_items, + pending_items: progress.pending_items, + running_items: progress.running_items, + completed_items: progress.completed_items, + failed_items: progress.failed_items, + eta_seconds, + }; + let payload = serde_json::to_string(&update)?; + session + .notify_background_event(turn, format!("agent_job_progress:{payload}")) + .await; + self.last_emit_at = Instant::now(); + self.last_processed = processed; + self.last_failed = progress.failed_items; + Ok(()) + } +} + +#[async_trait] +impl ToolHandler for BatchJobHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tool_name, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "agent jobs handler received unsupported payload".to_string(), + )); + } + }; + + match tool_name.as_str() { + "spawn_agents_on_csv" => spawn_agents_on_csv::handle(session, turn, arguments).await, + "report_agent_job_result" => report_agent_job_result::handle(session, arguments).await, + other => Err(FunctionCallError::RespondToModel(format!( + "unsupported agent job tool {other}" + ))), + } + } +} + +mod spawn_agents_on_csv { + use super::*; + + /// Create a new agent job from a CSV and run it to completion. + /// + /// Each CSV row becomes a job item. The instruction string is a template where `{column}` + /// placeholders are filled with values from that row. Results are reported by workers via + /// `report_agent_job_result`, then exported to CSV on completion. + pub async fn handle( + session: Arc, + turn: Arc, + arguments: String, + ) -> Result { + let args: SpawnAgentsOnCsvArgs = parse_arguments(arguments.as_str())?; + if args.instruction.trim().is_empty() { + return Err(FunctionCallError::RespondToModel( + "instruction must be non-empty".to_string(), + )); + } + + let db = required_state_db(&session)?; + let input_path = turn.resolve_path(Some(args.csv_path)); + let input_path_display = input_path.display().to_string(); + let csv_content = tokio::fs::read_to_string(&input_path) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to read csv input {input_path_display}: {err}" + )) + })?; + let (headers, rows) = parse_csv(csv_content.as_str()).map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to parse csv input: {err}")) + })?; + if headers.is_empty() { + return Err(FunctionCallError::RespondToModel( + "csv input must include a header row".to_string(), + )); + } + ensure_unique_headers(headers.as_slice())?; + + let id_column_index = args.id_column.as_ref().map_or(Ok(None), |column_name| { + headers + .iter() + .position(|header| header == column_name) + .map(Some) + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!( + "id_column {column_name} was not found in csv headers" + )) + }) + })?; + + let mut items = Vec::with_capacity(rows.len()); + let mut seen_ids = HashSet::new(); + for (idx, row) in rows.into_iter().enumerate() { + if row.len() != headers.len() { + let row_index = idx + 2; + let row_len = row.len(); + let header_len = headers.len(); + return Err(FunctionCallError::RespondToModel(format!( + "csv row {row_index} has {row_len} fields but header has {header_len}" + ))); + } + + let source_id = id_column_index + .and_then(|index| row.get(index).cloned()) + .filter(|value| !value.trim().is_empty()); + let row_index = idx + 1; + let base_item_id = source_id + .clone() + .unwrap_or_else(|| format!("row-{row_index}")); + let mut item_id = base_item_id.clone(); + let mut suffix = 2usize; + while !seen_ids.insert(item_id.clone()) { + item_id = format!("{base_item_id}-{suffix}"); + suffix = suffix.saturating_add(1); + } + + let row_object = headers + .iter() + .zip(row.iter()) + .map(|(header, value)| (header.clone(), Value::String(value.clone()))) + .collect::>(); + items.push(codex_state::AgentJobItemCreateParams { + item_id, + row_index: idx as i64, + source_id, + row_json: Value::Object(row_object), + }); + } + + let job_id = Uuid::new_v4().to_string(); + let output_csv_path = args.output_csv_path.map_or_else( + || default_output_csv_path(input_path.as_path(), job_id.as_str()), + |path| turn.resolve_path(Some(path)), + ); + let job_suffix = &job_id[..8]; + let job_name = format!("agent-job-{job_suffix}"); + let max_runtime_seconds = normalize_max_runtime_seconds( + args.max_runtime_seconds + .or(turn.config.agent_job_max_runtime_seconds), + )?; + let _job = db + .create_agent_job( + &codex_state::AgentJobCreateParams { + id: job_id.clone(), + name: job_name, + instruction: args.instruction, + auto_export: true, + max_runtime_seconds, + output_schema_json: args.output_schema, + input_headers: headers, + input_csv_path: input_path.display().to_string(), + output_csv_path: output_csv_path.display().to_string(), + }, + items.as_slice(), + ) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to create agent job: {err}")) + })?; + + let requested_concurrency = args.max_concurrency.or(args.max_workers); + let options = match build_runner_options(&session, &turn, requested_concurrency).await { + Ok(options) => options, + Err(err) => { + let error_message = err.to_string(); + let _ = db + .mark_agent_job_failed(job_id.as_str(), error_message.as_str()) + .await; + return Err(err); + } + }; + db.mark_agent_job_running(job_id.as_str()) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to transition agent job {job_id} to running: {err}" + )) + })?; + let max_threads = turn.config.agent_max_threads; + let effective_concurrency = options.max_concurrency; + let message = format!( + "agent job concurrency: job_id={job_id} requested={requested_concurrency:?} max_threads={max_threads:?} effective={effective_concurrency}" + ); + let _ = session.notify_background_event(&turn, message).await; + if let Err(err) = run_agent_job_loop( + session.clone(), + turn.clone(), + db.clone(), + job_id.clone(), + options, + ) + .await + { + let error_message = format!("job runner failed: {err}"); + let _ = db + .mark_agent_job_failed(job_id.as_str(), error_message.as_str()) + .await; + return Err(FunctionCallError::RespondToModel(format!( + "agent job {job_id} failed: {err}" + ))); + } + + let job = db + .get_agent_job(job_id.as_str()) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to load agent job {job_id}: {err}" + )) + })? + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!("agent job {job_id} not found")) + })?; + let output_path = PathBuf::from(job.output_csv_path.clone()); + if !tokio::fs::try_exists(&output_path).await.unwrap_or(false) { + export_job_csv_snapshot(db.clone(), &job) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to export output csv {job_id}: {err}" + )) + })?; + } + let progress = db + .get_agent_job_progress(job_id.as_str()) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to load agent job progress {job_id}: {err}" + )) + })?; + let mut job_error = job.last_error.clone().filter(|err| !err.trim().is_empty()); + let failed_item_errors = if progress.failed_items > 0 { + let items = db + .list_agent_job_items( + job_id.as_str(), + Some(codex_state::AgentJobItemStatus::Failed), + Some(5), + ) + .await + .unwrap_or_default(); + let summaries: Vec<_> = items + .into_iter() + .filter_map(|item| { + let last_error = item.last_error.unwrap_or_default(); + if last_error.trim().is_empty() { + return None; + } + Some(AgentJobFailureSummary { + item_id: item.item_id, + source_id: item.source_id, + last_error, + }) + }) + .collect(); + if summaries.is_empty() { + if job_error.is_none() { + job_error = Some( + "agent job has failed items but no error details were recorded".to_string(), + ); + } + None + } else { + Some(summaries) + } + } else { + None + }; + let content = serde_json::to_string(&SpawnAgentsOnCsvResult { + job_id, + status: job.status.as_str().to_string(), + output_csv_path: job.output_csv_path, + total_items: progress.total_items, + completed_items: progress.completed_items, + failed_items: progress.failed_items, + job_error, + failed_item_errors, + }) + .map_err(|err| { + FunctionCallError::Fatal(format!( + "failed to serialize spawn_agents_on_csv result: {err}" + )) + })?; + Ok(ToolOutput::Function { + body: FunctionCallOutputBody::Text(content), + success: Some(true), + }) + } +} + +mod report_agent_job_result { + use super::*; + + pub async fn handle( + session: Arc, + arguments: String, + ) -> Result { + let args: ReportAgentJobResultArgs = parse_arguments(arguments.as_str())?; + if !args.result.is_object() { + return Err(FunctionCallError::RespondToModel( + "result must be a JSON object".to_string(), + )); + } + let db = required_state_db(&session)?; + let reporting_thread_id = session.conversation_id.to_string(); + let accepted = db + .report_agent_job_item_result( + args.job_id.as_str(), + args.item_id.as_str(), + reporting_thread_id.as_str(), + &args.result, + ) + .await + .map_err(|err| { + let job_id = args.job_id.as_str(); + let item_id = args.item_id.as_str(); + FunctionCallError::RespondToModel(format!( + "failed to record agent job result for {job_id} / {item_id}: {err}" + )) + })?; + if accepted && args.stop.unwrap_or(false) { + let message = "cancelled by worker request"; + let _ = db + .mark_agent_job_cancelled(args.job_id.as_str(), message) + .await; + } + let content = + serde_json::to_string(&ReportAgentJobResultToolResult { accepted }).map_err(|err| { + FunctionCallError::Fatal(format!( + "failed to serialize report_agent_job_result result: {err}" + )) + })?; + Ok(ToolOutput::Function { + body: FunctionCallOutputBody::Text(content), + success: Some(true), + }) + } +} + +fn required_state_db( + session: &Arc, +) -> Result, FunctionCallError> { + session.state_db().ok_or_else(|| { + FunctionCallError::Fatal( + "sqlite state db is unavailable for this session; enable the sqlite feature" + .to_string(), + ) + }) +} + +async fn build_runner_options( + session: &Arc, + turn: &Arc, + requested_concurrency: Option, +) -> Result { + let session_source = turn.session_source.clone(); + let child_depth = next_thread_spawn_depth(&session_source); + let max_depth = max_thread_spawn_depth(turn.config.agent_max_spawn_depth); + if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { + return Err(FunctionCallError::RespondToModel( + "agent depth limit reached; this session cannot spawn more subagents".to_string(), + )); + } + let max_concurrency = + normalize_concurrency(requested_concurrency, turn.config.agent_max_threads); + let base_instructions = session.get_base_instructions().await; + let spawn_config = build_agent_spawn_config(&base_instructions, turn.as_ref(), child_depth)?; + Ok(JobRunnerOptions { + max_concurrency, + spawn_config, + }) +} + +fn normalize_concurrency(requested: Option, max_threads: Option) -> usize { + let requested = requested.unwrap_or(DEFAULT_AGENT_JOB_CONCURRENCY).max(1); + let requested = requested.min(MAX_AGENT_JOB_CONCURRENCY); + if let Some(max_threads) = max_threads { + requested.min(max_threads.max(1)) + } else { + requested + } +} + +fn normalize_max_runtime_seconds(requested: Option) -> Result, FunctionCallError> { + let Some(requested) = requested else { + return Ok(None); + }; + if requested == 0 { + return Err(FunctionCallError::RespondToModel( + "max_runtime_seconds must be >= 1".to_string(), + )); + } + Ok(Some(requested)) +} + +async fn run_agent_job_loop( + session: Arc, + turn: Arc, + db: Arc, + job_id: String, + options: JobRunnerOptions, +) -> anyhow::Result<()> { + let job = db + .get_agent_job(job_id.as_str()) + .await? + .ok_or_else(|| anyhow::anyhow!("agent job {job_id} was not found"))?; + let runtime_timeout = job_runtime_timeout(&job); + let mut active_items: HashMap = HashMap::new(); + let mut progress_emitter = JobProgressEmitter::new(); + recover_running_items( + session.clone(), + db.clone(), + job_id.as_str(), + &mut active_items, + runtime_timeout, + ) + .await?; + let initial_progress = db.get_agent_job_progress(job_id.as_str()).await?; + progress_emitter + .maybe_emit(&session, &turn, job_id.as_str(), &initial_progress, true) + .await?; + + let mut cancel_requested = db.is_agent_job_cancelled(job_id.as_str()).await?; + loop { + let mut progressed = false; + + if !cancel_requested && db.is_agent_job_cancelled(job_id.as_str()).await? { + cancel_requested = true; + let _ = session + .notify_background_event( + &turn, + format!("agent job {job_id} cancellation requested; stopping new workers"), + ) + .await; + } + + if !cancel_requested && active_items.len() < options.max_concurrency { + let slots = options.max_concurrency - active_items.len(); + let pending_items = db + .list_agent_job_items( + job_id.as_str(), + Some(codex_state::AgentJobItemStatus::Pending), + Some(slots), + ) + .await?; + for item in pending_items { + let prompt = build_worker_prompt(&job, &item)?; + let items = vec![UserInput::Text { + text: prompt, + text_elements: Vec::new(), + }]; + let thread_id = match session + .services + .agent_control + .spawn_agent( + options.spawn_config.clone(), + items, + Some(SessionSource::SubAgent(SubAgentSource::Other(format!( + "agent_job:{job_id}" + )))), + ) + .await + { + Ok(thread_id) => thread_id, + Err(CodexErr::AgentLimitReached { .. }) => { + db.mark_agent_job_item_pending( + job_id.as_str(), + item.item_id.as_str(), + None, + ) + .await?; + break; + } + Err(err) => { + let error_message = format!("failed to spawn worker: {err}"); + db.mark_agent_job_item_failed( + job_id.as_str(), + item.item_id.as_str(), + error_message.as_str(), + ) + .await?; + progressed = true; + continue; + } + }; + let assigned = db + .mark_agent_job_item_running_with_thread( + job_id.as_str(), + item.item_id.as_str(), + thread_id.to_string().as_str(), + ) + .await?; + if !assigned { + let _ = session + .services + .agent_control + .shutdown_agent(thread_id) + .await; + continue; + } + active_items.insert( + thread_id, + ActiveJobItem { + item_id: item.item_id.clone(), + started_at: Instant::now(), + }, + ); + progressed = true; + } + } + + if reap_stale_active_items( + session.clone(), + db.clone(), + job_id.as_str(), + &mut active_items, + runtime_timeout, + ) + .await? + { + progressed = true; + } + + let finished = find_finished_threads(session.clone(), &active_items).await; + if finished.is_empty() { + let progress = db.get_agent_job_progress(job_id.as_str()).await?; + if cancel_requested { + if progress.running_items == 0 && active_items.is_empty() { + break; + } + } else if progress.pending_items == 0 + && progress.running_items == 0 + && active_items.is_empty() + { + break; + } + if !progressed { + tokio::time::sleep(STATUS_POLL_INTERVAL).await; + } + continue; + } + + for (thread_id, item_id) in finished { + finalize_finished_item( + session.clone(), + db.clone(), + job_id.as_str(), + item_id.as_str(), + thread_id, + ) + .await?; + active_items.remove(&thread_id); + let progress = db.get_agent_job_progress(job_id.as_str()).await?; + progress_emitter + .maybe_emit(&session, &turn, job_id.as_str(), &progress, false) + .await?; + } + } + + let progress = db.get_agent_job_progress(job_id.as_str()).await?; + if let Err(err) = export_job_csv_snapshot(db.clone(), &job).await { + let message = format!("auto-export failed: {err}"); + db.mark_agent_job_failed(job_id.as_str(), message.as_str()) + .await?; + return Ok(()); + } + let cancelled = cancel_requested || db.is_agent_job_cancelled(job_id.as_str()).await?; + if cancelled { + let pending_items = progress.pending_items; + let message = + format!("agent job {job_id} cancelled with {pending_items} unprocessed items"); + let _ = session.notify_background_event(&turn, message).await; + progress_emitter + .maybe_emit(&session, &turn, job_id.as_str(), &progress, true) + .await?; + return Ok(()); + } + if progress.failed_items > 0 { + let failed_items = progress.failed_items; + let message = format!("agent job completed with {failed_items} failed items"); + let _ = session.notify_background_event(&turn, message).await; + } + db.mark_agent_job_completed(job_id.as_str()).await?; + let progress = db.get_agent_job_progress(job_id.as_str()).await?; + progress_emitter + .maybe_emit(&session, &turn, job_id.as_str(), &progress, true) + .await?; + Ok(()) +} + +async fn export_job_csv_snapshot( + db: Arc, + job: &codex_state::AgentJob, +) -> anyhow::Result<()> { + let items = db.list_agent_job_items(job.id.as_str(), None, None).await?; + let csv_content = render_job_csv(job.input_headers.as_slice(), items.as_slice()) + .map_err(|err| anyhow::anyhow!("failed to render job csv for auto-export: {err}"))?; + let output_path = PathBuf::from(job.output_csv_path.clone()); + if let Some(parent) = output_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + tokio::fs::write(&output_path, csv_content).await?; + Ok(()) +} + +async fn recover_running_items( + session: Arc, + db: Arc, + job_id: &str, + active_items: &mut HashMap, + runtime_timeout: Duration, +) -> anyhow::Result<()> { + let running_items = db + .list_agent_job_items(job_id, Some(codex_state::AgentJobItemStatus::Running), None) + .await?; + for item in running_items { + if is_item_stale(&item, runtime_timeout) { + let error_message = format!("worker exceeded max runtime of {runtime_timeout:?}"); + db.mark_agent_job_item_failed(job_id, item.item_id.as_str(), error_message.as_str()) + .await?; + if let Some(assigned_thread_id) = item.assigned_thread_id.as_ref() + && let Ok(thread_id) = ThreadId::from_string(assigned_thread_id.as_str()) + { + let _ = session + .services + .agent_control + .shutdown_agent(thread_id) + .await; + } + continue; + } + let Some(assigned_thread_id) = item.assigned_thread_id.clone() else { + db.mark_agent_job_item_failed( + job_id, + item.item_id.as_str(), + "running item is missing assigned_thread_id", + ) + .await?; + continue; + }; + let thread_id = match ThreadId::from_string(assigned_thread_id.as_str()) { + Ok(thread_id) => thread_id, + Err(err) => { + let error_message = format!("invalid assigned_thread_id: {err:?}"); + db.mark_agent_job_item_failed( + job_id, + item.item_id.as_str(), + error_message.as_str(), + ) + .await?; + continue; + } + }; + if is_final(&session.services.agent_control.get_status(thread_id).await) { + finalize_finished_item( + session.clone(), + db.clone(), + job_id, + item.item_id.as_str(), + thread_id, + ) + .await?; + } else { + active_items.insert( + thread_id, + ActiveJobItem { + item_id: item.item_id.clone(), + started_at: started_at_from_item(&item), + }, + ); + } + } + Ok(()) +} + +async fn find_finished_threads( + session: Arc, + active_items: &HashMap, +) -> Vec<(ThreadId, String)> { + let mut finished = Vec::new(); + for (thread_id, item) in active_items { + if is_final(&session.services.agent_control.get_status(*thread_id).await) { + finished.push((*thread_id, item.item_id.clone())); + } + } + finished +} + +async fn reap_stale_active_items( + session: Arc, + db: Arc, + job_id: &str, + active_items: &mut HashMap, + runtime_timeout: Duration, +) -> anyhow::Result { + let mut stale = Vec::new(); + for (thread_id, item) in active_items.iter() { + if item.started_at.elapsed() >= runtime_timeout { + stale.push((*thread_id, item.item_id.clone())); + } + } + if stale.is_empty() { + return Ok(false); + } + for (thread_id, item_id) in stale { + let error_message = format!("worker exceeded max runtime of {runtime_timeout:?}"); + db.mark_agent_job_item_failed(job_id, item_id.as_str(), error_message.as_str()) + .await?; + let _ = session + .services + .agent_control + .shutdown_agent(thread_id) + .await; + active_items.remove(&thread_id); + } + Ok(true) +} + +async fn finalize_finished_item( + session: Arc, + db: Arc, + job_id: &str, + item_id: &str, + thread_id: ThreadId, +) -> anyhow::Result<()> { + let mut item = db + .get_agent_job_item(job_id, item_id) + .await? + .ok_or_else(|| { + anyhow::anyhow!("job item not found for finalization: {job_id}/{item_id}") + })?; + if item.result_json.is_none() { + tokio::time::sleep(Duration::from_millis(250)).await; + item = db + .get_agent_job_item(job_id, item_id) + .await? + .ok_or_else(|| { + anyhow::anyhow!("job item not found after grace period: {job_id}/{item_id}") + })?; + } + if item.result_json.is_some() { + if !db.mark_agent_job_item_completed(job_id, item_id).await? { + db.mark_agent_job_item_failed( + job_id, + item_id, + "worker reported result but item could not transition to completed", + ) + .await?; + } + } else { + db.mark_agent_job_item_failed( + job_id, + item_id, + "worker finished without calling report_agent_job_result", + ) + .await?; + } + let _ = session + .services + .agent_control + .shutdown_agent(thread_id) + .await; + Ok(()) +} + +fn build_worker_prompt( + job: &codex_state::AgentJob, + item: &codex_state::AgentJobItem, +) -> anyhow::Result { + let job_id = job.id.as_str(); + let item_id = item.item_id.as_str(); + let instruction = render_instruction_template(job.instruction.as_str(), &item.row_json); + let output_schema = job + .output_schema_json + .as_ref() + .map(serde_json::to_string_pretty) + .transpose()? + .unwrap_or_else(|| "{}".to_string()); + let row_json = serde_json::to_string_pretty(&item.row_json)?; + Ok(format!( + "You are processing one item for a generic agent job.\n\ +Job ID: {job_id}\n\ +Item ID: {item_id}\n\n\ +Task instruction:\n\ +{instruction}\n\n\ +Input row (JSON):\n\ +{row_json}\n\n\ +Expected result schema (JSON Schema or {{}}):\n\ +{output_schema}\n\n\ +You MUST call the `report_agent_job_result` tool exactly once with:\n\ +1. `job_id` = \"{job_id}\"\n\ +2. `item_id` = \"{item_id}\"\n\ +3. `result` = a JSON object that contains your analysis result for this row.\n\n\ +If you need to stop the job early, include `stop` = true in the tool call.\n\n\ +After the tool call succeeds, stop.", + )) +} + +fn render_instruction_template(instruction: &str, row_json: &Value) -> String { + const OPEN_BRACE_SENTINEL: &str = "__CODEX_OPEN_BRACE__"; + const CLOSE_BRACE_SENTINEL: &str = "__CODEX_CLOSE_BRACE__"; + + let mut rendered = instruction + .replace("{{", OPEN_BRACE_SENTINEL) + .replace("}}", CLOSE_BRACE_SENTINEL); + let Some(row) = row_json.as_object() else { + return rendered + .replace(OPEN_BRACE_SENTINEL, "{") + .replace(CLOSE_BRACE_SENTINEL, "}"); + }; + for (key, value) in row { + let placeholder = format!("{{{key}}}"); + let replacement = value + .as_str() + .map(str::to_string) + .unwrap_or_else(|| value.to_string()); + rendered = rendered.replace(placeholder.as_str(), replacement.as_str()); + } + rendered + .replace(OPEN_BRACE_SENTINEL, "{") + .replace(CLOSE_BRACE_SENTINEL, "}") +} + +fn ensure_unique_headers(headers: &[String]) -> Result<(), FunctionCallError> { + let mut seen = HashSet::new(); + for header in headers { + if !seen.insert(header) { + return Err(FunctionCallError::RespondToModel(format!( + "csv header {header} is duplicated" + ))); + } + } + Ok(()) +} + +fn job_runtime_timeout(job: &codex_state::AgentJob) -> Duration { + job.max_runtime_seconds + .map(Duration::from_secs) + .unwrap_or(DEFAULT_AGENT_JOB_ITEM_TIMEOUT) +} + +fn started_at_from_item(item: &codex_state::AgentJobItem) -> Instant { + let now = chrono::Utc::now(); + let age = now.signed_duration_since(item.updated_at); + if let Ok(age) = age.to_std() { + Instant::now().checked_sub(age).unwrap_or_else(Instant::now) + } else { + Instant::now() + } +} + +fn is_item_stale(item: &codex_state::AgentJobItem, runtime_timeout: Duration) -> bool { + let now = chrono::Utc::now(); + if let Ok(age) = now.signed_duration_since(item.updated_at).to_std() { + age >= runtime_timeout + } else { + false + } +} + +fn default_output_csv_path(input_csv_path: &Path, job_id: &str) -> PathBuf { + let stem = input_csv_path + .file_stem() + .and_then(|stem| stem.to_str()) + .unwrap_or("agent_job_output"); + let job_suffix = &job_id[..8]; + input_csv_path.with_file_name(format!("{stem}.agent-job-{job_suffix}.csv")) +} + +fn parse_csv(content: &str) -> Result<(Vec, Vec>), String> { + let mut reader = csv::ReaderBuilder::new() + .has_headers(true) + .flexible(true) + .from_reader(content.as_bytes()); + let headers_record = reader.headers().map_err(|err| err.to_string())?; + let mut headers: Vec = headers_record.iter().map(str::to_string).collect(); + if let Some(first) = headers.first_mut() { + *first = first.trim_start_matches('\u{feff}').to_string(); + } + let mut rows = Vec::new(); + for record in reader.records() { + let record = record.map_err(|err| err.to_string())?; + let row: Vec = record.iter().map(str::to_string).collect(); + if row.iter().all(std::string::String::is_empty) { + continue; + } + rows.push(row); + } + Ok((headers, rows)) +} + +fn render_job_csv( + headers: &[String], + items: &[codex_state::AgentJobItem], +) -> Result { + let mut csv = String::new(); + let mut output_headers = headers.to_vec(); + output_headers.extend([ + "job_id".to_string(), + "item_id".to_string(), + "row_index".to_string(), + "source_id".to_string(), + "status".to_string(), + "attempt_count".to_string(), + "last_error".to_string(), + "result_json".to_string(), + "reported_at".to_string(), + "completed_at".to_string(), + ]); + csv.push_str( + output_headers + .iter() + .map(|header| csv_escape(header.as_str())) + .collect::>() + .join(",") + .as_str(), + ); + csv.push('\n'); + for item in items { + let row_object = item.row_json.as_object().ok_or_else(|| { + let item_id = item.item_id.as_str(); + FunctionCallError::RespondToModel(format!( + "row_json for item {item_id} is not a JSON object" + )) + })?; + let mut row_values = Vec::new(); + for header in headers { + let value = row_object + .get(header) + .map_or_else(String::new, value_to_csv_string); + row_values.push(csv_escape(value.as_str())); + } + row_values.push(csv_escape(item.job_id.as_str())); + row_values.push(csv_escape(item.item_id.as_str())); + row_values.push(csv_escape(item.row_index.to_string().as_str())); + row_values.push(csv_escape( + item.source_id.clone().unwrap_or_default().as_str(), + )); + row_values.push(csv_escape(item.status.as_str())); + row_values.push(csv_escape(item.attempt_count.to_string().as_str())); + row_values.push(csv_escape( + item.last_error.clone().unwrap_or_default().as_str(), + )); + row_values.push(csv_escape( + item.result_json + .as_ref() + .map_or_else(String::new, std::string::ToString::to_string) + .as_str(), + )); + row_values.push(csv_escape( + item.reported_at + .map(|value| value.to_rfc3339()) + .unwrap_or_default() + .as_str(), + )); + row_values.push(csv_escape( + item.completed_at + .map(|value| value.to_rfc3339()) + .unwrap_or_default() + .as_str(), + )); + csv.push_str(row_values.join(",").as_str()); + csv.push('\n'); + } + Ok(csv) +} + +fn value_to_csv_string(value: &Value) -> String { + match value { + Value::Null => String::new(), + Value::String(s) => s.clone(), + Value::Bool(b) => b.to_string(), + Value::Number(n) => n.to_string(), + Value::Array(_) | Value::Object(_) => value.to_string(), + } +} + +fn csv_escape(value: &str) -> String { + if value.contains(',') || value.contains('\n') || value.contains('\r') || value.contains('"') { + let escaped = value.replace('"', "\"\""); + format!("\"{escaped}\"") + } else { + value.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use serde_json::json; + + #[test] + fn parse_csv_supports_quotes_and_commas() { + let input = "id,name\n1,\"alpha, beta\"\n2,gamma\n"; + let (headers, rows) = parse_csv(input).expect("csv parse"); + assert_eq!(headers, vec!["id".to_string(), "name".to_string()]); + assert_eq!( + rows, + vec![ + vec!["1".to_string(), "alpha, beta".to_string()], + vec!["2".to_string(), "gamma".to_string()] + ] + ); + } + + #[test] + fn csv_escape_quotes_when_needed() { + assert_eq!(csv_escape("simple"), "simple"); + assert_eq!(csv_escape("a,b"), "\"a,b\""); + assert_eq!(csv_escape("a\"b"), "\"a\"\"b\""); + } + + #[test] + fn render_instruction_template_expands_placeholders_and_escapes_braces() { + let row = json!({ + "path": "src/lib.rs", + "area": "test", + "file path": "docs/readme.md", + }); + let rendered = render_instruction_template( + "Review {path} in {area}. Also see {file path}. Use {{literal}}.", + &row, + ); + assert_eq!( + rendered, + "Review src/lib.rs in test. Also see docs/readme.md. Use {literal}." + ); + } + + #[test] + fn render_instruction_template_leaves_unknown_placeholders() { + let row = json!({ + "path": "src/lib.rs", + }); + let rendered = render_instruction_template("Check {path} then {missing}", &row); + assert_eq!(rendered, "Check src/lib.rs then {missing}"); + } + + #[test] + fn ensure_unique_headers_rejects_duplicates() { + let headers = vec!["path".to_string(), "path".to_string()]; + let Err(err) = ensure_unique_headers(headers.as_slice()) else { + panic!("expected duplicate header error"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("csv header path is duplicated".to_string()) + ); + } +} diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index ae93327b8..810c5cb1d 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -86,6 +86,7 @@ impl ToolHandler for ApplyPatchHandler { call_id, tool_name, payload, + .. } = invocation; let patch_input = match payload { diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index ba52a2d01..bdc47a088 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod agent_jobs; pub mod apply_patch; mod dynamic; mod grep_files; diff --git a/codex-rs/core/src/tools/handlers/multi_agents.rs b/codex-rs/core/src/tools/handlers/multi_agents.rs index 090810d9f..8fbbd0d11 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents.rs @@ -1,5 +1,6 @@ use crate::agent::AgentStatus; use crate::agent::exceeds_thread_spawn_depth_limit; +use crate::agent::max_thread_spawn_depth; use crate::codex::Session; use crate::codex::TurnContext; use crate::config::Config; @@ -97,6 +98,7 @@ mod spawn { use crate::agent::role::apply_role_to_config; use crate::agent::exceeds_thread_spawn_depth_limit; + use crate::agent::max_thread_spawn_depth; use crate::agent::next_thread_spawn_depth; use std::sync::Arc; @@ -129,7 +131,8 @@ mod spawn { let prompt = input_preview(&input_items); let session_source = turn.session_source.clone(); let child_depth = next_thread_spawn_depth(&session_source); - if exceeds_thread_spawn_depth_limit(child_depth, turn.config.agent_max_depth) { + let max_depth = max_thread_spawn_depth(turn.config.agent_max_spawn_depth); + if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { return Err(FunctionCallError::RespondToModel( "Agent depth limit reached. Solve the task yourself.".to_string(), )); @@ -345,7 +348,8 @@ mod resume_agent { .await .unwrap_or((None, None)); let child_depth = next_thread_spawn_depth(&turn.session_source); - if exceeds_thread_spawn_depth_limit(child_depth, turn.config.agent_max_depth) { + let max_depth = max_thread_spawn_depth(turn.config.agent_max_spawn_depth); + if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { return Err(FunctionCallError::RespondToModel( "Agent depth limit reached. Solve the task yourself.".to_string(), )); @@ -891,7 +895,7 @@ fn input_preview(items: &[UserInput]) -> String { parts.join("\n") } -fn build_agent_spawn_config( +pub(crate) fn build_agent_spawn_config( base_instructions: &BaseInstructions, turn: &TurnContext, child_depth: i32, @@ -948,7 +952,8 @@ fn apply_spawn_agent_runtime_overrides( fn apply_spawn_agent_overrides(config: &mut Config, child_depth: i32) { config.permissions.approval_policy = Constrained::allow_only(AskForApproval::Never); - if exceeds_thread_spawn_depth_limit(child_depth + 1, config.agent_max_depth) { + let max_depth = max_thread_spawn_depth(config.agent_max_spawn_depth); + if exceeds_thread_spawn_depth_limit(child_depth + 1, max_depth) { config.features.disable(Feature::Collab); } } @@ -959,6 +964,7 @@ mod tests { use crate::AuthManager; use crate::CodexAuth; use crate::ThreadManager; + use crate::agent::max_thread_spawn_depth; use crate::built_in_model_providers; use crate::codex::make_session_and_context; use crate::config::DEFAULT_AGENT_MAX_DEPTH; @@ -998,6 +1004,7 @@ mod tests { call_id: "call-1".to_string(), tool_name: tool_name.to_string(), payload, + source: crate::tools::router::ToolCallSource::Direct, } } @@ -1259,9 +1266,10 @@ mod tests { let manager = thread_manager(); session.services.agent_control = manager.agent_control(); + let max_depth = max_thread_spawn_depth(turn.config.agent_max_spawn_depth); turn.session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { parent_thread_id: session.conversation_id, - depth: DEFAULT_AGENT_MAX_DEPTH, + depth: max_depth, agent_nickname: None, agent_role: None, }); @@ -1689,9 +1697,10 @@ mod tests { let manager = thread_manager(); session.services.agent_control = manager.agent_control(); + let max_depth = max_thread_spawn_depth(turn.config.agent_max_spawn_depth); turn.session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { parent_thread_id: session.conversation_id, - depth: DEFAULT_AGENT_MAX_DEPTH, + depth: max_depth, agent_nickname: None, agent_role: None, }); diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index 7a78ecb1c..0a0bb5347 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -176,6 +176,7 @@ impl ToolHandler for ShellHandler { call_id, tool_name, payload, + .. } = invocation; match payload { @@ -261,6 +262,7 @@ impl ToolHandler for ShellCommandHandler { call_id, tool_name, payload, + .. } = invocation; let ToolPayload::Function { arguments } = payload else { diff --git a/codex-rs/core/src/tools/handlers/view_image.rs b/codex-rs/core/src/tools/handlers/view_image.rs index 6337cef3e..ea05f736e 100644 --- a/codex-rs/core/src/tools/handlers/view_image.rs +++ b/codex-rs/core/src/tools/handlers/view_image.rs @@ -8,6 +8,7 @@ use tokio::fs; use crate::function_tool::FunctionCallError; use crate::protocol::EventMsg; use crate::protocol::ViewImageToolCallEvent; +use crate::tools::context::ToolCallSource; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; @@ -15,6 +16,7 @@ use crate::tools::handlers::parse_arguments; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseInputItem; use codex_protocol::models::local_image_content_items_with_label_number; pub struct ViewImageHandler; @@ -50,6 +52,7 @@ impl ToolHandler for ViewImageHandler { turn, payload, call_id, + source, .. } = invocation; @@ -81,7 +84,26 @@ impl ToolHandler for ViewImageHandler { } let event_path = abs_path.clone(); - let content = local_image_content_items_with_label_number(&abs_path, None) + let content = local_image_content_items_with_label_number(&abs_path, None); + if source == ToolCallSource::JsRepl + && content + .iter() + .any(|item| matches!(item, ContentItem::InputImage { .. })) + { + let input_item = ResponseInputItem::Message { + role: "user".to_string(), + content: content.clone(), + }; + if session + .inject_response_items(vec![input_item]) + .await + .is_err() + { + tracing::warn!("view_image could not find an active turn to attach image input"); + } + } + + let content = content .into_iter() .map(|item| match item { ContentItem::InputText { text } => { diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index fe0827f4f..d95a2c49a 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -22,6 +22,8 @@ use std::collections::HashMap; use std::sync::Arc; use tracing::instrument; +pub use crate::tools::context::ToolCallSource; + #[derive(Clone, Debug)] pub struct ToolCall { pub tool_name: String, @@ -29,12 +31,6 @@ pub struct ToolCall { pub payload: ToolPayload, } -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum ToolCallSource { - Direct, - JsRepl, -} - pub struct ToolRouter { registry: ToolRegistry, specs: Vec, @@ -179,6 +175,7 @@ impl ToolRouter { call_id, tool_name, payload, + source, }; match self.registry.dispatch(invocation).await { diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index ea9ccb3ea..3c9ab4bf3 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -9,6 +9,7 @@ use crate::mcp_connection_manager::ToolInfo; use crate::tools::handlers::PLAN_TOOL; use crate::tools::handlers::SEARCH_TOOL_BM25_DEFAULT_LIMIT; use crate::tools::handlers::SEARCH_TOOL_BM25_TOOL_NAME; +use crate::tools::handlers::agent_jobs::BatchJobHandler; use crate::tools::handlers::apply_patch::create_apply_patch_freeform_tool; use crate::tools::handlers::apply_patch::create_apply_patch_json_tool; use crate::tools::handlers::multi_agents::DEFAULT_WAIT_TIMEOUT_MS; @@ -22,6 +23,8 @@ use codex_protocol::models::VIEW_IMAGE_TOOL_NAME; use codex_protocol::openai_models::ApplyPatchToolType; use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::ModelInfo; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; use serde::Deserialize; use serde::Serialize; use serde_json::Value as JsonValue; @@ -53,12 +56,15 @@ pub(crate) struct ToolsConfig { pub collab_tools: bool, pub collaboration_modes_tools: bool, pub experimental_supported_tools: Vec, + pub agent_jobs_tools: bool, + pub agent_jobs_worker_tools: bool, } pub(crate) struct ToolsConfigParams<'a> { pub(crate) model_info: &'a ModelInfo, pub(crate) features: &'a Features, pub(crate) web_search_mode: Option, + pub(crate) session_source: SessionSource, } impl ToolsConfig { @@ -67,14 +73,16 @@ impl ToolsConfig { model_info, features, web_search_mode, + session_source, } = params; let include_apply_patch_tool = features.enabled(Feature::ApplyPatchFreeform); let include_js_repl = features.enabled(Feature::JsRepl); let include_js_repl_tools_only = include_js_repl && features.enabled(Feature::JsReplToolsOnly); let include_collab_tools = features.enabled(Feature::Collab); - let include_collaboration_modes_tools = true; + let include_collaboration_modes_tools = features.enabled(Feature::CollaborationModes); let include_search_tool = features.enabled(Feature::Apps); + let include_agent_jobs = include_collab_tools && features.enabled(Feature::Sqlite); let request_permission_enabled = features.enabled(Feature::RequestPermissions); let shell_command_backend = if features.enabled(Feature::ShellTool) && features.enabled(Feature::ShellZshFork) { @@ -110,6 +118,13 @@ impl ToolsConfig { } }; + let agent_jobs_worker_tools = include_agent_jobs + && matches!( + session_source, + SessionSource::SubAgent(SubAgentSource::Other(label)) + if label.starts_with("agent_job:") + ); + Self { shell_type, shell_command_backend, @@ -124,6 +139,8 @@ impl ToolsConfig { collab_tools: include_collab_tools, collaboration_modes_tools: include_collaboration_modes_tools, experimental_supported_tools: model_info.experimental_supported_tools.clone(), + agent_jobs_tools: include_agent_jobs, + agent_jobs_worker_tools, } } @@ -623,6 +640,131 @@ fn create_spawn_agent_tool(config: &ToolsConfig) -> ToolSpec { }) } +fn create_spawn_agents_on_csv_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "csv_path".to_string(), + JsonSchema::String { + description: Some("Path to the CSV file containing input rows.".to_string()), + }, + ); + properties.insert( + "instruction".to_string(), + JsonSchema::String { + description: Some( + "Instruction template to apply to each CSV row. Use {column_name} placeholders to inject values from the row." + .to_string(), + ), + }, + ); + properties.insert( + "id_column".to_string(), + JsonSchema::String { + description: Some("Optional column name to use as stable item id.".to_string()), + }, + ); + properties.insert( + "output_csv_path".to_string(), + JsonSchema::String { + description: Some("Optional output CSV path for exported results.".to_string()), + }, + ); + properties.insert( + "max_concurrency".to_string(), + JsonSchema::Number { + description: Some( + "Maximum concurrent workers for this job. Defaults to 16 and is capped by config." + .to_string(), + ), + }, + ); + properties.insert( + "max_workers".to_string(), + JsonSchema::Number { + description: Some( + "Alias for max_concurrency. Set to 1 to run sequentially.".to_string(), + ), + }, + ); + properties.insert( + "max_runtime_seconds".to_string(), + JsonSchema::Number { + description: Some( + "Maximum runtime per worker before it is failed. Defaults to 1800 seconds." + .to_string(), + ), + }, + ); + properties.insert( + "output_schema".to_string(), + JsonSchema::Object { + properties: BTreeMap::new(), + required: None, + additional_properties: None, + }, + ); + ToolSpec::Function(ResponsesApiTool { + name: "spawn_agents_on_csv".to_string(), + description: "Process a CSV by spawning one worker sub-agent per row. The instruction string is a template where `{column}` placeholders are replaced with row values. Each worker must call `report_agent_job_result` with a JSON object (matching `output_schema` when provided); missing reports are treated as failures. This call blocks until all rows finish and automatically exports results to `output_csv_path` (or a default path)." + .to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["csv_path".to_string(), "instruction".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + +fn create_report_agent_job_result_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "job_id".to_string(), + JsonSchema::String { + description: Some("Identifier of the job.".to_string()), + }, + ); + properties.insert( + "item_id".to_string(), + JsonSchema::String { + description: Some("Identifier of the job item.".to_string()), + }, + ); + properties.insert( + "result".to_string(), + JsonSchema::Object { + properties: BTreeMap::new(), + required: None, + additional_properties: None, + }, + ); + properties.insert( + "stop".to_string(), + JsonSchema::Boolean { + description: Some( + "Optional. When true, cancels the remaining job items after this result is recorded." + .to_string(), + ), + }, + ); + ToolSpec::Function(ResponsesApiTool { + name: "report_agent_job_result".to_string(), + description: + "Worker-only tool to report a result for an agent job item. Main agents should not call this." + .to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec![ + "job_id".to_string(), + "item_id".to_string(), + "result".to_string(), + ]), + additional_properties: Some(false.into()), + }, + }) +} + fn create_send_input_tool() -> ToolSpec { let properties = BTreeMap::from([ ( @@ -1670,6 +1812,16 @@ pub(crate) fn build_specs( builder.register_handler("close_agent", multi_agent_handler); } + if config.agent_jobs_tools { + let agent_jobs_handler = Arc::new(BatchJobHandler); + builder.push_spec(create_spawn_agents_on_csv_tool()); + builder.register_handler("spawn_agents_on_csv", agent_jobs_handler.clone()); + if config.agent_jobs_worker_tools { + builder.push_spec(create_report_agent_job_result_tool()); + builder.register_handler("report_agent_job_result", agent_jobs_handler); + } + } + if let Some(mcp_tools) = mcp_tools { let mut entries: Vec<(String, rmcp::model::Tool)> = mcp_tools.into_iter().collect(); entries.sort_by(|a, b| a.0.cmp(&b.0)); @@ -1870,6 +2022,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&config, None, None, &[]).build(); @@ -1928,10 +2081,42 @@ mod tests { let mut features = Features::with_defaults(); features.enable(Feature::Collab); features.enable(Feature::CollaborationModes); + features.enable(Feature::Sqlite); let tools_config = ToolsConfig::new(&ToolsConfigParams { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert_contains_tool_names( + &tools, + &[ + "spawn_agent", + "send_input", + "wait", + "close_agent", + "spawn_agents_on_csv", + ], + ); + } + + #[test] + fn test_build_specs_agent_job_worker_tools_enabled() { + let config = test_config(); + let model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Collab); + features.enable(Feature::CollaborationModes); + features.enable(Feature::Sqlite); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::SubAgent(SubAgentSource::Other( + "agent_job:test".to_string(), + )), }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); assert_contains_tool_names( @@ -1942,10 +2127,62 @@ mod tests { "resume_agent", "wait", "close_agent", + "spawn_agents_on_csv", + "report_agent_job_result", ], ); } + #[test] + fn request_user_input_requires_collaboration_modes_feature() { + let config = test_config(); + let model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.disable(Feature::CollaborationModes); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert!( + !tools.iter().any(|t| t.spec.name() == "request_user_input"), + "request_user_input should be disabled when collaboration_modes feature is off" + ); + + features.enable(Feature::CollaborationModes); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert_contains_tool_names(&tools, &["request_user_input"]); + } + + #[test] + fn get_memory_requires_feature_flag() { + let config = test_config(); + let model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.disable(Feature::MemoryTool); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert!( + !tools.iter().any(|t| t.spec.name() == "get_memory"), + "get_memory should be disabled when memory_tool feature is off" + ); + } + #[test] fn js_repl_requires_feature_flag() { let config = test_config(); @@ -1957,6 +2194,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); @@ -1982,6 +2220,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); assert_contains_tool_names(&tools, &["js_repl", "js_repl_reset"]); @@ -2013,6 +2252,7 @@ mod tests { model_info: &model_info, features, web_search_mode, + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); let tool_names = tools.iter().map(|t| t.spec.name()).collect::>(); @@ -2046,6 +2286,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); @@ -2069,6 +2310,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); @@ -2092,6 +2334,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); @@ -2115,6 +2358,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, Some(HashMap::new()), None, &[]).build(); @@ -2314,6 +2558,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, Some(HashMap::new()), None, &[]).build(); @@ -2337,6 +2582,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, }); assert_eq!(tools_config.shell_type, ConfigShellToolType::ShellCommand); @@ -2358,6 +2604,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); @@ -2382,6 +2629,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); @@ -2413,6 +2661,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs( &tools_config, @@ -2499,6 +2748,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); // Intentionally construct a map with keys that would sort alphabetically. @@ -2544,6 +2794,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs( @@ -2611,6 +2862,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs( @@ -2665,6 +2917,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs( @@ -2716,6 +2969,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs( @@ -2769,6 +3023,7 @@ mod tests { model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs( @@ -2901,6 +3156,7 @@ Examples of valid command strings: model_info: &model_info, features: &features, web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, }); let (tools, _) = build_specs( &tools_config, diff --git a/codex-rs/core/tests/common/test_codex.rs b/codex-rs/core/tests/common/test_codex.rs index 820590923..173db5ca7 100644 --- a/codex-rs/core/tests/common/test_codex.rs +++ b/codex-rs/core/tests/common/test_codex.rs @@ -215,6 +215,14 @@ impl TestCodexBuilder { } if let Ok(path) = codex_utils_cargo_bin::cargo_bin("codex") { config.codex_linux_sandbox_exe = Some(path); + } else if let Ok(exe) = std::env::current_exe() + && let Some(path) = exe + .parent() + .and_then(|parent| parent.parent()) + .map(|parent| parent.join("codex")) + && path.is_file() + { + config.codex_linux_sandbox_exe = Some(path); } let mut mutators = vec![]; diff --git a/codex-rs/core/tests/suite/agent_jobs.rs b/codex-rs/core/tests/suite/agent_jobs.rs new file mode 100644 index 000000000..5708f400a --- /dev/null +++ b/codex-rs/core/tests/suite/agent_jobs.rs @@ -0,0 +1,424 @@ +use anyhow::Result; +use codex_core::features::Feature; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::sse; +use core_test_support::responses::sse_response; +use core_test_support::responses::start_mock_server; +use core_test_support::test_codex::test_codex; +use regex_lite::Regex; +use serde_json::Value; +use serde_json::json; +use std::fs; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use wiremock::Mock; +use wiremock::Respond; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path_regex; + +struct AgentJobsResponder { + spawn_args_json: String, + seen_main: AtomicBool, + call_counter: AtomicUsize, +} + +impl AgentJobsResponder { + fn new(spawn_args_json: String) -> Self { + Self { + spawn_args_json, + seen_main: AtomicBool::new(false), + call_counter: AtomicUsize::new(0), + } + } +} + +struct StopAfterFirstResponder { + spawn_args_json: String, + seen_main: AtomicBool, + worker_calls: Arc, +} + +impl StopAfterFirstResponder { + fn new(spawn_args_json: String, worker_calls: Arc) -> Self { + Self { + spawn_args_json, + seen_main: AtomicBool::new(false), + worker_calls, + } + } +} + +impl Respond for StopAfterFirstResponder { + fn respond(&self, request: &wiremock::Request) -> ResponseTemplate { + let body_bytes = decode_body_bytes(request); + let body: Value = serde_json::from_slice(&body_bytes).unwrap_or(Value::Null); + + if has_function_call_output(&body) { + return sse_response(sse(vec![ + ev_response_created("resp-tool"), + ev_completed("resp-tool"), + ])); + } + + if let Some((job_id, item_id)) = extract_job_and_item(&body) { + let call_index = self.worker_calls.fetch_add(1, Ordering::SeqCst); + let call_id = format!("call-worker-{call_index}"); + let stop = call_index == 0; + let args = json!({ + "job_id": job_id, + "item_id": item_id, + "result": { "item_id": item_id }, + "stop": stop, + }); + let args_json = serde_json::to_string(&args).unwrap_or_else(|err| { + panic!("worker args serialize: {err}"); + }); + return sse_response(sse(vec![ + ev_response_created("resp-worker"), + ev_function_call(&call_id, "report_agent_job_result", &args_json), + ev_completed("resp-worker"), + ])); + } + + if !self.seen_main.swap(true, Ordering::SeqCst) { + return sse_response(sse(vec![ + ev_response_created("resp-main"), + ev_function_call("call-spawn", "spawn_agents_on_csv", &self.spawn_args_json), + ev_completed("resp-main"), + ])); + } + + sse_response(sse(vec![ + ev_response_created("resp-default"), + ev_completed("resp-default"), + ])) + } +} + +impl Respond for AgentJobsResponder { + fn respond(&self, request: &wiremock::Request) -> ResponseTemplate { + let body_bytes = decode_body_bytes(request); + let body: Value = serde_json::from_slice(&body_bytes).unwrap_or(Value::Null); + + if has_function_call_output(&body) { + return sse_response(sse(vec![ + ev_response_created("resp-tool"), + ev_completed("resp-tool"), + ])); + } + + if let Some((job_id, item_id)) = extract_job_and_item(&body) { + let call_id = format!( + "call-worker-{}", + self.call_counter.fetch_add(1, Ordering::SeqCst) + ); + let args = json!({ + "job_id": job_id, + "item_id": item_id, + "result": { "item_id": item_id } + }); + let args_json = serde_json::to_string(&args).unwrap_or_else(|err| { + panic!("worker args serialize: {err}"); + }); + return sse_response(sse(vec![ + ev_response_created("resp-worker"), + ev_function_call(&call_id, "report_agent_job_result", &args_json), + ev_completed("resp-worker"), + ])); + } + + if !self.seen_main.swap(true, Ordering::SeqCst) { + return sse_response(sse(vec![ + ev_response_created("resp-main"), + ev_function_call("call-spawn", "spawn_agents_on_csv", &self.spawn_args_json), + ev_completed("resp-main"), + ])); + } + + sse_response(sse(vec![ + ev_response_created("resp-default"), + ev_completed("resp-default"), + ])) + } +} + +fn decode_body_bytes(request: &wiremock::Request) -> Vec { + let Some(encoding) = request + .headers + .get("content-encoding") + .and_then(|value| value.to_str().ok()) + else { + return request.body.clone(); + }; + if encoding + .split(',') + .any(|entry| entry.trim().eq_ignore_ascii_case("zstd")) + { + zstd::stream::decode_all(std::io::Cursor::new(&request.body)) + .unwrap_or_else(|_| request.body.clone()) + } else { + request.body.clone() + } +} + +fn has_function_call_output(body: &Value) -> bool { + body.get("input") + .and_then(Value::as_array) + .is_some_and(|items| { + items.iter().any(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call_output") + }) + }) +} + +fn extract_job_and_item(body: &Value) -> Option<(String, String)> { + let texts = message_input_texts(body); + let mut combined = texts.join("\n"); + if let Some(instructions) = body.get("instructions").and_then(Value::as_str) { + combined.push('\n'); + combined.push_str(instructions); + } + if !combined.contains("You are processing one item for a generic agent job.") { + return None; + } + let job_id = Regex::new(r"Job ID:\s*([^\n]+)") + .ok()? + .captures(&combined) + .and_then(|caps| caps.get(1)) + .map(|m| m.as_str().trim().to_string())?; + let item_id = Regex::new(r"Item ID:\s*([^\n]+)") + .ok()? + .captures(&combined) + .and_then(|caps| caps.get(1)) + .map(|m| m.as_str().trim().to_string())?; + Some((job_id, item_id)) +} + +fn message_input_texts(body: &Value) -> Vec { + let Some(items) = body.get("input").and_then(Value::as_array) else { + return Vec::new(); + }; + items + .iter() + .filter(|item| item.get("type").and_then(Value::as_str) == Some("message")) + .filter_map(|item| item.get("content").and_then(Value::as_array)) + .flatten() + .filter(|span| span.get("type").and_then(Value::as_str) == Some("input_text")) + .filter_map(|span| span.get("text").and_then(Value::as_str)) + .map(str::to_string) + .collect() +} + +fn parse_simple_csv_line(line: &str) -> Vec { + line.split(',').map(str::to_string).collect() +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn report_agent_job_result_rejects_wrong_thread() -> Result<()> { + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Collab); + config.features.enable(Feature::Sqlite); + }); + let test = builder.build(&server).await?; + + let input_path = test.cwd_path().join("agent_jobs_wrong_thread.csv"); + let output_path = test.cwd_path().join("agent_jobs_wrong_thread_out.csv"); + fs::write(&input_path, "path\nfile-1\n")?; + + let args = json!({ + "csv_path": input_path.display().to_string(), + "instruction": "Return {path}", + "output_csv_path": output_path.display().to_string(), + }); + let args_json = serde_json::to_string(&args)?; + + let responder = AgentJobsResponder::new(args_json); + Mock::given(method("POST")) + .and(path_regex(".*/responses$")) + .respond_with(responder) + .mount(&server) + .await; + + test.submit_turn("run job").await?; + + let db = test.codex.state_db().expect("state db"); + let output = fs::read_to_string(&output_path)?; + let rows: Vec<&str> = output.lines().skip(1).collect(); + assert_eq!(rows.len(), 1); + let job_id = rows + .first() + .and_then(|line| { + parse_simple_csv_line(line) + .iter() + .find(|value| value.len() == 36) + .cloned() + }) + .expect("job_id from csv"); + let job = db.get_agent_job(job_id.as_str()).await?.expect("job"); + let items = db + .list_agent_job_items(job.id.as_str(), None, Some(10)) + .await?; + let item = items.first().expect("item"); + let wrong_thread_id = "00000000-0000-0000-0000-000000000000"; + let accepted = db + .report_agent_job_item_result( + job.id.as_str(), + item.item_id.as_str(), + wrong_thread_id, + &json!({ "wrong": true }), + ) + .await?; + assert!(!accepted); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn spawn_agents_on_csv_runs_and_exports() -> Result<()> { + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Collab); + config.features.enable(Feature::Sqlite); + }); + let test = builder.build(&server).await?; + + let input_path = test.cwd_path().join("agent_jobs_input.csv"); + let output_path = test.cwd_path().join("agent_jobs_output.csv"); + fs::write(&input_path, "path,area\nfile-1,test\nfile-2,test\n")?; + + let args = json!({ + "csv_path": input_path.display().to_string(), + "instruction": "Return {path}", + "output_csv_path": output_path.display().to_string(), + }); + let args_json = serde_json::to_string(&args)?; + + let responder = AgentJobsResponder::new(args_json); + Mock::given(method("POST")) + .and(path_regex(".*/responses$")) + .respond_with(responder) + .mount(&server) + .await; + + test.submit_turn("run batch job").await?; + + let output = fs::read_to_string(&output_path)?; + assert!(output.contains("result_json")); + assert!(output.contains("item_id")); + assert!(output.contains("\"item_id\"")); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn spawn_agents_on_csv_dedupes_item_ids() -> Result<()> { + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Collab); + config.features.enable(Feature::Sqlite); + }); + let test = builder.build(&server).await?; + + let input_path = test.cwd_path().join("agent_jobs_dupe.csv"); + let output_path = test.cwd_path().join("agent_jobs_dupe_out.csv"); + fs::write(&input_path, "id,path\nfoo,alpha\nfoo,beta\n")?; + + let args = json!({ + "csv_path": input_path.display().to_string(), + "instruction": "Return {path}", + "id_column": "id", + "output_csv_path": output_path.display().to_string(), + }); + let args_json = serde_json::to_string(&args)?; + + let responder = AgentJobsResponder::new(args_json); + Mock::given(method("POST")) + .and(path_regex(".*/responses$")) + .respond_with(responder) + .mount(&server) + .await; + + test.submit_turn("run batch job with duplicate ids").await?; + + let output = fs::read_to_string(&output_path)?; + let mut lines = output.lines(); + let headers = lines.next().expect("csv headers"); + let header_cols = parse_simple_csv_line(headers); + let item_id_index = header_cols + .iter() + .position(|header| header == "item_id") + .expect("item_id column"); + + let mut item_ids = Vec::new(); + for line in lines { + let cols = parse_simple_csv_line(line); + item_ids.push(cols[item_id_index].clone()); + } + item_ids.sort(); + item_ids.dedup(); + assert_eq!(item_ids.len(), 2); + assert!(item_ids.contains(&"foo".to_string())); + assert!(item_ids.contains(&"foo-2".to_string())); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn spawn_agents_on_csv_stop_halts_future_items() -> Result<()> { + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Collab); + config.features.enable(Feature::Sqlite); + }); + let test = builder.build(&server).await?; + + let input_path = test.cwd_path().join("agent_jobs_stop.csv"); + let output_path = test.cwd_path().join("agent_jobs_stop_out.csv"); + fs::write(&input_path, "path\nfile-1\nfile-2\nfile-3\n")?; + + let args = json!({ + "csv_path": input_path.display().to_string(), + "instruction": "Return {path}", + "output_csv_path": output_path.display().to_string(), + "max_concurrency": 1, + }); + let args_json = serde_json::to_string(&args)?; + + let worker_calls = Arc::new(AtomicUsize::new(0)); + let responder = StopAfterFirstResponder::new(args_json, worker_calls.clone()); + Mock::given(method("POST")) + .and(path_regex(".*/responses$")) + .respond_with(responder) + .mount(&server) + .await; + + test.submit_turn("run job").await?; + + let output = fs::read_to_string(&output_path)?; + let rows: Vec<&str> = output.lines().skip(1).collect(); + assert_eq!(rows.len(), 3); + let job_id = rows + .first() + .and_then(|line| { + parse_simple_csv_line(line) + .iter() + .find(|value| value.len() == 36) + .cloned() + }) + .expect("job_id from csv"); + let db = test.codex.state_db().expect("state db"); + let job = db.get_agent_job(job_id.as_str()).await?.expect("job"); + assert_eq!(job.status, codex_state::AgentJobStatus::Cancelled); + let progress = db.get_agent_job_progress(job_id.as_str()).await?; + assert_eq!(progress.total_items, 3); + assert_eq!(progress.completed_items, 1); + assert_eq!(progress.failed_items, 0); + assert_eq!(progress.running_items, 0); + assert_eq!(progress.pending_items, 2); + assert_eq!(worker_calls.load(Ordering::SeqCst), 1); + Ok(()) +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index c5abe2eaa..f3b64e1e7 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -56,6 +56,7 @@ pub static CODEX_ALIASES_TEMP_DIR: TestCodexAliasesGuard = unsafe { #[cfg(not(target_os = "windows"))] mod abort_tasks; +mod agent_jobs; mod agent_websocket; mod apply_patch_cli; #[cfg(not(target_os = "windows"))] diff --git a/codex-rs/core/tests/suite/shell_snapshot.rs b/codex-rs/core/tests/suite/shell_snapshot.rs index 2a8017b39..3be9d2b57 100644 --- a/codex-rs/core/tests/suite/shell_snapshot.rs +++ b/codex-rs/core/tests/suite/shell_snapshot.rs @@ -52,10 +52,16 @@ async fn wait_for_snapshot(codex_home: &Path) -> Result { let snapshot_dir = codex_home.join("shell_snapshots"); let deadline = Instant::now() + Duration::from_secs(5); loop { - if let Ok(mut entries) = fs::read_dir(&snapshot_dir).await - && let Some(entry) = entries.next_entry().await? - { - return Ok(entry.path()); + if let Ok(mut entries) = fs::read_dir(&snapshot_dir).await { + while let Some(entry) = entries.next_entry().await? { + let path = entry.path(); + let Some(extension) = path.extension().and_then(|ext| ext.to_str()) else { + continue; + }; + if extension == "sh" || extension == "ps1" { + return Ok(path); + } + } } if Instant::now() >= deadline { diff --git a/codex-rs/core/tests/suite/sqlite_state.rs b/codex-rs/core/tests/suite/sqlite_state.rs index 8f0e5834e..113e837b3 100644 --- a/codex-rs/core/tests/suite/sqlite_state.rs +++ b/codex-rs/core/tests/suite/sqlite_state.rs @@ -33,7 +33,7 @@ async fn new_thread_is_recorded_in_state_db() -> Result<()> { let thread_id = test.session_configured.session_id; let rollout_path = test.codex.rollout_path().expect("rollout path"); - let db_path = codex_state::state_db_path(test.config.codex_home.as_path()); + let db_path = codex_state::state_db_path(test.config.sqlite_home.as_path()); for _ in 0..100 { if tokio::fs::try_exists(&db_path).await.unwrap_or(false) { @@ -161,7 +161,7 @@ async fn backfill_scans_existing_rollouts() -> Result<()> { let test = builder.build(&server).await?; - let db_path = codex_state::state_db_path(test.config.codex_home.as_path()); + let db_path = codex_state::state_db_path(test.config.sqlite_home.as_path()); let rollout_path = test.config.codex_home.join(&rollout_rel_path); let default_provider = test.config.model_provider_id.clone(); @@ -220,7 +220,7 @@ async fn user_messages_persist_in_state_db() -> Result<()> { }); let test = builder.build(&server).await?; - let db_path = codex_state::state_db_path(test.config.codex_home.as_path()); + let db_path = codex_state::state_db_path(test.config.sqlite_home.as_path()); for _ in 0..100 { if tokio::fs::try_exists(&db_path).await.unwrap_or(false) { break; diff --git a/codex-rs/exec/src/cli.rs b/codex-rs/exec/src/cli.rs index c40237160..6cda7e408 100644 --- a/codex-rs/exec/src/cli.rs +++ b/codex-rs/exec/src/cli.rs @@ -86,6 +86,10 @@ pub struct Cli { #[arg(long = "color", value_enum, default_value_t = Color::Auto)] pub color: Color, + /// Force cursor-based progress updates in exec mode. + #[arg(long = "progress-cursor", default_value_t = false)] + pub progress_cursor: bool, + /// Print events to stdout as JSONL. #[arg( long = "json", diff --git a/codex-rs/exec/src/event_processor_with_human_output.rs b/codex-rs/exec/src/event_processor_with_human_output.rs index f3c92b76d..6e679052e 100644 --- a/codex-rs/exec/src/event_processor_with_human_output.rs +++ b/codex-rs/exec/src/event_processor_with_human_output.rs @@ -38,9 +38,12 @@ use codex_utils_elapsed::format_duration; use codex_utils_elapsed::format_elapsed; use owo_colors::OwoColorize; use owo_colors::Style; +use serde::Deserialize; use shlex::try_join; use std::collections::HashMap; +use std::io::Write; use std::path::PathBuf; +use std::time::Duration; use std::time::Instant; use crate::event_processor::CodexStatus; @@ -76,11 +79,17 @@ pub(crate) struct EventProcessorWithHumanOutput { last_total_token_usage: Option, final_message: Option, last_proposed_plan: Option, + progress_active: bool, + progress_last_len: usize, + use_ansi_cursor: bool, + progress_anchor: bool, + progress_done: bool, } impl EventProcessorWithHumanOutput { pub(crate) fn create_with_ansi( with_ansi: bool, + cursor_ansi: bool, config: &Config, last_message_path: Option, ) -> Self { @@ -103,6 +112,11 @@ impl EventProcessorWithHumanOutput { last_total_token_usage: None, final_message: None, last_proposed_plan: None, + progress_active: false, + progress_last_len: 0, + use_ansi_cursor: cursor_ansi, + progress_anchor: false, + progress_done: false, } } else { Self { @@ -121,11 +135,27 @@ impl EventProcessorWithHumanOutput { last_total_token_usage: None, final_message: None, last_proposed_plan: None, + progress_active: false, + progress_last_len: 0, + use_ansi_cursor: cursor_ansi, + progress_anchor: false, + progress_done: false, } } } } +#[derive(Debug, Deserialize)] +struct AgentJobProgressMessage { + job_id: String, + total_items: usize, + pending_items: usize, + running_items: usize, + completed_items: usize, + failed_items: usize, + eta_seconds: Option, +} + struct PatchApplyBegin { start_time: Instant, auto_approved: bool, @@ -176,6 +206,18 @@ impl EventProcessor for EventProcessorWithHumanOutput { fn process_event(&mut self, event: Event) -> CodexStatus { let Event { id: _, msg } = event; + if let EventMsg::BackgroundEvent(BackgroundEventEvent { message }) = &msg + && let Some(update) = Self::parse_agent_job_progress(message) + { + self.render_agent_job_progress(update); + return CodexStatus::Running; + } + if self.progress_active && !Self::should_interrupt_progress(&msg) { + return CodexStatus::Running; + } + if !Self::is_silent_event(&msg) { + self.finish_progress_line(); + } match msg { EventMsg::Error(ErrorEvent { message, .. }) => { let prefix = "ERROR:".style(self.red); @@ -818,6 +860,7 @@ impl EventProcessor for EventProcessorWithHumanOutput { } fn print_final_output(&mut self) { + self.finish_progress_line(); if let Some(usage_info) = &self.last_total_token_usage { eprintln!( "{}\n{}", @@ -841,6 +884,207 @@ impl EventProcessor for EventProcessorWithHumanOutput { } } +impl EventProcessorWithHumanOutput { + fn parse_agent_job_progress(message: &str) -> Option { + let payload = message.strip_prefix("agent_job_progress:")?; + serde_json::from_str::(payload).ok() + } + + fn is_silent_event(msg: &EventMsg) -> bool { + matches!( + msg, + EventMsg::ThreadNameUpdated(_) + | EventMsg::TokenCount(_) + | EventMsg::TurnStarted(_) + | EventMsg::ExecApprovalRequest(_) + | EventMsg::ApplyPatchApprovalRequest(_) + | EventMsg::TerminalInteraction(_) + | EventMsg::ExecCommandOutputDelta(_) + | EventMsg::GetHistoryEntryResponse(_) + | EventMsg::McpListToolsResponse(_) + | EventMsg::ListCustomPromptsResponse(_) + | EventMsg::ListSkillsResponse(_) + | EventMsg::ListRemoteSkillsResponse(_) + | EventMsg::RemoteSkillDownloaded(_) + | EventMsg::RawResponseItem(_) + | EventMsg::UserMessage(_) + | EventMsg::EnteredReviewMode(_) + | EventMsg::ExitedReviewMode(_) + | EventMsg::AgentMessageDelta(_) + | EventMsg::AgentReasoningDelta(_) + | EventMsg::AgentReasoningRawContentDelta(_) + | EventMsg::ItemStarted(_) + | EventMsg::ItemCompleted(_) + | EventMsg::AgentMessageContentDelta(_) + | EventMsg::PlanDelta(_) + | EventMsg::ReasoningContentDelta(_) + | EventMsg::ReasoningRawContentDelta(_) + | EventMsg::SkillsUpdateAvailable + | EventMsg::UndoCompleted(_) + | EventMsg::UndoStarted(_) + | EventMsg::ThreadRolledBack(_) + | EventMsg::RequestUserInput(_) + | EventMsg::DynamicToolCallRequest(_) + ) + } + + fn should_interrupt_progress(msg: &EventMsg) -> bool { + matches!( + msg, + EventMsg::Error(_) + | EventMsg::Warning(_) + | EventMsg::DeprecationNotice(_) + | EventMsg::StreamError(_) + | EventMsg::TurnComplete(_) + | EventMsg::ShutdownComplete + ) + } + + fn finish_progress_line(&mut self) { + if self.progress_active { + self.progress_active = false; + self.progress_last_len = 0; + self.progress_done = false; + if self.use_ansi_cursor { + if self.progress_anchor { + eprintln!("\u{1b}[1A\u{1b}[1G\u{1b}[2K"); + } else { + eprintln!("\u{1b}[1G\u{1b}[2K"); + } + } else { + eprintln!(); + } + self.progress_anchor = false; + } + } + + fn render_agent_job_progress(&mut self, update: AgentJobProgressMessage) { + let total = update.total_items.max(1); + let processed = update.completed_items + update.failed_items; + let percent = (processed as f64 / total as f64 * 100.0).round() as i64; + let job_label = update.job_id.chars().take(8).collect::(); + let eta = update + .eta_seconds + .map(|secs| format_duration(Duration::from_secs(secs))) + .unwrap_or_else(|| "--".to_string()); + let columns = std::env::var("COLUMNS") + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|value| *value > 0); + let line = format_agent_job_progress_line( + columns, + job_label.as_str(), + AgentJobProgressStats { + processed, + total, + percent, + failed: update.failed_items, + running: update.running_items, + pending: update.pending_items, + }, + eta.as_str(), + ); + let done = processed >= update.total_items; + if !self.use_ansi_cursor { + eprintln!("{line}"); + if done { + self.progress_active = false; + self.progress_last_len = 0; + } + return; + } + if done && self.progress_done { + return; + } + if !self.progress_active { + eprintln!(); + self.progress_anchor = true; + self.progress_done = false; + } + let mut output = String::new(); + if self.progress_anchor { + output.push_str("\u{1b}[1A\u{1b}[1G\u{1b}[2K"); + } else { + output.push_str("\u{1b}[1G\u{1b}[2K"); + } + output.push_str(&line); + if done { + output.push('\n'); + eprint!("{output}"); + self.progress_active = false; + self.progress_last_len = 0; + self.progress_anchor = false; + self.progress_done = true; + return; + } + eprint!("{output}"); + let _ = std::io::stderr().flush(); + self.progress_active = true; + self.progress_last_len = line.len(); + } +} + +struct AgentJobProgressStats { + processed: usize, + total: usize, + percent: i64, + failed: usize, + running: usize, + pending: usize, +} + +fn format_agent_job_progress_line( + columns: Option, + job_label: &str, + stats: AgentJobProgressStats, + eta: &str, +) -> String { + let rest = format!( + "{processed}/{total} {percent}% f{failed} r{running} p{pending} eta {eta}", + processed = stats.processed, + total = stats.total, + percent = stats.percent, + failed = stats.failed, + running = stats.running, + pending = stats.pending + ); + let prefix = format!("job {job_label}"); + let base_len = prefix.len() + rest.len() + 4; + let mut bar_width = columns + .and_then(|columns| columns.checked_sub(base_len)) + .filter(|available| *available > 0) + .unwrap_or(20usize); + let with_bar = |width: usize| { + let filled = ((stats.processed as f64 / stats.total as f64) * width as f64) + .round() + .clamp(0.0, width as f64) as usize; + let mut bar = "#".repeat(filled); + bar.push_str(&"-".repeat(width - filled)); + format!("{prefix} [{bar}] {rest}") + }; + let mut line = with_bar(bar_width); + if let Some(columns) = columns + && line.len() > columns + { + let min_line = format!("{prefix} {rest}"); + if min_line.len() > columns { + let mut truncated = min_line; + if columns > 2 && truncated.len() > columns { + truncated.truncate(columns - 2); + truncated.push_str(".."); + } + return truncated; + } + let available = columns.saturating_sub(base_len); + if available == 0 { + return min_line; + } + bar_width = available.min(bar_width).max(1); + line = with_bar(bar_width); + } + line +} + fn escape_command(command: &[String]) -> String { try_join(command.iter().map(String::as_str)).unwrap_or_else(|_| command.join(" ")) } diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 89180d89e..382081dd3 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -41,6 +41,7 @@ use codex_protocol::protocol::Op; use codex_protocol::protocol::ReviewRequest; use codex_protocol::protocol::ReviewTarget; use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; use codex_protocol::user_input::UserInput; use codex_utils_absolute_path::AbsolutePathBuf; use codex_utils_oss::ensure_oss_provider_ready; @@ -86,6 +87,7 @@ struct ThreadEventEnvelope { thread_id: codex_protocol::ThreadId, thread: Arc, event: Event, + suppress_output: bool, } pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> anyhow::Result<()> { @@ -113,9 +115,10 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any prompt, output_schema: output_schema_path, config_overrides, + progress_cursor, } = cli; - let (stdout_with_ansi, stderr_with_ansi) = match color { + let (_stdout_with_ansi, stderr_with_ansi) = match color { cli::Color::Always => (true, true), cli::Color::Never => (false, false), cli::Color::Auto => ( @@ -123,6 +126,24 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any supports_color::on_cached(Stream::Stderr).is_some(), ), }; + let cursor_ansi = if progress_cursor { + true + } else { + match color { + cli::Color::Never => false, + cli::Color::Always => true, + cli::Color::Auto => { + if stderr_with_ansi || std::io::stderr().is_terminal() { + true + } else { + match std::env::var("TERM") { + Ok(term) => !term.is_empty() && term != "dumb", + Err(_) => false, + } + } + } + } + }; // Build fmt layer (existing logging) to compose with OTEL layer. let default_level = "error"; @@ -318,7 +339,8 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any let mut event_processor: Box = match json_mode { true => Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone())), _ => Box::new(EventProcessorWithHumanOutput::create_with_ansi( - stdout_with_ansi, + stderr_with_ansi, + cursor_ansi, &config, last_message_file.clone(), )), @@ -466,7 +488,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); let attached_threads = Arc::new(Mutex::new(HashSet::from([primary_thread_id]))); - spawn_thread_listener(primary_thread_id, thread.clone(), tx.clone()); + spawn_thread_listener(primary_thread_id, thread.clone(), tx.clone(), false); { let thread = thread.clone(); @@ -494,7 +516,14 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any match thread_manager.get_thread(thread_id).await { Ok(thread) => { attached_threads.lock().await.insert(thread_id); - spawn_thread_listener(thread_id, thread, tx.clone()); + let suppress_output = + is_agent_job_subagent(&thread.config_snapshot().await); + spawn_thread_listener( + thread_id, + thread, + tx.clone(), + suppress_output, + ); } Err(err) => { warn!("failed to attach listener for thread {thread_id}: {err}") @@ -549,7 +578,11 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any thread_id, thread, event, + suppress_output, } = envelope; + if suppress_output && should_suppress_agent_job_event(&event.msg) { + continue; + } if matches!(event.msg, EventMsg::Error(_)) { error_seen = true; } @@ -613,6 +646,7 @@ fn spawn_thread_listener( thread_id: codex_protocol::ThreadId, thread: Arc, tx: tokio::sync::mpsc::UnboundedSender, + suppress_output: bool, ) { tokio::spawn(async move { loop { @@ -625,6 +659,7 @@ fn spawn_thread_listener( thread_id, thread: Arc::clone(&thread), event, + suppress_output, }) { error!("Error sending event: {err:?}"); break; @@ -645,6 +680,29 @@ fn spawn_thread_listener( }); } +fn is_agent_job_subagent(config: &codex_core::ThreadConfigSnapshot) -> bool { + match &config.session_source { + SessionSource::SubAgent(SubAgentSource::Other(source)) => source.starts_with("agent_job:"), + _ => false, + } +} + +fn should_suppress_agent_job_event(msg: &EventMsg) -> bool { + !matches!( + msg, + EventMsg::ExecApprovalRequest(_) + | EventMsg::ApplyPatchApprovalRequest(_) + | EventMsg::RequestUserInput(_) + | EventMsg::DynamicToolCallRequest(_) + | EventMsg::ElicitationRequest(_) + | EventMsg::Error(_) + | EventMsg::Warning(_) + | EventMsg::DeprecationNotice(_) + | EventMsg::StreamError(_) + | EventMsg::ShutdownComplete + ) +} + async fn resolve_resume_path( config: &Config, args: &crate::cli::ResumeArgs, diff --git a/codex-rs/state/migrations/0014_agent_jobs.sql b/codex-rs/state/migrations/0014_agent_jobs.sql new file mode 100644 index 000000000..d8968c405 --- /dev/null +++ b/codex-rs/state/migrations/0014_agent_jobs.sql @@ -0,0 +1,38 @@ +CREATE TABLE agent_jobs ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + status TEXT NOT NULL, + instruction TEXT NOT NULL, + output_schema_json TEXT, + input_headers_json TEXT NOT NULL, + input_csv_path TEXT NOT NULL, + output_csv_path TEXT NOT NULL, + auto_export INTEGER NOT NULL DEFAULT 1, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + started_at INTEGER, + completed_at INTEGER, + last_error TEXT +); + +CREATE TABLE agent_job_items ( + job_id TEXT NOT NULL, + item_id TEXT NOT NULL, + row_index INTEGER NOT NULL, + source_id TEXT, + row_json TEXT NOT NULL, + status TEXT NOT NULL, + assigned_thread_id TEXT, + attempt_count INTEGER NOT NULL DEFAULT 0, + result_json TEXT, + last_error TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + completed_at INTEGER, + reported_at INTEGER, + PRIMARY KEY (job_id, item_id), + FOREIGN KEY(job_id) REFERENCES agent_jobs(id) ON DELETE CASCADE +); + +CREATE INDEX idx_agent_jobs_status ON agent_jobs(status, updated_at DESC); +CREATE INDEX idx_agent_job_items_status ON agent_job_items(job_id, status, row_index ASC); diff --git a/codex-rs/state/migrations/0015_agent_jobs_max_runtime_seconds.sql b/codex-rs/state/migrations/0015_agent_jobs_max_runtime_seconds.sql new file mode 100644 index 000000000..ab1006d45 --- /dev/null +++ b/codex-rs/state/migrations/0015_agent_jobs_max_runtime_seconds.sql @@ -0,0 +1,2 @@ +ALTER TABLE agent_jobs +ADD COLUMN max_runtime_seconds INTEGER; diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index 062db796a..59607f3d9 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -22,6 +22,13 @@ pub use runtime::StateRuntime; /// /// Most consumers should prefer [`StateRuntime`]. pub use extract::apply_rollout_item; +pub use model::AgentJob; +pub use model::AgentJobCreateParams; +pub use model::AgentJobItem; +pub use model::AgentJobItemCreateParams; +pub use model::AgentJobItemStatus; +pub use model::AgentJobProgress; +pub use model::AgentJobStatus; pub use model::Anchor; pub use model::BackfillState; pub use model::BackfillStats; @@ -38,6 +45,9 @@ pub use model::ThreadsPage; pub use runtime::state_db_filename; pub use runtime::state_db_path; +/// Environment variable for overriding the SQLite state database home directory. +pub const SQLITE_HOME_ENV: &str = "CODEX_SQLITE_HOME"; + pub const STATE_DB_FILENAME: &str = "state"; pub const STATE_DB_VERSION: u32 = 5; diff --git a/codex-rs/state/src/model/agent_job.rs b/codex-rs/state/src/model/agent_job.rs new file mode 100644 index 000000000..36f9ff12d --- /dev/null +++ b/codex-rs/state/src/model/agent_job.rs @@ -0,0 +1,256 @@ +use anyhow::Result; +use chrono::DateTime; +use chrono::Utc; +use serde_json::Value; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AgentJobStatus { + Pending, + Running, + Completed, + Failed, + Cancelled, +} + +impl AgentJobStatus { + pub const fn as_str(self) -> &'static str { + match self { + AgentJobStatus::Pending => "pending", + AgentJobStatus::Running => "running", + AgentJobStatus::Completed => "completed", + AgentJobStatus::Failed => "failed", + AgentJobStatus::Cancelled => "cancelled", + } + } + + pub fn parse(value: &str) -> Result { + match value { + "pending" => Ok(Self::Pending), + "running" => Ok(Self::Running), + "completed" => Ok(Self::Completed), + "failed" => Ok(Self::Failed), + "cancelled" => Ok(Self::Cancelled), + _ => Err(anyhow::anyhow!("invalid agent job status: {value}")), + } + } + + pub fn is_final(self) -> bool { + matches!( + self, + AgentJobStatus::Completed | AgentJobStatus::Failed | AgentJobStatus::Cancelled + ) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AgentJobItemStatus { + Pending, + Running, + Completed, + Failed, +} + +impl AgentJobItemStatus { + pub const fn as_str(self) -> &'static str { + match self { + AgentJobItemStatus::Pending => "pending", + AgentJobItemStatus::Running => "running", + AgentJobItemStatus::Completed => "completed", + AgentJobItemStatus::Failed => "failed", + } + } + + pub fn parse(value: &str) -> Result { + match value { + "pending" => Ok(Self::Pending), + "running" => Ok(Self::Running), + "completed" => Ok(Self::Completed), + "failed" => Ok(Self::Failed), + _ => Err(anyhow::anyhow!("invalid agent job item status: {value}")), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AgentJob { + pub id: String, + pub name: String, + pub status: AgentJobStatus, + pub instruction: String, + pub auto_export: bool, + pub max_runtime_seconds: Option, + // TODO(jif-oai): Convert to JSON Schema and enforce structured outputs. + pub output_schema_json: Option, + pub input_headers: Vec, + pub input_csv_path: String, + pub output_csv_path: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub started_at: Option>, + pub completed_at: Option>, + pub last_error: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AgentJobItem { + pub job_id: String, + pub item_id: String, + pub row_index: i64, + pub source_id: Option, + pub row_json: Value, + pub status: AgentJobItemStatus, + pub assigned_thread_id: Option, + pub attempt_count: i64, + pub result_json: Option, + pub last_error: Option, + pub created_at: DateTime, + pub updated_at: DateTime, + pub completed_at: Option>, + pub reported_at: Option>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AgentJobProgress { + pub total_items: usize, + pub pending_items: usize, + pub running_items: usize, + pub completed_items: usize, + pub failed_items: usize, +} + +#[derive(Debug, Clone)] +pub struct AgentJobCreateParams { + pub id: String, + pub name: String, + pub instruction: String, + pub auto_export: bool, + pub max_runtime_seconds: Option, + pub output_schema_json: Option, + pub input_headers: Vec, + pub input_csv_path: String, + pub output_csv_path: String, +} + +#[derive(Debug, Clone)] +pub struct AgentJobItemCreateParams { + pub item_id: String, + pub row_index: i64, + pub source_id: Option, + pub row_json: Value, +} + +#[derive(Debug, sqlx::FromRow)] +pub(crate) struct AgentJobRow { + pub(crate) id: String, + pub(crate) name: String, + pub(crate) status: String, + pub(crate) instruction: String, + pub(crate) auto_export: i64, + pub(crate) max_runtime_seconds: Option, + pub(crate) output_schema_json: Option, + pub(crate) input_headers_json: String, + pub(crate) input_csv_path: String, + pub(crate) output_csv_path: String, + pub(crate) created_at: i64, + pub(crate) updated_at: i64, + pub(crate) started_at: Option, + pub(crate) completed_at: Option, + pub(crate) last_error: Option, +} + +impl TryFrom for AgentJob { + type Error = anyhow::Error; + + fn try_from(value: AgentJobRow) -> Result { + let output_schema_json = value + .output_schema_json + .as_deref() + .map(serde_json::from_str) + .transpose()?; + let input_headers = serde_json::from_str(value.input_headers_json.as_str())?; + let max_runtime_seconds = value + .max_runtime_seconds + .map(u64::try_from) + .transpose() + .map_err(|_| anyhow::anyhow!("invalid max_runtime_seconds value"))?; + Ok(Self { + id: value.id, + name: value.name, + status: AgentJobStatus::parse(value.status.as_str())?, + instruction: value.instruction, + auto_export: value.auto_export != 0, + max_runtime_seconds, + output_schema_json, + input_headers, + input_csv_path: value.input_csv_path, + output_csv_path: value.output_csv_path, + created_at: epoch_seconds_to_datetime(value.created_at)?, + updated_at: epoch_seconds_to_datetime(value.updated_at)?, + started_at: value + .started_at + .map(epoch_seconds_to_datetime) + .transpose()?, + completed_at: value + .completed_at + .map(epoch_seconds_to_datetime) + .transpose()?, + last_error: value.last_error, + }) + } +} + +#[derive(Debug, sqlx::FromRow)] +pub(crate) struct AgentJobItemRow { + pub(crate) job_id: String, + pub(crate) item_id: String, + pub(crate) row_index: i64, + pub(crate) source_id: Option, + pub(crate) row_json: String, + pub(crate) status: String, + pub(crate) assigned_thread_id: Option, + pub(crate) attempt_count: i64, + pub(crate) result_json: Option, + pub(crate) last_error: Option, + pub(crate) created_at: i64, + pub(crate) updated_at: i64, + pub(crate) completed_at: Option, + pub(crate) reported_at: Option, +} + +impl TryFrom for AgentJobItem { + type Error = anyhow::Error; + + fn try_from(value: AgentJobItemRow) -> Result { + Ok(Self { + job_id: value.job_id, + item_id: value.item_id, + row_index: value.row_index, + source_id: value.source_id, + row_json: serde_json::from_str(value.row_json.as_str())?, + status: AgentJobItemStatus::parse(value.status.as_str())?, + assigned_thread_id: value.assigned_thread_id, + attempt_count: value.attempt_count, + result_json: value + .result_json + .as_deref() + .map(serde_json::from_str) + .transpose()?, + last_error: value.last_error, + created_at: epoch_seconds_to_datetime(value.created_at)?, + updated_at: epoch_seconds_to_datetime(value.updated_at)?, + completed_at: value + .completed_at + .map(epoch_seconds_to_datetime) + .transpose()?, + reported_at: value + .reported_at + .map(epoch_seconds_to_datetime) + .transpose()?, + }) + } +} + +fn epoch_seconds_to_datetime(secs: i64) -> Result> { + DateTime::::from_timestamp(secs, 0) + .ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}")) +} diff --git a/codex-rs/state/src/model/mod.rs b/codex-rs/state/src/model/mod.rs index 2a0009d64..816c036f8 100644 --- a/codex-rs/state/src/model/mod.rs +++ b/codex-rs/state/src/model/mod.rs @@ -1,8 +1,16 @@ +mod agent_job; mod backfill_state; mod log; mod memories; mod thread_metadata; +pub use agent_job::AgentJob; +pub use agent_job::AgentJobCreateParams; +pub use agent_job::AgentJobItem; +pub use agent_job::AgentJobItemCreateParams; +pub use agent_job::AgentJobItemStatus; +pub use agent_job::AgentJobProgress; +pub use agent_job::AgentJobStatus; pub use backfill_state::BackfillState; pub use backfill_state::BackfillStatus; pub use log::LogEntry; @@ -21,6 +29,8 @@ pub use thread_metadata::ThreadMetadata; pub use thread_metadata::ThreadMetadataBuilder; pub use thread_metadata::ThreadsPage; +pub(crate) use agent_job::AgentJobItemRow; +pub(crate) use agent_job::AgentJobRow; pub(crate) use memories::Stage1OutputRow; pub(crate) use thread_metadata::ThreadRow; pub(crate) use thread_metadata::anchor_from_item; diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index f027cac6e..29d85cdd5 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -1,3 +1,10 @@ +use crate::AgentJob; +use crate::AgentJobCreateParams; +use crate::AgentJobItem; +use crate::AgentJobItemCreateParams; +use crate::AgentJobItemStatus; +use crate::AgentJobProgress; +use crate::AgentJobStatus; use crate::DB_ERROR_METRIC; use crate::LogEntry; use crate::LogQuery; @@ -11,6 +18,8 @@ use crate::ThreadMetadataBuilder; use crate::ThreadsPage; use crate::apply_rollout_item; use crate::migrations::MIGRATOR; +use crate::model::AgentJobItemRow; +use crate::model::AgentJobRow; use crate::model::ThreadRow; use crate::model::anchor_from_item; use crate::model::datetime_to_epoch_seconds; @@ -901,6 +910,564 @@ ON CONFLICT(thread_id, position) DO NOTHING Ok(result.rows_affected()) } + pub async fn create_agent_job( + &self, + params: &AgentJobCreateParams, + items: &[AgentJobItemCreateParams], + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let input_headers_json = serde_json::to_string(¶ms.input_headers)?; + let output_schema_json = params + .output_schema_json + .as_ref() + .map(serde_json::to_string) + .transpose()?; + let max_runtime_seconds = params + .max_runtime_seconds + .map(i64::try_from) + .transpose() + .map_err(|_| anyhow::anyhow!("invalid max_runtime_seconds value"))?; + let mut tx = self.pool.begin().await?; + sqlx::query( + r#" +INSERT INTO agent_jobs ( + id, + name, + status, + instruction, + auto_export, + max_runtime_seconds, + output_schema_json, + input_headers_json, + input_csv_path, + output_csv_path, + created_at, + updated_at, + started_at, + completed_at, + last_error +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL, NULL) + "#, + ) + .bind(params.id.as_str()) + .bind(params.name.as_str()) + .bind(AgentJobStatus::Pending.as_str()) + .bind(params.instruction.as_str()) + .bind(i64::from(params.auto_export)) + .bind(max_runtime_seconds) + .bind(output_schema_json) + .bind(input_headers_json) + .bind(params.input_csv_path.as_str()) + .bind(params.output_csv_path.as_str()) + .bind(now) + .bind(now) + .execute(&mut *tx) + .await?; + + for item in items { + let row_json = serde_json::to_string(&item.row_json)?; + sqlx::query( + r#" +INSERT INTO agent_job_items ( + job_id, + item_id, + row_index, + source_id, + row_json, + status, + assigned_thread_id, + attempt_count, + result_json, + last_error, + created_at, + updated_at, + completed_at, + reported_at +) VALUES (?, ?, ?, ?, ?, ?, NULL, 0, NULL, NULL, ?, ?, NULL, NULL) + "#, + ) + .bind(params.id.as_str()) + .bind(item.item_id.as_str()) + .bind(item.row_index) + .bind(item.source_id.as_deref()) + .bind(row_json) + .bind(AgentJobItemStatus::Pending.as_str()) + .bind(now) + .bind(now) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + let job_id = params.id.as_str(); + self.get_agent_job(job_id) + .await? + .ok_or_else(|| anyhow::anyhow!("failed to load created agent job {job_id}")) + } + + pub async fn get_agent_job(&self, job_id: &str) -> anyhow::Result> { + let row = sqlx::query_as::<_, AgentJobRow>( + r#" +SELECT + id, + name, + status, + instruction, + auto_export, + max_runtime_seconds, + output_schema_json, + input_headers_json, + input_csv_path, + output_csv_path, + created_at, + updated_at, + started_at, + completed_at, + last_error +FROM agent_jobs +WHERE id = ? + "#, + ) + .bind(job_id) + .fetch_optional(self.pool.as_ref()) + .await?; + row.map(AgentJob::try_from).transpose() + } + + pub async fn list_agent_job_items( + &self, + job_id: &str, + status: Option, + limit: Option, + ) -> anyhow::Result> { + let mut builder = QueryBuilder::::new( + r#" +SELECT + job_id, + item_id, + row_index, + source_id, + row_json, + status, + assigned_thread_id, + attempt_count, + result_json, + last_error, + created_at, + updated_at, + completed_at, + reported_at +FROM agent_job_items +WHERE job_id = + "#, + ); + builder.push_bind(job_id); + if let Some(status) = status { + builder.push(" AND status = "); + builder.push_bind(status.as_str()); + } + builder.push(" ORDER BY row_index ASC"); + if let Some(limit) = limit { + builder.push(" LIMIT "); + builder.push_bind(limit as i64); + } + let rows = builder + .build_query_as::() + .fetch_all(self.pool.as_ref()) + .await?; + rows.into_iter().map(AgentJobItem::try_from).collect() + } + + pub async fn get_agent_job_item( + &self, + job_id: &str, + item_id: &str, + ) -> anyhow::Result> { + let row = sqlx::query_as::<_, AgentJobItemRow>( + r#" +SELECT + job_id, + item_id, + row_index, + source_id, + row_json, + status, + assigned_thread_id, + attempt_count, + result_json, + last_error, + created_at, + updated_at, + completed_at, + reported_at +FROM agent_job_items +WHERE job_id = ? AND item_id = ? + "#, + ) + .bind(job_id) + .bind(item_id) + .fetch_optional(self.pool.as_ref()) + .await?; + row.map(AgentJobItem::try_from).transpose() + } + + pub async fn mark_agent_job_running(&self, job_id: &str) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + sqlx::query( + r#" +UPDATE agent_jobs +SET + status = ?, + updated_at = ?, + started_at = COALESCE(started_at, ?), + completed_at = NULL, + last_error = NULL +WHERE id = ? + "#, + ) + .bind(AgentJobStatus::Running.as_str()) + .bind(now) + .bind(now) + .bind(job_id) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + pub async fn mark_agent_job_completed(&self, job_id: &str) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + sqlx::query( + r#" +UPDATE agent_jobs +SET status = ?, updated_at = ?, completed_at = ?, last_error = NULL +WHERE id = ? + "#, + ) + .bind(AgentJobStatus::Completed.as_str()) + .bind(now) + .bind(now) + .bind(job_id) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + pub async fn mark_agent_job_failed( + &self, + job_id: &str, + error_message: &str, + ) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + sqlx::query( + r#" +UPDATE agent_jobs +SET status = ?, updated_at = ?, completed_at = ?, last_error = ? +WHERE id = ? + "#, + ) + .bind(AgentJobStatus::Failed.as_str()) + .bind(now) + .bind(now) + .bind(error_message) + .bind(job_id) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + pub async fn mark_agent_job_cancelled( + &self, + job_id: &str, + reason: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_jobs +SET status = ?, updated_at = ?, completed_at = ?, last_error = ? +WHERE id = ? AND status IN (?, ?) + "#, + ) + .bind(AgentJobStatus::Cancelled.as_str()) + .bind(now) + .bind(now) + .bind(reason) + .bind(job_id) + .bind(AgentJobStatus::Pending.as_str()) + .bind(AgentJobStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn is_agent_job_cancelled(&self, job_id: &str) -> anyhow::Result { + let row = sqlx::query( + r#" +SELECT status +FROM agent_jobs +WHERE id = ? + "#, + ) + .bind(job_id) + .fetch_optional(self.pool.as_ref()) + .await?; + let Some(row) = row else { + return Ok(false); + }; + let status: String = row.try_get("status")?; + Ok(AgentJobStatus::parse(status.as_str())? == AgentJobStatus::Cancelled) + } + + pub async fn mark_agent_job_item_running( + &self, + job_id: &str, + item_id: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + assigned_thread_id = NULL, + attempt_count = attempt_count + 1, + updated_at = ?, + last_error = NULL +WHERE job_id = ? AND item_id = ? AND status = ? + "#, + ) + .bind(AgentJobItemStatus::Running.as_str()) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Pending.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn mark_agent_job_item_running_with_thread( + &self, + job_id: &str, + item_id: &str, + thread_id: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + assigned_thread_id = ?, + attempt_count = attempt_count + 1, + updated_at = ?, + last_error = NULL +WHERE job_id = ? AND item_id = ? AND status = ? + "#, + ) + .bind(AgentJobItemStatus::Running.as_str()) + .bind(thread_id) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Pending.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn mark_agent_job_item_pending( + &self, + job_id: &str, + item_id: &str, + error_message: Option<&str>, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + assigned_thread_id = NULL, + updated_at = ?, + last_error = ? +WHERE job_id = ? AND item_id = ? AND status = ? + "#, + ) + .bind(AgentJobItemStatus::Pending.as_str()) + .bind(now) + .bind(error_message) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn set_agent_job_item_thread( + &self, + job_id: &str, + item_id: &str, + thread_id: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET assigned_thread_id = ?, updated_at = ? +WHERE job_id = ? AND item_id = ? AND status = ? + "#, + ) + .bind(thread_id) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn report_agent_job_item_result( + &self, + job_id: &str, + item_id: &str, + reporting_thread_id: &str, + result_json: &Value, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let serialized = serde_json::to_string(result_json)?; + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + result_json = ?, + reported_at = ?, + updated_at = ?, + last_error = NULL +WHERE + job_id = ? + AND item_id = ? + AND status = ? + AND assigned_thread_id = ? + "#, + ) + .bind(serialized) + .bind(now) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .bind(reporting_thread_id) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn mark_agent_job_item_completed( + &self, + job_id: &str, + item_id: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + completed_at = ?, + updated_at = ?, + assigned_thread_id = NULL +WHERE + job_id = ? + AND item_id = ? + AND status = ? + AND result_json IS NOT NULL + "#, + ) + .bind(AgentJobItemStatus::Completed.as_str()) + .bind(now) + .bind(now) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn mark_agent_job_item_failed( + &self, + job_id: &str, + item_id: &str, + error_message: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let result = sqlx::query( + r#" +UPDATE agent_job_items +SET + status = ?, + completed_at = ?, + updated_at = ?, + last_error = ?, + assigned_thread_id = NULL +WHERE + job_id = ? + AND item_id = ? + AND status = ? + "#, + ) + .bind(AgentJobItemStatus::Failed.as_str()) + .bind(now) + .bind(now) + .bind(error_message) + .bind(job_id) + .bind(item_id) + .bind(AgentJobItemStatus::Running.as_str()) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected() > 0) + } + + pub async fn get_agent_job_progress(&self, job_id: &str) -> anyhow::Result { + let row = sqlx::query( + r#" +SELECT + COUNT(*) AS total_items, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS pending_items, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS running_items, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS completed_items, + SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS failed_items +FROM agent_job_items +WHERE job_id = ? + "#, + ) + .bind(AgentJobItemStatus::Pending.as_str()) + .bind(AgentJobItemStatus::Running.as_str()) + .bind(AgentJobItemStatus::Completed.as_str()) + .bind(AgentJobItemStatus::Failed.as_str()) + .bind(job_id) + .fetch_one(self.pool.as_ref()) + .await?; + + let total_items: i64 = row.try_get("total_items")?; + let pending_items: Option = row.try_get("pending_items")?; + let running_items: Option = row.try_get("running_items")?; + let completed_items: Option = row.try_get("completed_items")?; + let failed_items: Option = row.try_get("failed_items")?; + Ok(AgentJobProgress { + total_items: usize::try_from(total_items).unwrap_or_default(), + pending_items: usize::try_from(pending_items.unwrap_or_default()).unwrap_or_default(), + running_items: usize::try_from(running_items.unwrap_or_default()).unwrap_or_default(), + completed_items: usize::try_from(completed_items.unwrap_or_default()) + .unwrap_or_default(), + failed_items: usize::try_from(failed_items.unwrap_or_default()).unwrap_or_default(), + }) + } + async fn ensure_backfill_state_row(&self) -> anyhow::Result<()> { sqlx::query( r#" diff --git a/docs/config.md b/docs/config.md index ed80e559c..30665bb11 100644 --- a/docs/config.md +++ b/docs/config.md @@ -28,6 +28,12 @@ Codex can run a notification hook when the agent finishes a turn. See the config The generated JSON Schema for `config.toml` lives at `codex-rs/core/config.schema.json`. +## SQLite State DB + +Codex stores the SQLite-backed state DB under `sqlite_home` (config key) or the +`CODEX_SQLITE_HOME` environment variable. When unset, WorkspaceWrite sandbox +sessions default to a temp directory; other modes default to `CODEX_HOME`. + ## Notices Codex stores "do not show again" flags for some UI prompts under the `[notice]` table.