From 1d5eba0090d61444029fea1b44eae76238dcdd67 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 10 Feb 2026 13:42:09 +0000 Subject: [PATCH] feat: align memory phase 1 and make it stronger (#11300) ## Align with the new phase-1 design Basically we know run phase 1 in parallel by considering: * Max 64 rollouts * Max 1 month old * Consider the most recent first This PR also adds stronger parallelization capabilities by detecting stale jobs, retry policies, ownership of computation to prevent double computations etc etc --- codex-rs/core/src/codex/memory_startup.rs | 667 ++++++++++++--- codex-rs/core/src/memories/mod.rs | 41 +- codex-rs/core/src/memories/selection.rs | 21 +- codex-rs/core/src/memories/storage.rs | 120 +-- codex-rs/core/src/memories/tests.rs | 87 +- codex-rs/core/src/memories/types.rs | 2 - .../migrations/0010_memory_workflow_v2.sql | 85 ++ codex-rs/state/src/lib.rs | 1 + codex-rs/state/src/model/thread_memory.rs | 30 + codex-rs/state/src/runtime.rs | 758 +++++++++++++++++- 10 files changed, 1553 insertions(+), 259 deletions(-) create mode 100644 codex-rs/state/migrations/0010_memory_workflow_v2.sql diff --git a/codex-rs/core/src/codex/memory_startup.rs b/codex-rs/core/src/codex/memory_startup.rs index 4fdfffce4..dda791857 100644 --- a/codex-rs/core/src/codex/memory_startup.rs +++ b/codex-rs/core/src/codex/memory_startup.rs @@ -1,6 +1,27 @@ use super::*; +use chrono::DateTime; +use chrono::Utc; +use sha2::Digest; +use sha2::Sha256; +use std::time::Duration; const MEMORY_STARTUP_STAGE: &str = "run_memories_startup_pipeline"; +const PHASE_ONE_THREAD_SCAN_LIMIT: usize = 5_000; +const PHASE_ONE_DB_LOCK_RETRY_LIMIT: usize = 3; +const PHASE_ONE_DB_LOCK_RETRY_BACKOFF_MS: u64 = 25; + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +struct MemoryScopeTarget { + scope_kind: &'static str, + scope_key: String, + memory_root: PathBuf, +} + +#[derive(Clone, Debug)] +struct ClaimedPhaseOneCandidate { + candidate: memories::RolloutCandidate, + claimed_scopes: Vec<(MemoryScopeTarget, String)>, +} #[derive(Clone)] struct StageOneRequestContext { @@ -55,7 +76,7 @@ pub(super) async fn run_memories_startup_pipeline( let Some(page) = state_db::list_threads_db( session.services.state_db.as_deref(), &config.codex_home, - 200, + PHASE_ONE_THREAD_SCAN_LIMIT, None, ThreadSortKey::UpdatedAt, INTERACTIVE_SESSION_SOURCES, @@ -68,32 +89,26 @@ pub(super) async fn run_memories_startup_pipeline( return Ok(()); }; - let mut existing_memories = Vec::new(); - for item in &page.items { - if let Some(memory) = state_db::get_thread_memory( - session.services.state_db.as_deref(), - item.id, - MEMORY_STARTUP_STAGE, - ) - .await - { - existing_memories.push(memory); - } - } - - let candidates = memories::select_rollout_candidates_from_db( + let selection_candidates = memories::select_rollout_candidates_from_db( &page.items, session.conversation_id, - &existing_memories, - memories::MAX_ROLLOUTS_PER_STARTUP, + PHASE_ONE_THREAD_SCAN_LIMIT, + memories::PHASE_ONE_MAX_ROLLOUT_AGE_DAYS, ); + let claimed_candidates = claim_phase_one_candidates( + session, + config.as_ref(), + selection_candidates, + memories::MAX_ROLLOUTS_PER_STARTUP, + ) + .await; info!( - "memory phase-1 candidate selection complete: {} candidate(s) from {} indexed thread(s)", - candidates.len(), + "memory phase-1 candidate selection complete: {} claimed candidate(s) from {} indexed thread(s)", + claimed_candidates.len(), page.items.len() ); - if candidates.is_empty() { + if claimed_candidates.is_empty() { return Ok(()); } @@ -102,62 +117,173 @@ pub(super) async fn run_memories_startup_pipeline( turn_context.resolve_turn_metadata_header().await, ); - let touched_cwds = - futures::stream::iter(candidates.into_iter()) - .map(|candidate| { + let touched_scope_sets = + futures::stream::iter(claimed_candidates.into_iter()) + .map(|claimed_candidate| { let session = Arc::clone(session); - let config = Arc::clone(&config); let stage_one_context = stage_one_context.clone(); async move { - process_memory_candidate(session, config, candidate, stage_one_context).await + process_memory_candidate(session, claimed_candidate, stage_one_context).await } }) .buffer_unordered(memories::PHASE_ONE_CONCURRENCY_LIMIT) - .filter_map(futures::future::ready) - .collect::>() + .collect::>>() .await; + let touched_scopes = touched_scope_sets + .into_iter() + .flatten() + .collect::>(); info!( - "memory phase-1 extraction complete: {} cwd(s) touched", - touched_cwds.len() + "memory phase-1 extraction complete: {} scope(s) touched", + touched_scopes.len() ); - if touched_cwds.is_empty() { + if touched_scopes.is_empty() { return Ok(()); } - let consolidation_cwd_count = touched_cwds.len(); - futures::stream::iter(touched_cwds.into_iter()) - .map(|cwd| { + let consolidation_scope_count = touched_scopes.len(); + futures::stream::iter(touched_scopes.into_iter()) + .map(|scope| { let session = Arc::clone(session); let config = Arc::clone(&config); async move { - run_memory_consolidation_for_cwd(session, config, cwd).await; + run_memory_consolidation_for_scope(session, config, scope).await; } }) .buffer_unordered(memories::PHASE_ONE_CONCURRENCY_LIMIT) .collect::>() .await; info!( - "memory phase-2 consolidation dispatch complete: {} cwd(s) scheduled", - consolidation_cwd_count + "memory phase-2 consolidation dispatch complete: {} scope(s) scheduled", + consolidation_scope_count ); Ok(()) } +async fn claim_phase_one_candidates( + session: &Session, + config: &Config, + candidates: Vec, + max_claimed_candidates: usize, +) -> Vec { + if max_claimed_candidates == 0 { + return Vec::new(); + } + + let Some(state_db) = session.services.state_db.as_deref() else { + return Vec::new(); + }; + + let mut claimed_candidates = Vec::new(); + for candidate in candidates { + if claimed_candidates.len() >= max_claimed_candidates { + break; + } + + let source_updated_at = parse_source_updated_at_epoch(&candidate); + let mut claimed_scopes = Vec::<(MemoryScopeTarget, String)>::new(); + for scope in memory_scope_targets_for_candidate(config, &candidate) { + let Some(claim) = try_claim_phase1_job_with_retry( + state_db, + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + session.conversation_id, + source_updated_at, + ) + .await + else { + continue; + }; + + if let codex_state::Phase1JobClaimOutcome::Claimed { ownership_token } = claim { + claimed_scopes.push((scope, ownership_token)); + } + } + + if !claimed_scopes.is_empty() { + claimed_candidates.push(ClaimedPhaseOneCandidate { + candidate, + claimed_scopes, + }); + } + } + + claimed_candidates +} + +async fn try_claim_phase1_job_with_retry( + state_db: &codex_state::StateRuntime, + thread_id: ThreadId, + scope_kind: &str, + scope_key: &str, + owner_session_id: ThreadId, + source_updated_at: i64, +) -> Option { + for attempt in 0..=PHASE_ONE_DB_LOCK_RETRY_LIMIT { + match state_db + .try_claim_phase1_job( + thread_id, + scope_kind, + scope_key, + owner_session_id, + source_updated_at, + memories::PHASE_ONE_JOB_LEASE_SECONDS, + ) + .await + { + Ok(claim) => return Some(claim), + Err(err) => { + let is_locked = err.to_string().contains("database is locked"); + if is_locked && attempt < PHASE_ONE_DB_LOCK_RETRY_LIMIT { + tokio::time::sleep(Duration::from_millis( + PHASE_ONE_DB_LOCK_RETRY_BACKOFF_MS * (attempt as u64 + 1), + )) + .await; + continue; + } + warn!("state db try_claim_phase1_job failed during {MEMORY_STARTUP_STAGE}: {err}"); + return None; + } + } + } + None +} + async fn process_memory_candidate( session: Arc, - config: Arc, - candidate: memories::RolloutCandidate, + claimed_candidate: ClaimedPhaseOneCandidate, stage_one_context: StageOneRequestContext, -) -> Option { - let memory_root = memories::memory_root_for_cwd(&config.codex_home, &candidate.cwd); - if let Err(err) = memories::ensure_layout(&memory_root).await { - warn!( - "failed to create memory layout for cwd {}: {err}", - candidate.cwd.display() - ); - return None; +) -> HashSet { + let candidate = claimed_candidate.candidate; + let claimed_scopes = claimed_candidate.claimed_scopes; + + let mut ready_scopes = Vec::<(MemoryScopeTarget, String)>::new(); + for (scope, ownership_token) in claimed_scopes { + if let Err(err) = memories::ensure_layout(&scope.memory_root).await { + warn!( + "failed to create memory layout for scope {}:{} root={}: {err}", + scope.scope_kind, + scope.scope_key, + scope.memory_root.display() + ); + mark_phase1_job_failed_best_effort( + session.as_ref(), + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + &ownership_token, + "failed to create memory layout", + ) + .await; + continue; + } + ready_scopes.push((scope, ownership_token)); + } + if ready_scopes.is_empty() { + return HashSet::new(); } let (rollout_items, _thread_id, parse_errors) = @@ -168,7 +294,14 @@ async fn process_memory_candidate( "failed to load rollout {} for memories: {err}", candidate.rollout_path.display() ); - return None; + fail_claimed_phase_one_jobs( + &session, + &candidate, + &ready_scopes, + "failed to load rollout", + ) + .await; + return HashSet::new(); } }; if parse_errors > 0 { @@ -188,7 +321,14 @@ async fn process_memory_candidate( "failed to prepare filtered rollout payload {} for memories: {err}", candidate.rollout_path.display() ); - return None; + fail_claimed_phase_one_jobs( + &session, + &candidate, + &ready_scopes, + "failed to serialize filtered rollout", + ) + .await; + return HashSet::new(); } }; @@ -232,7 +372,14 @@ async fn process_memory_candidate( "stage-1 memory request failed for rollout {}: {err}", candidate.rollout_path.display() ); - return None; + fail_claimed_phase_one_jobs( + &session, + &candidate, + &ready_scopes, + "stage-1 memory request failed", + ) + .await; + return HashSet::new(); } }; @@ -243,7 +390,14 @@ async fn process_memory_candidate( "failed while waiting for stage-1 memory response for rollout {}: {err}", candidate.rollout_path.display() ); - return None; + fail_claimed_phase_one_jobs( + &session, + &candidate, + &ready_scopes, + "stage-1 memory response stream failed", + ) + .await; + return HashSet::new(); } }; @@ -254,68 +408,288 @@ async fn process_memory_candidate( "invalid stage-1 memory payload for rollout {}: {err}", candidate.rollout_path.display() ); - return None; + fail_claimed_phase_one_jobs( + &session, + &candidate, + &ready_scopes, + "invalid stage-1 memory payload", + ) + .await; + return HashSet::new(); } }; - let raw_memory_path = - match memories::write_raw_memory(&memory_root, &candidate, &stage_one_output.raw_memory) - .await + let mut touched_scopes = HashSet::new(); + for (scope, ownership_token) in &ready_scopes { + if persist_phase_one_memory_for_scope( + &session, + &candidate, + scope, + ownership_token, + &stage_one_output.raw_memory, + &stage_one_output.summary, + ) + .await { - Ok(path) => path, - Err(err) => { - warn!( - "failed to write raw memory for rollout {}: {err}", - candidate.rollout_path.display() - ); - return None; - } - }; - - if state_db::upsert_thread_memory( - session.services.state_db.as_deref(), - candidate.thread_id, - &stage_one_output.raw_memory, - &stage_one_output.summary, - MEMORY_STARTUP_STAGE, - ) - .await - .is_none() - { - warn!( - "failed to upsert thread memory for rollout {}; removing {}", - candidate.rollout_path.display(), - raw_memory_path.display() - ); - if let Err(err) = tokio::fs::remove_file(&raw_memory_path).await - && err.kind() != std::io::ErrorKind::NotFound - { - warn!( - "failed to remove orphaned raw memory {}: {err}", - raw_memory_path.display() - ); + touched_scopes.insert(scope.clone()); } - return None; } - info!( - "memory phase-1 raw memory persisted: rollout={} cwd={} raw_memory_path={}", - candidate.rollout_path.display(), - candidate.cwd.display(), - raw_memory_path.display() - ); - Some(candidate.cwd) + touched_scopes } -async fn run_memory_consolidation_for_cwd( +fn parse_source_updated_at_epoch(candidate: &memories::RolloutCandidate) -> i64 { + candidate + .updated_at + .as_deref() + .and_then(|value| DateTime::parse_from_rfc3339(value).ok()) + .map(|value| value.with_timezone(&Utc).timestamp()) + .unwrap_or_else(|| Utc::now().timestamp()) +} + +fn memory_scope_targets_for_candidate( + config: &Config, + candidate: &memories::RolloutCandidate, +) -> Vec { + vec![ + MemoryScopeTarget { + scope_kind: memories::MEMORY_SCOPE_KIND_CWD, + scope_key: memories::memory_scope_key_for_cwd(&candidate.cwd), + memory_root: memories::memory_root_for_cwd(&config.codex_home, &candidate.cwd), + }, + MemoryScopeTarget { + scope_kind: memories::MEMORY_SCOPE_KIND_USER, + scope_key: memories::MEMORY_SCOPE_KEY_USER.to_string(), + memory_root: memories::memory_root_for_user(&config.codex_home), + }, + ] +} + +async fn fail_claimed_phase_one_jobs( + session: &Session, + candidate: &memories::RolloutCandidate, + claimed_scopes: &[(MemoryScopeTarget, String)], + reason: &str, +) { + for (scope, ownership_token) in claimed_scopes { + mark_phase1_job_failed_best_effort( + session, + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + reason, + ) + .await; + } +} + +async fn persist_phase_one_memory_for_scope( + session: &Session, + candidate: &memories::RolloutCandidate, + scope: &MemoryScopeTarget, + ownership_token: &str, + raw_memory: &str, + summary: &str, +) -> bool { + let Some(state_db) = session.services.state_db.as_deref() else { + mark_phase1_job_failed_best_effort( + session, + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + "state db unavailable for scoped thread memory upsert", + ) + .await; + return false; + }; + + let lease_renewed = match state_db + .renew_phase1_job_lease( + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + ) + .await + { + Ok(renewed) => renewed, + Err(err) => { + warn!("state db renew_phase1_job_lease failed during {MEMORY_STARTUP_STAGE}: {err}"); + return false; + } + }; + if !lease_renewed { + debug!( + "memory phase-1 write skipped after ownership changed: rollout={} scope={} scope_key={}", + candidate.rollout_path.display(), + scope.scope_kind, + scope.scope_key + ); + return false; + } + + let upserted = match state_db + .upsert_thread_memory_for_scope_if_phase1_owner( + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + raw_memory, + summary, + ) + .await + { + Ok(upserted) => upserted, + Err(err) => { + warn!( + "state db upsert_thread_memory_for_scope_if_phase1_owner failed during {MEMORY_STARTUP_STAGE}: {err}" + ); + mark_phase1_job_failed_best_effort( + session, + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + "failed to upsert scoped thread memory", + ) + .await; + return false; + } + }; + if upserted.is_none() { + debug!( + "memory phase-1 db upsert skipped after ownership changed: rollout={} scope={} scope_key={}", + candidate.rollout_path.display(), + scope.scope_kind, + scope.scope_key + ); + return false; + } + + let latest_memories = match state_db + .get_last_n_thread_memories_for_scope( + scope.scope_kind, + &scope.scope_key, + memories::MAX_RAW_MEMORIES_PER_SCOPE, + ) + .await + { + Ok(memories) => memories, + Err(err) => { + warn!( + "state db get_last_n_thread_memories_for_scope failed during {MEMORY_STARTUP_STAGE}: {err}" + ); + mark_phase1_job_failed_best_effort( + session, + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + "failed to read scope memories after upsert", + ) + .await; + return false; + } + }; + + if let Err(err) = + memories::sync_raw_memories_from_memories(&scope.memory_root, &latest_memories).await + { + warn!( + "failed syncing raw memories for scope {}:{} root={}: {err}", + scope.scope_kind, + scope.scope_key, + scope.memory_root.display() + ); + mark_phase1_job_failed_best_effort( + session, + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + "failed to sync scope raw memories", + ) + .await; + return false; + } + + if let Err(err) = + memories::rebuild_memory_summary_from_memories(&scope.memory_root, &latest_memories).await + { + warn!( + "failed rebuilding memory_summary for scope {}:{} root={}: {err}", + scope.scope_kind, + scope.scope_key, + scope.memory_root.display() + ); + mark_phase1_job_failed_best_effort( + session, + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + "failed to rebuild scope memory summary", + ) + .await; + return false; + } + + let mut hasher = Sha256::new(); + hasher.update(summary.as_bytes()); + let summary_hash = format!("{:x}", hasher.finalize()); + let raw_memory_path = scope + .memory_root + .join("raw_memories") + .join(format!("{}.md", candidate.thread_id)); + let marked_succeeded = match state_db + .mark_phase1_job_succeeded( + candidate.thread_id, + scope.scope_kind, + &scope.scope_key, + ownership_token, + &raw_memory_path.display().to_string(), + &summary_hash, + ) + .await + { + Ok(marked) => marked, + Err(err) => { + warn!("state db mark_phase1_job_succeeded failed during {MEMORY_STARTUP_STAGE}: {err}"); + return false; + } + }; + if !marked_succeeded { + return false; + } + + if let Err(err) = state_db + .mark_memory_scope_dirty(scope.scope_kind, &scope.scope_key, true) + .await + { + warn!("state db mark_memory_scope_dirty failed during {MEMORY_STARTUP_STAGE}: {err}"); + } + + info!( + "memory phase-1 raw memory persisted: rollout={} scope={} scope_key={} raw_memory_path={}", + candidate.rollout_path.display(), + scope.scope_kind, + scope.scope_key, + raw_memory_path.display() + ); + true +} + +async fn run_memory_consolidation_for_scope( session: Arc, config: Arc, - cwd: PathBuf, + scope: MemoryScopeTarget, ) { let lock_owner = session.conversation_id; let Some(lock_acquired) = state_db::try_acquire_memory_consolidation_lock( session.services.state_db.as_deref(), - &cwd, + &scope.memory_root, lock_owner, memories::CONSOLIDATION_LOCK_LEASE_SECONDS, MEMORY_STARTUP_STAGE, @@ -323,34 +697,27 @@ async fn run_memory_consolidation_for_cwd( .await else { warn!( - "failed to acquire memory consolidation lock for cwd {}; skipping consolidation", - cwd.display() + "failed to acquire memory consolidation lock for scope {}:{}; skipping consolidation", + scope.scope_kind, scope.scope_key ); return; }; if !lock_acquired { debug!( - "memory consolidation lock already held for cwd {}; skipping", - cwd.display() + "memory consolidation lock already held for scope {}:{}; skipping", + scope.scope_kind, scope.scope_key ); return; } - let Some(latest_memories) = state_db::get_last_n_thread_memories_for_cwd( - session.services.state_db.as_deref(), - &cwd, - memories::MAX_RAW_MEMORIES_PER_CWD, - MEMORY_STARTUP_STAGE, - ) - .await - else { + let Some(state_db) = session.services.state_db.as_deref() else { warn!( - "failed to read recent thread memories for cwd {}; skipping consolidation", - cwd.display() + "state db unavailable for scope {}:{}; skipping consolidation", + scope.scope_kind, scope.scope_key ); let _ = state_db::release_memory_consolidation_lock( session.services.state_db.as_deref(), - &cwd, + &scope.memory_root, lock_owner, MEMORY_STARTUP_STAGE, ) @@ -358,17 +725,41 @@ async fn run_memory_consolidation_for_cwd( return; }; - let memory_root = memories::memory_root_for_cwd(&config.codex_home, &cwd); + let latest_memories = match state_db + .get_last_n_thread_memories_for_scope( + scope.scope_kind, + &scope.scope_key, + memories::MAX_RAW_MEMORIES_PER_SCOPE, + ) + .await + { + Ok(memories) => memories, + Err(err) => { + warn!( + "state db get_last_n_thread_memories_for_scope failed during {MEMORY_STARTUP_STAGE}: {err}" + ); + let _ = state_db::release_memory_consolidation_lock( + session.services.state_db.as_deref(), + &scope.memory_root, + lock_owner, + MEMORY_STARTUP_STAGE, + ) + .await; + return; + } + }; + + let memory_root = scope.memory_root.clone(); if let Err(err) = memories::prune_to_recent_memories_and_rebuild_summary(&memory_root, &latest_memories).await { warn!( - "failed to refresh phase-1 memory outputs for cwd {}: {err}", - cwd.display() + "failed to refresh phase-1 memory outputs for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key ); let _ = state_db::release_memory_consolidation_lock( session.services.state_db.as_deref(), - &cwd, + &scope.memory_root, lock_owner, MEMORY_STARTUP_STAGE, ) @@ -378,12 +769,12 @@ async fn run_memory_consolidation_for_cwd( if let Err(err) = memories::wipe_consolidation_outputs(&memory_root).await { warn!( - "failed to wipe previous consolidation outputs for cwd {}: {err}", - cwd.display() + "failed to wipe previous consolidation outputs for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key ); let _ = state_db::release_memory_consolidation_lock( session.services.state_db.as_deref(), - &cwd, + &scope.memory_root, lock_owner, MEMORY_STARTUP_STAGE, ) @@ -409,25 +800,24 @@ async fn run_memory_consolidation_for_cwd( { Ok(consolidation_agent_id) => { info!( - "memory phase-2 consolidation agent started: cwd={} agent_id={}", - cwd.display(), - consolidation_agent_id + "memory phase-2 consolidation agent started: scope={} scope_key={} agent_id={}", + scope.scope_kind, scope.scope_key, consolidation_agent_id ); spawn_memory_lock_release_task( session.as_ref(), - cwd, + scope.memory_root, lock_owner, consolidation_agent_id, ); } Err(err) => { warn!( - "failed to spawn memory consolidation agent for cwd {}: {err}", - cwd.display() + "failed to spawn memory consolidation agent for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key ); let _ = state_db::release_memory_consolidation_lock( session.services.state_db.as_deref(), - &cwd, + &scope.memory_root, lock_owner, MEMORY_STARTUP_STAGE, ) @@ -495,6 +885,31 @@ fn spawn_memory_lock_release_task( }); } +async fn mark_phase1_job_failed_best_effort( + session: &Session, + thread_id: ThreadId, + scope_kind: &str, + scope_key: &str, + ownership_token: &str, + failure_reason: &str, +) { + let Some(state_db) = session.services.state_db.as_deref() else { + return; + }; + if let Err(err) = state_db + .mark_phase1_job_failed( + thread_id, + scope_kind, + scope_key, + ownership_token, + failure_reason, + ) + .await + { + warn!("state db mark_phase1_job_failed failed during {MEMORY_STARTUP_STAGE}: {err}"); + } +} + async fn collect_response_text_until_completed(stream: &mut ResponseStream) -> CodexResult { let mut output_text = String::new(); diff --git a/codex-rs/core/src/memories/mod.rs b/codex-rs/core/src/memories/mod.rs index e0a6d43ec..9a0ae00d2 100644 --- a/codex-rs/core/src/memories/mod.rs +++ b/codex-rs/core/src/memories/mod.rs @@ -17,13 +17,20 @@ use std::path::PathBuf; /// Subagent source label used to identify consolidation tasks. pub(crate) const MEMORY_CONSOLIDATION_SUBAGENT_LABEL: &str = "memory_consolidation"; /// Maximum number of rollout candidates processed per startup pass. -pub(crate) const MAX_ROLLOUTS_PER_STARTUP: usize = 8; +pub(crate) const MAX_ROLLOUTS_PER_STARTUP: usize = 64; /// Concurrency cap for startup memory extraction and consolidation scheduling. pub(crate) const PHASE_ONE_CONCURRENCY_LIMIT: usize = MAX_ROLLOUTS_PER_STARTUP; -/// Maximum number of recent raw memories retained per working directory. -pub(crate) const MAX_RAW_MEMORIES_PER_CWD: usize = 10; +/// Maximum number of recent raw memories retained per scope. +pub(crate) const MAX_RAW_MEMORIES_PER_SCOPE: usize = 64; +/// Maximum rollout age considered for phase-1 extraction. +pub(crate) const PHASE_ONE_MAX_ROLLOUT_AGE_DAYS: i64 = 30; +/// Lease duration (seconds) for phase-1 job ownership. +pub(crate) const PHASE_ONE_JOB_LEASE_SECONDS: i64 = 3_600; /// Lease duration (seconds) for per-cwd consolidation locks. pub(crate) const CONSOLIDATION_LOCK_LEASE_SECONDS: i64 = 600; +pub(crate) const MEMORY_SCOPE_KIND_CWD: &str = "cwd"; +pub(crate) const MEMORY_SCOPE_KIND_USER: &str = "user"; +pub(crate) const MEMORY_SCOPE_KEY_USER: &str = "user"; const MEMORY_SUBDIR: &str = "memory"; const RAW_MEMORIES_SUBDIR: &str = "raw_memories"; @@ -31,6 +38,7 @@ const MEMORY_SUMMARY_FILENAME: &str = "memory_summary.md"; const MEMORY_REGISTRY_FILENAME: &str = "MEMORY.md"; const LEGACY_CONSOLIDATED_FILENAME: &str = "consolidated.md"; const SKILLS_SUBDIR: &str = "skills"; +const CWD_MEMORY_BUCKET_HEX_LEN: usize = 16; pub(crate) use phase_one::RAW_MEMORY_PROMPT; pub(crate) use phase_one::parse_stage_one_output; @@ -43,8 +51,9 @@ pub(crate) use rollout::StageOneRolloutFilter; pub(crate) use rollout::serialize_filtered_rollout_response_items; pub(crate) use selection::select_rollout_candidates_from_db; pub(crate) use storage::prune_to_recent_memories_and_rebuild_summary; +pub(crate) use storage::rebuild_memory_summary_from_memories; +pub(crate) use storage::sync_raw_memories_from_memories; pub(crate) use storage::wipe_consolidation_outputs; -pub(crate) use storage::write_raw_memory; pub(crate) use types::RolloutCandidate; /// Returns the on-disk memory root directory for a given working directory. @@ -56,6 +65,21 @@ pub(crate) fn memory_root_for_cwd(codex_home: &Path, cwd: &Path) -> PathBuf { codex_home.join("memories").join(bucket).join(MEMORY_SUBDIR) } +/// Returns the DB scope key for a cwd-scoped memory entry. +/// +/// This uses the same normalization/fallback behavior as cwd bucket derivation. +pub(crate) fn memory_scope_key_for_cwd(cwd: &Path) -> String { + normalize_cwd_for_memory(cwd).display().to_string() +} + +/// Returns the on-disk user-shared memory root directory. +pub(crate) fn memory_root_for_user(codex_home: &Path) -> PathBuf { + codex_home + .join("memories") + .join(MEMORY_SCOPE_KEY_USER) + .join(MEMORY_SUBDIR) +} + fn raw_memories_dir(root: &Path) -> PathBuf { root.join(RAW_MEMORIES_SUBDIR) } @@ -70,9 +94,14 @@ pub(crate) async fn ensure_layout(root: &Path) -> std::io::Result<()> { } fn memory_bucket_for_cwd(cwd: &Path) -> String { - let normalized = normalize_for_path_comparison(cwd).unwrap_or_else(|_| cwd.to_path_buf()); + let normalized = normalize_cwd_for_memory(cwd); let normalized = normalized.to_string_lossy(); let mut hasher = Sha256::new(); hasher.update(normalized.as_bytes()); - format!("{:x}", hasher.finalize()) + let full_hash = format!("{:x}", hasher.finalize()); + full_hash[..CWD_MEMORY_BUCKET_HEX_LEN].to_string() +} + +fn normalize_cwd_for_memory(cwd: &Path) -> PathBuf { + normalize_for_path_comparison(cwd).unwrap_or_else(|_| cwd.to_path_buf()) } diff --git a/codex-rs/core/src/memories/selection.rs b/codex-rs/core/src/memories/selection.rs index 1e707e7a6..9b0814943 100644 --- a/codex-rs/core/src/memories/selection.rs +++ b/codex-rs/core/src/memories/selection.rs @@ -1,28 +1,25 @@ +use chrono::Duration; +use chrono::Utc; use codex_protocol::ThreadId; -use codex_state::ThreadMemory; use codex_state::ThreadMetadata; -use std::collections::BTreeMap; use super::types::RolloutCandidate; /// Selects rollout candidates that need stage-1 memory extraction. /// -/// A rollout is selected when it is not the active thread and has no memory yet -/// (or the stored memory is older than the thread metadata timestamp). +/// A rollout is selected when it is not the active thread and was updated +/// within the configured max age window. pub(crate) fn select_rollout_candidates_from_db( items: &[ThreadMetadata], current_thread_id: ThreadId, - existing_memories: &[ThreadMemory], max_items: usize, + max_age_days: i64, ) -> Vec { if max_items == 0 { return Vec::new(); } - let memory_updated_by_thread = existing_memories - .iter() - .map(|memory| (memory.thread_id.to_string(), memory.updated_at)) - .collect::>(); + let cutoff = Utc::now() - Duration::days(max_age_days.max(0)); let mut candidates = Vec::new(); @@ -30,10 +27,7 @@ pub(crate) fn select_rollout_candidates_from_db( if item.id == current_thread_id { continue; } - - let memory_updated_at = memory_updated_by_thread.get(&item.id.to_string()); - if memory_updated_at.is_some_and(|memory_updated_at| *memory_updated_at >= item.updated_at) - { + if item.updated_at < cutoff { continue; } @@ -41,7 +35,6 @@ pub(crate) fn select_rollout_candidates_from_db( thread_id: item.id, rollout_path: item.rollout_path.clone(), cwd: item.cwd.clone(), - title: item.title.clone(), updated_at: Some(item.updated_at.to_rfc3339()), }); diff --git a/codex-rs/core/src/memories/storage.rs b/codex-rs/core/src/memories/storage.rs index 32809dcd9..f1dbd96f8 100644 --- a/codex-rs/core/src/memories/storage.rs +++ b/codex-rs/core/src/memories/storage.rs @@ -6,47 +6,12 @@ use std::path::PathBuf; use tracing::warn; use super::LEGACY_CONSOLIDATED_FILENAME; -use super::MAX_RAW_MEMORIES_PER_CWD; +use super::MAX_RAW_MEMORIES_PER_SCOPE; use super::MEMORY_REGISTRY_FILENAME; use super::SKILLS_SUBDIR; use super::ensure_layout; use super::memory_summary_file; use super::raw_memories_dir; -use super::types::RolloutCandidate; - -/// Writes (or replaces) the per-thread markdown raw memory on disk. -/// -/// This also removes older files for the same thread id to keep one canonical -/// raw memory file per thread. -pub(crate) async fn write_raw_memory( - root: &Path, - candidate: &RolloutCandidate, - raw_memory: &str, -) -> std::io::Result { - let slug = build_memory_slug(&candidate.title); - let filename = format!("{}_{}.md", candidate.thread_id, slug); - let path = raw_memories_dir(root).join(filename); - - remove_outdated_thread_raw_memories(root, &candidate.thread_id.to_string(), &path).await?; - - let mut body = String::new(); - writeln!(body, "thread_id: {}", candidate.thread_id) - .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; - writeln!(body, "cwd: {}", candidate.cwd.display()) - .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; - writeln!(body, "rollout_path: {}", candidate.rollout_path.display()) - .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; - if let Some(updated_at) = candidate.updated_at.as_deref() { - writeln!(body, "updated_at: {updated_at}") - .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; - } - writeln!(body).map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; - body.push_str(raw_memory.trim()); - body.push('\n'); - - tokio::fs::write(&path, body).await?; - Ok(path) -} /// Prunes stale raw memory files and rebuilds the routing summary for recent memories. pub(crate) async fn prune_to_recent_memories_and_rebuild_summary( @@ -57,7 +22,7 @@ pub(crate) async fn prune_to_recent_memories_and_rebuild_summary( let keep = memories .iter() - .take(MAX_RAW_MEMORIES_PER_CWD) + .take(MAX_RAW_MEMORIES_PER_SCOPE) .map(|memory| memory.thread_id.to_string()) .collect::>(); @@ -65,6 +30,38 @@ pub(crate) async fn prune_to_recent_memories_and_rebuild_summary( rebuild_memory_summary(root, memories).await } +/// Rebuild `memory_summary.md` for a scope without pruning raw memory files. +pub(crate) async fn rebuild_memory_summary_from_memories( + root: &Path, + memories: &[ThreadMemory], +) -> std::io::Result<()> { + ensure_layout(root).await?; + rebuild_memory_summary(root, memories).await +} + +/// Syncs canonical raw memory files from DB-backed memory rows. +pub(crate) async fn sync_raw_memories_from_memories( + root: &Path, + memories: &[ThreadMemory], +) -> std::io::Result<()> { + ensure_layout(root).await?; + + let retained = memories + .iter() + .take(MAX_RAW_MEMORIES_PER_SCOPE) + .collect::>(); + let keep = retained + .iter() + .map(|memory| memory.thread_id.to_string()) + .collect::>(); + prune_raw_memories(root, &keep).await?; + + for memory in retained { + write_raw_memory_for_thread(root, memory).await?; + } + Ok(()) +} + /// Clears consolidation outputs so a fresh consolidation run can regenerate them. /// /// Phase-1 artifacts (`raw_memories/` and `memory_summary.md`) are preserved. @@ -103,7 +100,7 @@ async fn rebuild_memory_summary(root: &Path, memories: &[ThreadMemory]) -> std:: } body.push_str("Map of concise summaries to thread IDs (latest first):\n\n"); - for memory in memories.iter().take(MAX_RAW_MEMORIES_PER_CWD) { + for memory in memories.iter().take(MAX_RAW_MEMORIES_PER_SCOPE) { let summary = compact_summary_for_index(&memory.memory_summary); writeln!(body, "- {summary} (thread: `{}`)", memory.thread_id) .map_err(|err| std::io::Error::other(format!("format memory summary: {err}")))?; @@ -179,27 +176,25 @@ async fn remove_outdated_thread_raw_memories( Ok(()) } -fn build_memory_slug(value: &str) -> String { - let mut slug = String::new(); - let mut last_was_sep = false; +async fn write_raw_memory_for_thread( + root: &Path, + memory: &ThreadMemory, +) -> std::io::Result { + let path = raw_memories_dir(root).join(format!("{}.md", memory.thread_id)); - for ch in value.chars() { - let normalized = ch.to_ascii_lowercase(); - if normalized.is_ascii_alphanumeric() { - slug.push(normalized); - last_was_sep = false; - } else if !last_was_sep { - slug.push('_'); - last_was_sep = true; - } - } + remove_outdated_thread_raw_memories(root, &memory.thread_id.to_string(), &path).await?; - let slug = slug.trim_matches('_').to_string(); - if slug.is_empty() { - "memory".to_string() - } else { - slug.chars().take(64).collect() - } + let mut body = String::new(); + writeln!(body, "thread_id: {}", memory.thread_id) + .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; + writeln!(body, "updated_at: {}", memory.updated_at.to_rfc3339()) + .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; + writeln!(body).map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; + body.push_str(memory.raw_memory.trim()); + body.push('\n'); + + tokio::fs::write(&path, body).await?; + Ok(path) } fn compact_summary_for_index(summary: &str) -> String { @@ -208,10 +203,15 @@ fn compact_summary_for_index(summary: &str) -> String { fn extract_thread_id_from_summary_filename(file_name: &str) -> Option<&str> { let stem = file_name.strip_suffix(".md")?; - let (thread_id, _) = stem.split_once('_')?; - if thread_id.is_empty() { + if stem.is_empty() { None + } else if let Some((thread_id, _legacy_slug)) = stem.split_once('_') { + if thread_id.is_empty() { + None + } else { + Some(thread_id) + } } else { - Some(thread_id) + Some(stem) } } diff --git a/codex-rs/core/src/memories/tests.rs b/codex-rs/core/src/memories/tests.rs index efd6d4360..1123a97bb 100644 --- a/codex-rs/core/src/memories/tests.rs +++ b/codex-rs/core/src/memories/tests.rs @@ -1,7 +1,10 @@ +use super::MEMORY_SCOPE_KIND_CWD; +use super::PHASE_ONE_MAX_ROLLOUT_AGE_DAYS; use super::StageOneResponseItemKinds; use super::StageOneRolloutFilter; use super::ensure_layout; use super::memory_root_for_cwd; +use super::memory_scope_key_for_cwd; use super::memory_summary_file; use super::parse_stage_one_output; use super::prune_to_recent_memories_and_rebuild_summary; @@ -77,7 +80,7 @@ fn memory_root_varies_by_cwd() { .and_then(std::path::Path::file_name) .and_then(std::ffi::OsStr::to_str) .expect("cwd bucket"); - assert_eq!(bucket_a.len(), 64); + assert_eq!(bucket_a.len(), 16); assert!(bucket_a.chars().all(|ch| ch.is_ascii_hexdigit())); } @@ -97,6 +100,22 @@ fn memory_root_encoding_avoids_component_collisions() { assert!(!root_hash.display().to_string().contains("workspace")); } +#[test] +fn memory_scope_key_uses_normalized_cwd() { + let dir = tempdir().expect("tempdir"); + let workspace = dir.path().join("workspace"); + std::fs::create_dir_all(&workspace).expect("mkdir workspace"); + std::fs::create_dir_all(workspace.join("nested")).expect("mkdir nested"); + + let alias = workspace.join("nested").join(".."); + let normalized = workspace + .canonicalize() + .expect("canonical workspace path should resolve"); + let alias_key = memory_scope_key_for_cwd(&alias); + let normalized_key = memory_scope_key_for_cwd(&normalized); + assert_eq!(alias_key, normalized_key); +} + #[test] fn parse_stage_one_output_accepts_fenced_json() { let raw = "```json\n{\"rawMemory\":\"abc\",\"summary\":\"short\"}\n```"; @@ -206,64 +225,58 @@ fn serialize_filtered_rollout_response_items_filters_by_response_item_kind() { } #[test] -fn select_rollout_candidates_uses_db_memory_recency() { +fn select_rollout_candidates_filters_by_age_window() { let dir = tempdir().expect("tempdir"); let cwd_a = dir.path().join("workspace-a"); let cwd_b = dir.path().join("workspace-b"); std::fs::create_dir_all(&cwd_a).expect("mkdir cwd a"); std::fs::create_dir_all(&cwd_b).expect("mkdir cwd b"); + let now = Utc::now().timestamp(); let current_thread_id = ThreadId::default(); - let stale_thread_id = ThreadId::default(); - let fresh_thread_id = ThreadId::default(); - let missing_thread_id = ThreadId::default(); + let recent_thread_id = ThreadId::default(); + let old_thread_id = ThreadId::default(); + let recent_two_thread_id = ThreadId::default(); let current = thread_metadata( current_thread_id, dir.path().join("current.jsonl"), cwd_a.clone(), "current", - 500, + now, ); - let fresh = thread_metadata( - fresh_thread_id, - dir.path().join("fresh.jsonl"), + let recent = thread_metadata( + recent_thread_id, + dir.path().join("recent.jsonl"), cwd_a, - "fresh", - 400, + "recent", + now - 10, ); - let stale = thread_metadata( - stale_thread_id, - dir.path().join("stale.jsonl"), + let old = thread_metadata( + old_thread_id, + dir.path().join("old.jsonl"), cwd_b.clone(), - "stale", - 300, + "old", + now - (PHASE_ONE_MAX_ROLLOUT_AGE_DAYS + 1) * 24 * 60 * 60, ); - let missing = thread_metadata( - missing_thread_id, - dir.path().join("missing.jsonl"), + let recent_two = thread_metadata( + recent_two_thread_id, + dir.path().join("recent-two.jsonl"), cwd_b, - "missing", - 200, + "recent-two", + now - 20, ); - let memories = vec![ThreadMemory { - thread_id: fresh_thread_id, - raw_memory: "raw memory".to_string(), - memory_summary: "memory".to_string(), - updated_at: Utc.timestamp_opt(450, 0).single().expect("timestamp"), - }]; - let candidates = select_rollout_candidates_from_db( - &[current, fresh, stale, missing], + &[current, recent, old, recent_two], current_thread_id, - &memories, 5, + PHASE_ONE_MAX_ROLLOUT_AGE_DAYS, ); assert_eq!(candidates.len(), 2); - assert_eq!(candidates[0].thread_id, stale_thread_id); - assert_eq!(candidates[1].thread_id, missing_thread_id); + assert_eq!(candidates[0].thread_id, recent_thread_id); + assert_eq!(candidates[1].thread_id, recent_two_thread_id); } #[tokio::test] @@ -274,8 +287,8 @@ async fn prune_and_rebuild_summary_keeps_latest_memories_only() { let keep_id = ThreadId::default().to_string(); let drop_id = ThreadId::default().to_string(); - let keep_path = raw_memories_dir(&root).join(format!("{keep_id}_keep.md")); - let drop_path = raw_memories_dir(&root).join(format!("{drop_id}_drop.md")); + let keep_path = raw_memories_dir(&root).join(format!("{keep_id}.md")); + let drop_path = raw_memories_dir(&root).join(format!("{drop_id}.md")); tokio::fs::write(&keep_path, "keep") .await .expect("write keep"); @@ -285,9 +298,15 @@ async fn prune_and_rebuild_summary_keeps_latest_memories_only() { let memories = vec![ThreadMemory { thread_id: ThreadId::try_from(keep_id.clone()).expect("thread id"), + scope_kind: MEMORY_SCOPE_KIND_CWD.to_string(), + scope_key: "scope".to_string(), raw_memory: "raw memory".to_string(), memory_summary: "short summary".to_string(), updated_at: Utc.timestamp_opt(100, 0).single().expect("timestamp"), + last_used_at: None, + used_count: 0, + invalidated_at: None, + invalid_reason: None, }]; prune_to_recent_memories_and_rebuild_summary(&root, &memories) diff --git a/codex-rs/core/src/memories/types.rs b/codex-rs/core/src/memories/types.rs index 510c0faf5..7ba66ee81 100644 --- a/codex-rs/core/src/memories/types.rs +++ b/codex-rs/core/src/memories/types.rs @@ -11,8 +11,6 @@ pub(crate) struct RolloutCandidate { pub(crate) rollout_path: PathBuf, /// Thread working directory used for per-project memory bucketing. pub(crate) cwd: PathBuf, - /// Best-effort thread title used to build readable memory filenames. - pub(crate) title: String, /// Last observed thread update timestamp (RFC3339), if available. pub(crate) updated_at: Option, } diff --git a/codex-rs/state/migrations/0010_memory_workflow_v2.sql b/codex-rs/state/migrations/0010_memory_workflow_v2.sql new file mode 100644 index 000000000..475b5f6e3 --- /dev/null +++ b/codex-rs/state/migrations/0010_memory_workflow_v2.sql @@ -0,0 +1,85 @@ +DROP TABLE IF EXISTS thread_memory; +DROP TABLE IF EXISTS memory_consolidation_locks; +DROP TABLE IF EXISTS memory_phase1_jobs; +DROP TABLE IF EXISTS memory_scope_dirty; +DROP TABLE IF EXISTS memory_phase2_jobs; + +CREATE TABLE thread_memory ( + thread_id TEXT NOT NULL, + scope_kind TEXT NOT NULL, + scope_key TEXT NOT NULL, + raw_memory TEXT NOT NULL, + memory_summary TEXT NOT NULL, + updated_at INTEGER NOT NULL, + last_used_at INTEGER, + used_count INTEGER NOT NULL DEFAULT 0, + invalidated_at INTEGER, + invalid_reason TEXT, + PRIMARY KEY (thread_id, scope_kind, scope_key), + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE +); + +CREATE INDEX idx_thread_memory_scope_last_used_at + ON thread_memory(scope_kind, scope_key, last_used_at DESC, thread_id DESC); +CREATE INDEX idx_thread_memory_scope_updated_at + ON thread_memory(scope_kind, scope_key, updated_at DESC, thread_id DESC); + +CREATE TABLE memory_phase1_jobs ( + thread_id TEXT NOT NULL, + scope_kind TEXT NOT NULL, + scope_key TEXT NOT NULL, + status TEXT NOT NULL, + owner_session_id TEXT, + started_at INTEGER, + finished_at INTEGER, + failure_reason TEXT, + source_updated_at INTEGER NOT NULL, + raw_memory_path TEXT, + summary_hash TEXT, + ownership_token TEXT, + PRIMARY KEY (thread_id, scope_kind, scope_key), + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE +); + +CREATE INDEX idx_memory_phase1_jobs_status_started_at + ON memory_phase1_jobs(status, started_at DESC); +CREATE INDEX idx_memory_phase1_jobs_scope + ON memory_phase1_jobs(scope_kind, scope_key); + +CREATE TABLE memory_scope_dirty ( + scope_kind TEXT NOT NULL, + scope_key TEXT NOT NULL, + dirty INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + PRIMARY KEY (scope_kind, scope_key) +); + +CREATE INDEX idx_memory_scope_dirty_dirty + ON memory_scope_dirty(dirty, updated_at DESC); + +CREATE TABLE memory_phase2_jobs ( + scope_kind TEXT NOT NULL, + scope_key TEXT NOT NULL, + status TEXT NOT NULL, + owner_session_id TEXT, + agent_thread_id TEXT, + started_at INTEGER, + last_heartbeat_at INTEGER, + finished_at INTEGER, + attempt INTEGER NOT NULL DEFAULT 0, + failure_reason TEXT, + ownership_token TEXT, + PRIMARY KEY (scope_kind, scope_key) +); + +CREATE INDEX idx_memory_phase2_jobs_status_heartbeat + ON memory_phase2_jobs(status, last_heartbeat_at DESC); + +CREATE TABLE memory_consolidation_locks ( + cwd TEXT PRIMARY KEY, + working_thread_id TEXT NOT NULL, + updated_at INTEGER NOT NULL +); + +CREATE INDEX idx_memory_consolidation_locks_updated_at + ON memory_consolidation_locks(updated_at DESC); diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index 1625554e2..6db552794 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -31,6 +31,7 @@ pub use model::ThreadMemory; pub use model::ThreadMetadata; pub use model::ThreadMetadataBuilder; pub use model::ThreadsPage; +pub use runtime::Phase1JobClaimOutcome; pub use runtime::STATE_DB_FILENAME; pub use runtime::STATE_DB_VERSION; pub use runtime::state_db_filename; diff --git a/codex-rs/state/src/model/thread_memory.rs b/codex-rs/state/src/model/thread_memory.rs index 63c2b85aa..b0b29ce7e 100644 --- a/codex-rs/state/src/model/thread_memory.rs +++ b/codex-rs/state/src/model/thread_memory.rs @@ -9,26 +9,44 @@ use sqlx::sqlite::SqliteRow; #[derive(Debug, Clone, PartialEq, Eq)] pub struct ThreadMemory { pub thread_id: ThreadId, + pub scope_kind: String, + pub scope_key: String, pub raw_memory: String, pub memory_summary: String, pub updated_at: DateTime, + pub last_used_at: Option>, + pub used_count: i64, + pub invalidated_at: Option>, + pub invalid_reason: Option, } #[derive(Debug)] pub(crate) struct ThreadMemoryRow { thread_id: String, + scope_kind: String, + scope_key: String, raw_memory: String, memory_summary: String, updated_at: i64, + last_used_at: Option, + used_count: i64, + invalidated_at: Option, + invalid_reason: Option, } impl ThreadMemoryRow { pub(crate) fn try_from_row(row: &SqliteRow) -> Result { Ok(Self { thread_id: row.try_get("thread_id")?, + scope_kind: row.try_get("scope_kind")?, + scope_key: row.try_get("scope_key")?, raw_memory: row.try_get("raw_memory")?, memory_summary: row.try_get("memory_summary")?, updated_at: row.try_get("updated_at")?, + last_used_at: row.try_get("last_used_at")?, + used_count: row.try_get("used_count")?, + invalidated_at: row.try_get("invalidated_at")?, + invalid_reason: row.try_get("invalid_reason")?, }) } } @@ -39,9 +57,21 @@ impl TryFrom for ThreadMemory { fn try_from(row: ThreadMemoryRow) -> std::result::Result { Ok(Self { thread_id: ThreadId::try_from(row.thread_id)?, + scope_kind: row.scope_kind, + scope_key: row.scope_key, raw_memory: row.raw_memory, memory_summary: row.memory_summary, updated_at: epoch_seconds_to_datetime(row.updated_at)?, + last_used_at: row + .last_used_at + .map(epoch_seconds_to_datetime) + .transpose()?, + used_count: row.used_count, + invalidated_at: row + .invalidated_at + .map(epoch_seconds_to_datetime) + .transpose()?, + invalid_reason: row.invalid_reason, }) } } diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index ff60c300a..d1049a891 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -36,10 +36,13 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use tracing::warn; +use uuid::Uuid; pub const STATE_DB_FILENAME: &str = "state"; pub const STATE_DB_VERSION: u32 = 4; +const MEMORY_SCOPE_KIND_CWD: &str = "cwd"; + const METRIC_DB_INIT: &str = "codex.db.init"; #[derive(Clone)] @@ -49,6 +52,14 @@ pub struct StateRuntime { pool: Arc, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Phase1JobClaimOutcome { + Claimed { ownership_token: String }, + SkippedTerminalFailure, + SkippedUpToDate, + SkippedRunning, +} + impl StateRuntime { /// Initialize the state runtime using the provided Codex home and default provider. /// @@ -237,9 +248,21 @@ ORDER BY position ASC ) -> anyhow::Result> { let row = sqlx::query( r#" -SELECT thread_id, trace_summary AS raw_memory, memory_summary, updated_at +SELECT + thread_id, + scope_kind, + scope_key, + raw_memory, + memory_summary, + updated_at, + last_used_at, + used_count, + invalidated_at, + invalid_reason FROM thread_memory WHERE thread_id = ? +ORDER BY updated_at DESC, scope_kind DESC, scope_key DESC +LIMIT 1 "#, ) .bind(thread_id.to_string()) @@ -506,7 +529,7 @@ ON CONFLICT(id) DO UPDATE SET Ok(()) } - /// Insert or update memory summaries for a thread. + /// Insert or update memory summaries for a thread in the cwd scope. /// /// This method always advances `updated_at`, even if summaries are unchanged. pub async fn upsert_thread_memory( @@ -514,6 +537,29 @@ ON CONFLICT(id) DO UPDATE SET thread_id: ThreadId, raw_memory: &str, memory_summary: &str, + ) -> anyhow::Result { + let Some(thread) = self.get_thread(thread_id).await? else { + return Err(anyhow::anyhow!("thread not found: {thread_id}")); + }; + let scope_key = thread.cwd.display().to_string(); + self.upsert_thread_memory_for_scope( + thread_id, + MEMORY_SCOPE_KIND_CWD, + scope_key.as_str(), + raw_memory, + memory_summary, + ) + .await + } + + /// Insert or update memory summaries for a thread in an explicit scope. + pub async fn upsert_thread_memory_for_scope( + &self, + thread_id: ThreadId, + scope_kind: &str, + scope_key: &str, + raw_memory: &str, + memory_summary: &str, ) -> anyhow::Result { if self.get_thread(thread_id).await?.is_none() { return Err(anyhow::anyhow!("thread not found: {thread_id}")); @@ -524,12 +570,14 @@ ON CONFLICT(id) DO UPDATE SET r#" INSERT INTO thread_memory ( thread_id, - trace_summary, + scope_kind, + scope_key, + raw_memory, memory_summary, updated_at -) VALUES (?, ?, ?, ?) -ON CONFLICT(thread_id) DO UPDATE SET - trace_summary = excluded.trace_summary, +) VALUES (?, ?, ?, ?, ?, ?) +ON CONFLICT(thread_id, scope_kind, scope_key) DO UPDATE SET + raw_memory = excluded.raw_memory, memory_summary = excluded.memory_summary, updated_at = CASE WHEN excluded.updated_at <= thread_memory.updated_at THEN thread_memory.updated_at + 1 @@ -538,22 +586,149 @@ ON CONFLICT(thread_id) DO UPDATE SET "#, ) .bind(thread_id.to_string()) + .bind(scope_kind) + .bind(scope_key) .bind(raw_memory) .bind(memory_summary) .bind(updated_at) .execute(self.pool.as_ref()) .await?; - self.get_thread_memory(thread_id) - .await? + let row = sqlx::query( + r#" +SELECT + thread_id, + scope_kind, + scope_key, + raw_memory, + memory_summary, + updated_at, + last_used_at, + used_count, + invalidated_at, + invalid_reason +FROM thread_memory +WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + "#, + ) + .bind(thread_id.to_string()) + .bind(scope_kind) + .bind(scope_key) + .fetch_optional(self.pool.as_ref()) + .await?; + + row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from)) + .transpose()? .ok_or_else(|| anyhow::anyhow!("failed to load upserted thread memory: {thread_id}")) } + /// Insert or update memory summaries for a thread/scope only if the caller + /// still owns the corresponding phase-1 running job. + pub async fn upsert_thread_memory_for_scope_if_phase1_owner( + &self, + thread_id: ThreadId, + scope_kind: &str, + scope_key: &str, + ownership_token: &str, + raw_memory: &str, + memory_summary: &str, + ) -> anyhow::Result> { + if self.get_thread(thread_id).await?.is_none() { + return Err(anyhow::anyhow!("thread not found: {thread_id}")); + } + + let updated_at = Utc::now().timestamp(); + let rows_affected = sqlx::query( + r#" +INSERT INTO thread_memory ( + thread_id, + scope_kind, + scope_key, + raw_memory, + memory_summary, + updated_at +) +SELECT ?, ?, ?, ?, ?, ? +WHERE EXISTS ( + SELECT 1 + FROM memory_phase1_jobs + WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + AND status = 'running' AND ownership_token = ? +) +ON CONFLICT(thread_id, scope_kind, scope_key) DO UPDATE SET + raw_memory = excluded.raw_memory, + memory_summary = excluded.memory_summary, + updated_at = CASE + WHEN excluded.updated_at <= thread_memory.updated_at THEN thread_memory.updated_at + 1 + ELSE excluded.updated_at + END + "#, + ) + .bind(thread_id.to_string()) + .bind(scope_kind) + .bind(scope_key) + .bind(raw_memory) + .bind(memory_summary) + .bind(updated_at) + .bind(thread_id.to_string()) + .bind(scope_kind) + .bind(scope_key) + .bind(ownership_token) + .execute(self.pool.as_ref()) + .await? + .rows_affected(); + + if rows_affected == 0 { + return Ok(None); + } + + let row = sqlx::query( + r#" +SELECT + thread_id, + scope_kind, + scope_key, + raw_memory, + memory_summary, + updated_at, + last_used_at, + used_count, + invalidated_at, + invalid_reason +FROM thread_memory +WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + "#, + ) + .bind(thread_id.to_string()) + .bind(scope_kind) + .bind(scope_key) + .fetch_optional(self.pool.as_ref()) + .await?; + + row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from)) + .transpose() + } + /// Get the last `n` memories for threads with an exact cwd match. pub async fn get_last_n_thread_memories_for_cwd( &self, cwd: &Path, n: usize, + ) -> anyhow::Result> { + self.get_last_n_thread_memories_for_scope( + MEMORY_SCOPE_KIND_CWD, + &cwd.display().to_string(), + n, + ) + .await + } + + /// Get the last `n` memories for a specific memory scope. + pub async fn get_last_n_thread_memories_for_scope( + &self, + scope_kind: &str, + scope_key: &str, + n: usize, ) -> anyhow::Result> { if n == 0 { return Ok(Vec::new()); @@ -562,18 +737,24 @@ ON CONFLICT(thread_id) DO UPDATE SET let rows = sqlx::query( r#" SELECT - m.thread_id, - m.trace_summary AS raw_memory, - m.memory_summary, - m.updated_at -FROM thread_memory AS m -INNER JOIN threads AS t ON t.id = m.thread_id -WHERE t.cwd = ? -ORDER BY m.updated_at DESC, m.thread_id DESC + thread_id, + scope_kind, + scope_key, + raw_memory, + memory_summary, + updated_at, + last_used_at, + used_count, + invalidated_at, + invalid_reason +FROM thread_memory +WHERE scope_kind = ? AND scope_key = ? AND invalidated_at IS NULL +ORDER BY updated_at DESC, thread_id DESC LIMIT ? "#, ) - .bind(cwd.display().to_string()) + .bind(scope_kind) + .bind(scope_key) .bind(n as i64) .fetch_all(self.pool.as_ref()) .await?; @@ -583,6 +764,282 @@ LIMIT ? .collect() } + /// Try to claim a phase-1 memory extraction job for `(thread, scope)`. + pub async fn try_claim_phase1_job( + &self, + thread_id: ThreadId, + scope_kind: &str, + scope_key: &str, + owner_session_id: ThreadId, + source_updated_at: i64, + lease_seconds: i64, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let stale_cutoff = now.saturating_sub(lease_seconds.max(0)); + let ownership_token = Uuid::new_v4().to_string(); + let thread_id = thread_id.to_string(); + let owner_session_id = owner_session_id.to_string(); + + let mut tx = self.pool.begin().await?; + let existing = sqlx::query( + r#" +SELECT status, source_updated_at, started_at +FROM memory_phase1_jobs +WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + "#, + ) + .bind(thread_id.as_str()) + .bind(scope_kind) + .bind(scope_key) + .fetch_optional(&mut *tx) + .await?; + + let Some(existing) = existing else { + sqlx::query( + r#" +INSERT INTO memory_phase1_jobs ( + thread_id, + scope_kind, + scope_key, + status, + owner_session_id, + started_at, + finished_at, + failure_reason, + source_updated_at, + raw_memory_path, + summary_hash, + ownership_token +) VALUES (?, ?, ?, 'running', ?, ?, NULL, NULL, ?, NULL, NULL, ?) + "#, + ) + .bind(thread_id.as_str()) + .bind(scope_kind) + .bind(scope_key) + .bind(owner_session_id.as_str()) + .bind(now) + .bind(source_updated_at) + .bind(ownership_token.as_str()) + .execute(&mut *tx) + .await?; + tx.commit().await?; + return Ok(Phase1JobClaimOutcome::Claimed { ownership_token }); + }; + + let status: String = existing.try_get("status")?; + let existing_source_updated_at: i64 = existing.try_get("source_updated_at")?; + let existing_started_at: Option = existing.try_get("started_at")?; + if status == "failed" { + tx.commit().await?; + return Ok(Phase1JobClaimOutcome::SkippedTerminalFailure); + } + if status == "succeeded" && existing_source_updated_at >= source_updated_at { + tx.commit().await?; + return Ok(Phase1JobClaimOutcome::SkippedUpToDate); + } + if status == "running" && existing_started_at.is_some_and(|started| started > stale_cutoff) + { + tx.commit().await?; + return Ok(Phase1JobClaimOutcome::SkippedRunning); + } + + let rows_affected = if let Some(existing_started_at) = existing_started_at { + sqlx::query( + r#" +UPDATE memory_phase1_jobs +SET + status = 'running', + owner_session_id = ?, + started_at = ?, + finished_at = NULL, + failure_reason = NULL, + source_updated_at = ?, + raw_memory_path = NULL, + summary_hash = NULL, + ownership_token = ? +WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + AND status = ? AND source_updated_at = ? AND started_at = ? + "#, + ) + .bind(owner_session_id.as_str()) + .bind(now) + .bind(source_updated_at) + .bind(ownership_token.as_str()) + .bind(thread_id.as_str()) + .bind(scope_kind) + .bind(scope_key) + .bind(status.as_str()) + .bind(existing_source_updated_at) + .bind(existing_started_at) + .execute(&mut *tx) + .await? + .rows_affected() + } else { + sqlx::query( + r#" +UPDATE memory_phase1_jobs +SET + status = 'running', + owner_session_id = ?, + started_at = ?, + finished_at = NULL, + failure_reason = NULL, + source_updated_at = ?, + raw_memory_path = NULL, + summary_hash = NULL, + ownership_token = ? +WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + AND status = ? AND source_updated_at = ? AND started_at IS NULL + "#, + ) + .bind(owner_session_id.as_str()) + .bind(now) + .bind(source_updated_at) + .bind(ownership_token.as_str()) + .bind(thread_id.as_str()) + .bind(scope_kind) + .bind(scope_key) + .bind(status.as_str()) + .bind(existing_source_updated_at) + .execute(&mut *tx) + .await? + .rows_affected() + }; + + tx.commit().await?; + if rows_affected == 0 { + Ok(Phase1JobClaimOutcome::SkippedRunning) + } else { + Ok(Phase1JobClaimOutcome::Claimed { ownership_token }) + } + } + + /// Finalize a claimed phase-1 job as succeeded. + pub async fn mark_phase1_job_succeeded( + &self, + thread_id: ThreadId, + scope_kind: &str, + scope_key: &str, + ownership_token: &str, + raw_memory_path: &str, + summary_hash: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let rows_affected = sqlx::query( + r#" +UPDATE memory_phase1_jobs +SET + status = 'succeeded', + finished_at = ?, + failure_reason = NULL, + raw_memory_path = ?, + summary_hash = ? +WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + AND status = 'running' AND ownership_token = ? + "#, + ) + .bind(now) + .bind(raw_memory_path) + .bind(summary_hash) + .bind(thread_id.to_string()) + .bind(scope_kind) + .bind(scope_key) + .bind(ownership_token) + .execute(self.pool.as_ref()) + .await? + .rows_affected(); + Ok(rows_affected > 0) + } + + /// Finalize a claimed phase-1 job as failed. + pub async fn mark_phase1_job_failed( + &self, + thread_id: ThreadId, + scope_kind: &str, + scope_key: &str, + ownership_token: &str, + failure_reason: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let rows_affected = sqlx::query( + r#" +UPDATE memory_phase1_jobs +SET + status = 'failed', + finished_at = ?, + failure_reason = ? +WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + AND status = 'running' AND ownership_token = ? + "#, + ) + .bind(now) + .bind(failure_reason) + .bind(thread_id.to_string()) + .bind(scope_kind) + .bind(scope_key) + .bind(ownership_token) + .execute(self.pool.as_ref()) + .await? + .rows_affected(); + Ok(rows_affected > 0) + } + + /// Refresh lease timestamp for a claimed phase-1 job. + /// + /// Returns `true` only when the current owner token still matches. + pub async fn renew_phase1_job_lease( + &self, + thread_id: ThreadId, + scope_kind: &str, + scope_key: &str, + ownership_token: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let rows_affected = sqlx::query( + r#" +UPDATE memory_phase1_jobs +SET started_at = ? +WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? + AND status = 'running' AND ownership_token = ? + "#, + ) + .bind(now) + .bind(thread_id.to_string()) + .bind(scope_kind) + .bind(scope_key) + .bind(ownership_token) + .execute(self.pool.as_ref()) + .await? + .rows_affected(); + Ok(rows_affected > 0) + } + + /// Mark a memory scope as dirty/clean for phase-2 consolidation scheduling. + pub async fn mark_memory_scope_dirty( + &self, + scope_kind: &str, + scope_key: &str, + dirty: bool, + ) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + sqlx::query( + r#" +INSERT INTO memory_scope_dirty (scope_kind, scope_key, dirty, updated_at) +VALUES (?, ?, ?, ?) +ON CONFLICT(scope_kind, scope_key) DO UPDATE SET + dirty = excluded.dirty, + updated_at = excluded.updated_at + "#, + ) + .bind(scope_kind) + .bind(scope_key) + .bind(dirty) + .bind(now) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + /// Try to acquire or renew the per-cwd memory consolidation lock. /// /// Returns `true` when the lock is acquired/renewed for `working_thread_id`. @@ -1020,6 +1477,7 @@ fn push_thread_order_and_limit( #[cfg(test)] mod tests { + use super::Phase1JobClaimOutcome; use super::STATE_DB_FILENAME; use super::STATE_DB_VERSION; use super::StateRuntime; @@ -1471,6 +1929,272 @@ mod tests { let _ = tokio::fs::remove_dir_all(codex_home).await; } + #[tokio::test] + async fn phase1_job_claim_and_success_require_current_owner_token() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner, 100, 3600) + .await + .expect("claim phase1 job"); + let ownership_token = match claim { + Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + + assert!( + !runtime + .mark_phase1_job_succeeded( + thread_id, + "cwd", + "scope", + "wrong-token", + "/tmp/path", + "summary-hash" + ) + .await + .expect("mark succeeded wrong token should fail"), + "wrong token should not finalize the job" + ); + assert!( + runtime + .mark_phase1_job_succeeded( + thread_id, + "cwd", + "scope", + ownership_token.as_str(), + "/tmp/path", + "summary-hash" + ) + .await + .expect("mark succeeded with current token"), + "current token should finalize the job" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase1_job_running_stale_can_be_stolen_but_fresh_running_is_skipped() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) + .await + .expect("upsert thread"); + + let first_claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600) + .await + .expect("first claim"); + assert!( + matches!(first_claim, Phase1JobClaimOutcome::Claimed { .. }), + "first claim should acquire" + ); + + let fresh_second_claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 3600) + .await + .expect("fresh second claim"); + assert_eq!(fresh_second_claim, Phase1JobClaimOutcome::SkippedRunning); + + let stale_second_claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0) + .await + .expect("stale second claim"); + assert!( + matches!(stale_second_claim, Phase1JobClaimOutcome::Claimed { .. }), + "stale running job should be stealable" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase1_job_lease_renewal_requires_current_owner_token() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) + .await + .expect("upsert thread"); + + let first_claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600) + .await + .expect("first claim"); + let owner_a_token = match first_claim { + Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + + let stolen_claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0) + .await + .expect("stolen claim"); + let owner_b_token = match stolen_claim { + Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + + assert!( + !runtime + .renew_phase1_job_lease(thread_id, "cwd", "scope", owner_a_token.as_str()) + .await + .expect("old owner lease renewal should fail"), + "stale owner token should not renew lease" + ); + assert!( + runtime + .renew_phase1_job_lease(thread_id, "cwd", "scope", owner_b_token.as_str()) + .await + .expect("current owner lease renewal should succeed"), + "current owner token should renew lease" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase1_owner_guarded_upsert_rejects_stale_owner() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) + .await + .expect("upsert thread"); + + let first_claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600) + .await + .expect("first claim"); + let owner_a_token = match first_claim { + Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + + let stolen_claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0) + .await + .expect("stolen claim"); + let owner_b_token = match stolen_claim { + Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + + let stale_upsert = runtime + .upsert_thread_memory_for_scope_if_phase1_owner( + thread_id, + "cwd", + "scope", + owner_a_token.as_str(), + "stale raw memory", + "stale summary", + ) + .await + .expect("stale owner upsert"); + assert!( + stale_upsert.is_none(), + "stale owner token should not upsert thread memory" + ); + + let current_upsert = runtime + .upsert_thread_memory_for_scope_if_phase1_owner( + thread_id, + "cwd", + "scope", + owner_b_token.as_str(), + "fresh raw memory", + "fresh summary", + ) + .await + .expect("current owner upsert"); + let current_upsert = current_upsert.expect("current owner should upsert"); + assert_eq!(current_upsert.raw_memory, "fresh raw memory"); + assert_eq!(current_upsert.memory_summary, "fresh summary"); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase1_job_failed_is_terminal() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600) + .await + .expect("claim"); + let ownership_token = match claim { + Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_phase1_job_failed( + thread_id, + "cwd", + "scope", + ownership_token.as_str(), + "prompt failed" + ) + .await + .expect("mark failed"), + "owner token should be able to fail job" + ); + + let second_claim = runtime + .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 101, 3600) + .await + .expect("second claim"); + assert_eq!(second_claim, Phase1JobClaimOutcome::SkippedTerminalFailure); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + #[tokio::test] async fn deleting_thread_cascades_thread_memory() { let codex_home = unique_temp_dir();