diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 8398d0b99..2ab1d5943 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1603,6 +1603,7 @@ dependencies = [ "codex-protocol", "codex-responses-api-proxy", "codex-rmcp-client", + "codex-state", "codex-stdio-to-uds", "codex-tui", "codex-utils-cargo-bin", @@ -1614,6 +1615,7 @@ dependencies = [ "pretty_assertions", "regex-lite", "serde_json", + "sqlx", "supports-color 3.0.2", "tempfile", "tokio", diff --git a/codex-rs/cli/Cargo.toml b/codex-rs/cli/Cargo.toml index 1d955bbaa..b197cc3d5 100644 --- a/codex-rs/cli/Cargo.toml +++ b/codex-rs/cli/Cargo.toml @@ -35,6 +35,7 @@ codex-mcp-server = { workspace = true } codex-protocol = { workspace = true } codex-responses-api-proxy = { workspace = true } codex-rmcp-client = { workspace = true } +codex-state = { workspace = true } codex-stdio-to-uds = { workspace = true } codex-tui = { workspace = true } libc = { workspace = true } @@ -62,3 +63,4 @@ assert_matches = { workspace = true } codex-utils-cargo-bin = { workspace = true } predicates = { workspace = true } pretty_assertions = { workspace = true } +sqlx = { workspace = true } diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index c271fe0b9..f818b743e 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -22,6 +22,8 @@ use codex_exec::Command as ExecCommand; use codex_exec::ReviewArgs; use codex_execpolicy::ExecPolicyCheckCommand; use codex_responses_api_proxy::Args as ResponsesApiProxyArgs; +use codex_state::StateRuntime; +use codex_state::state_db_path; use codex_tui::AppExitInfo; use codex_tui::Cli as TuiCli; use codex_tui::ExitReason; @@ -163,6 +165,10 @@ struct DebugCommand { enum DebugSubcommand { /// Tooling: helps debug the app server. AppServer(DebugAppServerCommand), + + /// Internal: reset local memory state for a fresh start. + #[clap(hide = true)] + ClearMemories, } #[derive(Debug, Parser)] @@ -751,6 +757,9 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> { DebugSubcommand::AppServer(cmd) => { run_debug_app_server_command(cmd)?; } + DebugSubcommand::ClearMemories => { + run_debug_clear_memories_command(&root_config_overrides, &interactive).await?; + } }, Some(Subcommand::Execpolicy(ExecpolicyCommand { sub })) => match sub { ExecpolicySubcommand::Check(cmd) => run_execpolicycheck(cmd)?, @@ -877,6 +886,60 @@ fn maybe_print_under_development_feature_warning( ); } +async fn run_debug_clear_memories_command( + root_config_overrides: &CliConfigOverrides, + interactive: &TuiCli, +) -> anyhow::Result<()> { + let cli_kv_overrides = root_config_overrides + .parse_overrides() + .map_err(anyhow::Error::msg)?; + let overrides = ConfigOverrides { + config_profile: interactive.config_profile.clone(), + ..Default::default() + }; + let config = + Config::load_with_cli_overrides_and_harness_overrides(cli_kv_overrides, overrides).await?; + + let state_path = state_db_path(config.sqlite_home.as_path()); + let mut cleared_state_db = false; + if tokio::fs::try_exists(&state_path).await? { + let state_db = StateRuntime::init( + config.sqlite_home.clone(), + config.model_provider_id.clone(), + None, + ) + .await?; + state_db.reset_memory_data_for_fresh_start().await?; + cleared_state_db = true; + } + + let memory_root = config.codex_home.join("memories"); + let removed_memory_root = match tokio::fs::remove_dir_all(&memory_root).await { + Ok(()) => true, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => false, + Err(err) => return Err(err.into()), + }; + + let mut message = if cleared_state_db { + format!("Cleared memory state from {}.", state_path.display()) + } else { + format!("No state db found at {}.", state_path.display()) + }; + + if removed_memory_root { + message.push_str(&format!(" Removed {}.", memory_root.display())); + } else { + message.push_str(&format!( + " No memory directory found at {}.", + memory_root.display() + )); + } + + println!("{message}"); + + Ok(()) +} + /// Prepend root-level overrides so they have lower precedence than /// CLI-specific ones specified after the subcommand (if any). fn prepend_config_flags( diff --git a/codex-rs/cli/tests/debug_clear_memories.rs b/codex-rs/cli/tests/debug_clear_memories.rs new file mode 100644 index 000000000..d8db1ebbc --- /dev/null +++ b/codex-rs/cli/tests/debug_clear_memories.rs @@ -0,0 +1,141 @@ +use std::path::Path; + +use anyhow::Result; +use codex_state::StateRuntime; +use codex_state::state_db_path; +use predicates::str::contains; +use sqlx::SqlitePool; +use tempfile::TempDir; + +fn codex_command(codex_home: &Path) -> Result { + let mut cmd = assert_cmd::Command::new(codex_utils_cargo_bin::cargo_bin("codex")?); + cmd.env("CODEX_HOME", codex_home); + Ok(cmd) +} + +#[tokio::test] +async fn debug_clear_memories_resets_state_and_removes_memory_dir() -> Result<()> { + let codex_home = TempDir::new()?; + let runtime = StateRuntime::init( + codex_home.path().to_path_buf(), + "test-provider".to_string(), + None, + ) + .await?; + drop(runtime); + + let thread_id = "00000000-0000-0000-0000-000000000123"; + let db_path = state_db_path(codex_home.path()); + let pool = SqlitePool::connect(&format!("sqlite://{}", db_path.display())).await?; + + sqlx::query( + r#" +INSERT INTO threads ( + id, + rollout_path, + created_at, + updated_at, + source, + agent_nickname, + agent_role, + model_provider, + cwd, + cli_version, + title, + sandbox_policy, + approval_mode, + tokens_used, + first_user_message, + archived, + archived_at, + git_sha, + git_branch, + git_origin_url, + memory_mode +) VALUES (?, ?, 1, 1, 'cli', NULL, NULL, 'test-provider', ?, '', '', 'read-only', 'on-request', 0, '', 0, NULL, NULL, NULL, NULL, 'enabled') + "#, + ) + .bind(thread_id) + .bind(codex_home.path().join("session.jsonl").display().to_string()) + .bind(codex_home.path().display().to_string()) + .execute(&pool) + .await?; + + sqlx::query( + r#" +INSERT INTO stage1_outputs ( + thread_id, + source_updated_at, + raw_memory, + rollout_summary, + generated_at, + rollout_slug, + usage_count, + last_usage, + selected_for_phase2, + selected_for_phase2_source_updated_at +) VALUES (?, 1, 'raw', 'summary', 1, NULL, 0, NULL, 0, NULL) + "#, + ) + .bind(thread_id) + .execute(&pool) + .await?; + + sqlx::query( + r#" +INSERT INTO jobs ( + kind, + job_key, + status, + worker_id, + ownership_token, + started_at, + finished_at, + lease_until, + retry_at, + retry_remaining, + last_error, + input_watermark, + last_success_watermark +) VALUES + ('memory_stage1', ?, 'completed', NULL, NULL, NULL, NULL, NULL, NULL, 3, NULL, NULL, 1), + ('memory_consolidate_global', 'global', 'completed', NULL, NULL, NULL, NULL, NULL, NULL, 3, NULL, NULL, 1) + "#, + ) + .bind(thread_id) + .execute(&pool) + .await?; + + let memory_root = codex_home.path().join("memories"); + std::fs::create_dir_all(&memory_root)?; + std::fs::write(memory_root.join("memory_summary.md"), "stale memory")?; + drop(pool); + + let mut cmd = codex_command(codex_home.path())?; + cmd.args(["debug", "clear-memories"]) + .assert() + .success() + .stdout(contains("Cleared memory state")); + + let pool = SqlitePool::connect(&format!("sqlite://{}", db_path.display())).await?; + let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs") + .fetch_one(&pool) + .await?; + assert_eq!(stage1_outputs_count, 0); + + let memory_jobs_count: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1' OR kind = 'memory_consolidate_global'", + ) + .fetch_one(&pool) + .await?; + assert_eq!(memory_jobs_count, 0); + + let memory_mode: String = sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?") + .bind(thread_id) + .fetch_one(&pool) + .await?; + assert_eq!(memory_mode, "disabled"); + assert!(!memory_root.exists()); + + Ok(()) +} diff --git a/codex-rs/state/src/runtime/memories.rs b/codex-rs/state/src/runtime/memories.rs index 908919d0a..a36f9e5aa 100644 --- a/codex-rs/state/src/runtime/memories.rs +++ b/codex-rs/state/src/runtime/memories.rs @@ -30,6 +30,19 @@ impl StateRuntime { /// stage-1 (`memory_stage1`) and phase-2 (`memory_consolidate_global`) /// memory pipelines. pub async fn clear_memory_data(&self) -> anyhow::Result<()> { + self.clear_memory_data_inner(false).await + } + + /// Resets persisted memory state for a clean-slate local start. + /// + /// In addition to clearing persisted stage-1 outputs and memory pipeline + /// jobs, this disables memory generation for all existing threads so + /// historical rollouts are not immediately picked up again. + pub async fn reset_memory_data_for_fresh_start(&self) -> anyhow::Result<()> { + self.clear_memory_data_inner(true).await + } + + async fn clear_memory_data_inner(&self, disable_existing_threads: bool) -> anyhow::Result<()> { let mut tx = self.pool.begin().await?; sqlx::query( @@ -51,6 +64,18 @@ WHERE kind = ? OR kind = ? .execute(&mut *tx) .await?; + if disable_existing_threads { + sqlx::query( + r#" +UPDATE threads +SET memory_mode = 'disabled' +WHERE memory_mode = 'enabled' + "#, + ) + .execute(&mut *tx) + .await?; + } + tx.commit().await?; Ok(()) } @@ -1158,6 +1183,8 @@ ON CONFLICT(kind, job_key) DO UPDATE SET #[cfg(test)] mod tests { + use super::JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL; + use super::JOB_KIND_MEMORY_STAGE1; use super::StateRuntime; use super::test_support::test_thread_metadata; use super::test_support::unique_temp_dir; @@ -1664,6 +1691,115 @@ mod tests { let _ = tokio::fs::remove_dir_all(codex_home).await; } + #[tokio::test] + async fn reset_memory_data_for_fresh_start_clears_rows_and_disables_threads() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let now = Utc::now() - Duration::hours(13); + let worker_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("worker id"); + let enabled_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("enabled thread id"); + let disabled_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("disabled thread id"); + + let mut enabled = + test_thread_metadata(&codex_home, enabled_thread_id, codex_home.join("enabled")); + enabled.created_at = now; + enabled.updated_at = now; + runtime + .upsert_thread(&enabled) + .await + .expect("upsert enabled thread"); + + let claim = runtime + .try_claim_stage1_job( + enabled_thread_id, + worker_id, + enabled.updated_at.timestamp(), + 3600, + 64, + ) + .await + .expect("claim enabled thread"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + enabled_thread_id, + ownership_token.as_str(), + enabled.updated_at.timestamp(), + "raw", + "summary", + None, + ) + .await + .expect("mark enabled thread succeeded"), + "stage1 success should be recorded" + ); + runtime + .enqueue_global_consolidation(enabled.updated_at.timestamp()) + .await + .expect("enqueue global consolidation"); + + let mut disabled = + test_thread_metadata(&codex_home, disabled_thread_id, codex_home.join("disabled")); + disabled.created_at = now; + disabled.updated_at = now; + runtime + .upsert_thread(&disabled) + .await + .expect("upsert disabled thread"); + sqlx::query("UPDATE threads SET memory_mode = 'disabled' WHERE id = ?") + .bind(disabled_thread_id.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("disable existing thread"); + + runtime + .reset_memory_data_for_fresh_start() + .await + .expect("reset memory data"); + + let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs") + .fetch_one(runtime.pool.as_ref()) + .await + .expect("count stage1 outputs"); + assert_eq!(stage1_outputs_count, 0); + + let memory_jobs_count: i64 = + sqlx::query_scalar("SELECT COUNT(*) FROM jobs WHERE kind = ? OR kind = ?") + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("count memory jobs"); + assert_eq!(memory_jobs_count, 0); + + let enabled_memory_mode: String = + sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?") + .bind(enabled_thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("read enabled thread memory mode"); + assert_eq!(enabled_memory_mode, "disabled"); + + let disabled_memory_mode: String = + sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?") + .bind(disabled_thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("read disabled thread memory mode"); + assert_eq!(disabled_memory_mode, "disabled"); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + #[tokio::test] async fn claim_stage1_jobs_enforces_global_running_cap() { let codex_home = unique_temp_dir();