feat: add debug clear-memories command to hard-wipe memories state (#13085)
#### what adds a `codex debug clear-memories` command to help with clearing all memories state from disk, sqlite db, and marking threads as `memory_mode=disabled` so they don't get resummarized when the `memories` feature is re-enabled. #### tests add tests
This commit is contained in:
parent
8c1e3f3e64
commit
033ef9cb9d
5 changed files with 344 additions and 0 deletions
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
|
|
@ -1603,6 +1603,7 @@ dependencies = [
|
|||
"codex-protocol",
|
||||
"codex-responses-api-proxy",
|
||||
"codex-rmcp-client",
|
||||
"codex-state",
|
||||
"codex-stdio-to-uds",
|
||||
"codex-tui",
|
||||
"codex-utils-cargo-bin",
|
||||
|
|
@ -1614,6 +1615,7 @@ dependencies = [
|
|||
"pretty_assertions",
|
||||
"regex-lite",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
"supports-color 3.0.2",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ codex-mcp-server = { workspace = true }
|
|||
codex-protocol = { workspace = true }
|
||||
codex-responses-api-proxy = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-state = { workspace = true }
|
||||
codex-stdio-to-uds = { workspace = true }
|
||||
codex-tui = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
|
|
@ -62,3 +63,4 @@ assert_matches = { workspace = true }
|
|||
codex-utils-cargo-bin = { workspace = true }
|
||||
predicates = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ use codex_exec::Command as ExecCommand;
|
|||
use codex_exec::ReviewArgs;
|
||||
use codex_execpolicy::ExecPolicyCheckCommand;
|
||||
use codex_responses_api_proxy::Args as ResponsesApiProxyArgs;
|
||||
use codex_state::StateRuntime;
|
||||
use codex_state::state_db_path;
|
||||
use codex_tui::AppExitInfo;
|
||||
use codex_tui::Cli as TuiCli;
|
||||
use codex_tui::ExitReason;
|
||||
|
|
@ -163,6 +165,10 @@ struct DebugCommand {
|
|||
enum DebugSubcommand {
|
||||
/// Tooling: helps debug the app server.
|
||||
AppServer(DebugAppServerCommand),
|
||||
|
||||
/// Internal: reset local memory state for a fresh start.
|
||||
#[clap(hide = true)]
|
||||
ClearMemories,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
|
|
@ -751,6 +757,9 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> {
|
|||
DebugSubcommand::AppServer(cmd) => {
|
||||
run_debug_app_server_command(cmd)?;
|
||||
}
|
||||
DebugSubcommand::ClearMemories => {
|
||||
run_debug_clear_memories_command(&root_config_overrides, &interactive).await?;
|
||||
}
|
||||
},
|
||||
Some(Subcommand::Execpolicy(ExecpolicyCommand { sub })) => match sub {
|
||||
ExecpolicySubcommand::Check(cmd) => run_execpolicycheck(cmd)?,
|
||||
|
|
@ -877,6 +886,60 @@ fn maybe_print_under_development_feature_warning(
|
|||
);
|
||||
}
|
||||
|
||||
async fn run_debug_clear_memories_command(
|
||||
root_config_overrides: &CliConfigOverrides,
|
||||
interactive: &TuiCli,
|
||||
) -> anyhow::Result<()> {
|
||||
let cli_kv_overrides = root_config_overrides
|
||||
.parse_overrides()
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let overrides = ConfigOverrides {
|
||||
config_profile: interactive.config_profile.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
let config =
|
||||
Config::load_with_cli_overrides_and_harness_overrides(cli_kv_overrides, overrides).await?;
|
||||
|
||||
let state_path = state_db_path(config.sqlite_home.as_path());
|
||||
let mut cleared_state_db = false;
|
||||
if tokio::fs::try_exists(&state_path).await? {
|
||||
let state_db = StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.model_provider_id.clone(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
state_db.reset_memory_data_for_fresh_start().await?;
|
||||
cleared_state_db = true;
|
||||
}
|
||||
|
||||
let memory_root = config.codex_home.join("memories");
|
||||
let removed_memory_root = match tokio::fs::remove_dir_all(&memory_root).await {
|
||||
Ok(()) => true,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => false,
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
|
||||
let mut message = if cleared_state_db {
|
||||
format!("Cleared memory state from {}.", state_path.display())
|
||||
} else {
|
||||
format!("No state db found at {}.", state_path.display())
|
||||
};
|
||||
|
||||
if removed_memory_root {
|
||||
message.push_str(&format!(" Removed {}.", memory_root.display()));
|
||||
} else {
|
||||
message.push_str(&format!(
|
||||
" No memory directory found at {}.",
|
||||
memory_root.display()
|
||||
));
|
||||
}
|
||||
|
||||
println!("{message}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepend root-level overrides so they have lower precedence than
|
||||
/// CLI-specific ones specified after the subcommand (if any).
|
||||
fn prepend_config_flags(
|
||||
|
|
|
|||
141
codex-rs/cli/tests/debug_clear_memories.rs
Normal file
141
codex-rs/cli/tests/debug_clear_memories.rs
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_state::StateRuntime;
|
||||
use codex_state::state_db_path;
|
||||
use predicates::str::contains;
|
||||
use sqlx::SqlitePool;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn codex_command(codex_home: &Path) -> Result<assert_cmd::Command> {
|
||||
let mut cmd = assert_cmd::Command::new(codex_utils_cargo_bin::cargo_bin("codex")?);
|
||||
cmd.env("CODEX_HOME", codex_home);
|
||||
Ok(cmd)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn debug_clear_memories_resets_state_and_removes_memory_dir() -> Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let runtime = StateRuntime::init(
|
||||
codex_home.path().to_path_buf(),
|
||||
"test-provider".to_string(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
drop(runtime);
|
||||
|
||||
let thread_id = "00000000-0000-0000-0000-000000000123";
|
||||
let db_path = state_db_path(codex_home.path());
|
||||
let pool = SqlitePool::connect(&format!("sqlite://{}", db_path.display())).await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO threads (
|
||||
id,
|
||||
rollout_path,
|
||||
created_at,
|
||||
updated_at,
|
||||
source,
|
||||
agent_nickname,
|
||||
agent_role,
|
||||
model_provider,
|
||||
cwd,
|
||||
cli_version,
|
||||
title,
|
||||
sandbox_policy,
|
||||
approval_mode,
|
||||
tokens_used,
|
||||
first_user_message,
|
||||
archived,
|
||||
archived_at,
|
||||
git_sha,
|
||||
git_branch,
|
||||
git_origin_url,
|
||||
memory_mode
|
||||
) VALUES (?, ?, 1, 1, 'cli', NULL, NULL, 'test-provider', ?, '', '', 'read-only', 'on-request', 0, '', 0, NULL, NULL, NULL, NULL, 'enabled')
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.bind(codex_home.path().join("session.jsonl").display().to_string())
|
||||
.bind(codex_home.path().display().to_string())
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO stage1_outputs (
|
||||
thread_id,
|
||||
source_updated_at,
|
||||
raw_memory,
|
||||
rollout_summary,
|
||||
generated_at,
|
||||
rollout_slug,
|
||||
usage_count,
|
||||
last_usage,
|
||||
selected_for_phase2,
|
||||
selected_for_phase2_source_updated_at
|
||||
) VALUES (?, 1, 'raw', 'summary', 1, NULL, 0, NULL, 0, NULL)
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO jobs (
|
||||
kind,
|
||||
job_key,
|
||||
status,
|
||||
worker_id,
|
||||
ownership_token,
|
||||
started_at,
|
||||
finished_at,
|
||||
lease_until,
|
||||
retry_at,
|
||||
retry_remaining,
|
||||
last_error,
|
||||
input_watermark,
|
||||
last_success_watermark
|
||||
) VALUES
|
||||
('memory_stage1', ?, 'completed', NULL, NULL, NULL, NULL, NULL, NULL, 3, NULL, NULL, 1),
|
||||
('memory_consolidate_global', 'global', 'completed', NULL, NULL, NULL, NULL, NULL, NULL, 3, NULL, NULL, 1)
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
|
||||
let memory_root = codex_home.path().join("memories");
|
||||
std::fs::create_dir_all(&memory_root)?;
|
||||
std::fs::write(memory_root.join("memory_summary.md"), "stale memory")?;
|
||||
drop(pool);
|
||||
|
||||
let mut cmd = codex_command(codex_home.path())?;
|
||||
cmd.args(["debug", "clear-memories"])
|
||||
.assert()
|
||||
.success()
|
||||
.stdout(contains("Cleared memory state"));
|
||||
|
||||
let pool = SqlitePool::connect(&format!("sqlite://{}", db_path.display())).await?;
|
||||
let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs")
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
assert_eq!(stage1_outputs_count, 0);
|
||||
|
||||
let memory_jobs_count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1' OR kind = 'memory_consolidate_global'",
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
assert_eq!(memory_jobs_count, 0);
|
||||
|
||||
let memory_mode: String = sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
|
||||
.bind(thread_id)
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
assert_eq!(memory_mode, "disabled");
|
||||
assert!(!memory_root.exists());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -30,6 +30,19 @@ impl StateRuntime {
|
|||
/// stage-1 (`memory_stage1`) and phase-2 (`memory_consolidate_global`)
|
||||
/// memory pipelines.
|
||||
pub async fn clear_memory_data(&self) -> anyhow::Result<()> {
|
||||
self.clear_memory_data_inner(false).await
|
||||
}
|
||||
|
||||
/// Resets persisted memory state for a clean-slate local start.
|
||||
///
|
||||
/// In addition to clearing persisted stage-1 outputs and memory pipeline
|
||||
/// jobs, this disables memory generation for all existing threads so
|
||||
/// historical rollouts are not immediately picked up again.
|
||||
pub async fn reset_memory_data_for_fresh_start(&self) -> anyhow::Result<()> {
|
||||
self.clear_memory_data_inner(true).await
|
||||
}
|
||||
|
||||
async fn clear_memory_data_inner(&self, disable_existing_threads: bool) -> anyhow::Result<()> {
|
||||
let mut tx = self.pool.begin().await?;
|
||||
|
||||
sqlx::query(
|
||||
|
|
@ -51,6 +64,18 @@ WHERE kind = ? OR kind = ?
|
|||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
if disable_existing_threads {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE threads
|
||||
SET memory_mode = 'disabled'
|
||||
WHERE memory_mode = 'enabled'
|
||||
"#,
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1158,6 +1183,8 @@ ON CONFLICT(kind, job_key) DO UPDATE SET
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL;
|
||||
use super::JOB_KIND_MEMORY_STAGE1;
|
||||
use super::StateRuntime;
|
||||
use super::test_support::test_thread_metadata;
|
||||
use super::test_support::unique_temp_dir;
|
||||
|
|
@ -1664,6 +1691,115 @@ mod tests {
|
|||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reset_memory_data_for_fresh_start_clears_rows_and_disables_threads() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let now = Utc::now() - Duration::hours(13);
|
||||
let worker_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("worker id");
|
||||
let enabled_thread_id =
|
||||
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("enabled thread id");
|
||||
let disabled_thread_id =
|
||||
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("disabled thread id");
|
||||
|
||||
let mut enabled =
|
||||
test_thread_metadata(&codex_home, enabled_thread_id, codex_home.join("enabled"));
|
||||
enabled.created_at = now;
|
||||
enabled.updated_at = now;
|
||||
runtime
|
||||
.upsert_thread(&enabled)
|
||||
.await
|
||||
.expect("upsert enabled thread");
|
||||
|
||||
let claim = runtime
|
||||
.try_claim_stage1_job(
|
||||
enabled_thread_id,
|
||||
worker_id,
|
||||
enabled.updated_at.timestamp(),
|
||||
3600,
|
||||
64,
|
||||
)
|
||||
.await
|
||||
.expect("claim enabled thread");
|
||||
let ownership_token = match claim {
|
||||
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
|
||||
other => panic!("unexpected claim outcome: {other:?}"),
|
||||
};
|
||||
assert!(
|
||||
runtime
|
||||
.mark_stage1_job_succeeded(
|
||||
enabled_thread_id,
|
||||
ownership_token.as_str(),
|
||||
enabled.updated_at.timestamp(),
|
||||
"raw",
|
||||
"summary",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("mark enabled thread succeeded"),
|
||||
"stage1 success should be recorded"
|
||||
);
|
||||
runtime
|
||||
.enqueue_global_consolidation(enabled.updated_at.timestamp())
|
||||
.await
|
||||
.expect("enqueue global consolidation");
|
||||
|
||||
let mut disabled =
|
||||
test_thread_metadata(&codex_home, disabled_thread_id, codex_home.join("disabled"));
|
||||
disabled.created_at = now;
|
||||
disabled.updated_at = now;
|
||||
runtime
|
||||
.upsert_thread(&disabled)
|
||||
.await
|
||||
.expect("upsert disabled thread");
|
||||
sqlx::query("UPDATE threads SET memory_mode = 'disabled' WHERE id = ?")
|
||||
.bind(disabled_thread_id.to_string())
|
||||
.execute(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("disable existing thread");
|
||||
|
||||
runtime
|
||||
.reset_memory_data_for_fresh_start()
|
||||
.await
|
||||
.expect("reset memory data");
|
||||
|
||||
let stage1_outputs_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs")
|
||||
.fetch_one(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("count stage1 outputs");
|
||||
assert_eq!(stage1_outputs_count, 0);
|
||||
|
||||
let memory_jobs_count: i64 =
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM jobs WHERE kind = ? OR kind = ?")
|
||||
.bind(JOB_KIND_MEMORY_STAGE1)
|
||||
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
|
||||
.fetch_one(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("count memory jobs");
|
||||
assert_eq!(memory_jobs_count, 0);
|
||||
|
||||
let enabled_memory_mode: String =
|
||||
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
|
||||
.bind(enabled_thread_id.to_string())
|
||||
.fetch_one(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("read enabled thread memory mode");
|
||||
assert_eq!(enabled_memory_mode, "disabled");
|
||||
|
||||
let disabled_memory_mode: String =
|
||||
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
|
||||
.bind(disabled_thread_id.to_string())
|
||||
.fetch_one(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("read disabled thread memory mode");
|
||||
assert_eq!(disabled_memory_mode, "disabled");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn claim_stage1_jobs_enforces_global_running_cap() {
|
||||
let codex_home = unique_temp_dir();
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue