From a6e9469fa4dc19d3e30093fb8e182f9d89a94bbe Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 10 Feb 2026 20:26:39 +0000 Subject: [PATCH] chore: unify memory job flow (#11334) --- codex-rs/core/src/codex.rs | 8 +- codex-rs/core/src/codex/memory_startup.rs | 1169 ---------- codex-rs/core/src/memories/layout.rs | 59 + codex-rs/core/src/memories/mod.rs | 119 +- codex-rs/core/src/memories/prompts.rs | 36 +- codex-rs/core/src/memories/rollout.rs | 22 +- codex-rs/core/src/memories/scope.rs | 3 + codex-rs/core/src/memories/selection.rs | 47 - .../memories/{phase_one.rs => stage_one.rs} | 55 +- .../core/src/memories/startup/dispatch.rs | 221 ++ codex-rs/core/src/memories/startup/extract.rs | 150 ++ codex-rs/core/src/memories/startup/mod.rs | 352 +++ codex-rs/core/src/memories/startup/watch.rs | 188 ++ codex-rs/core/src/memories/storage.rs | 60 +- codex-rs/core/src/memories/tests.rs | 152 +- codex-rs/core/src/memories/text.rs | 50 + codex-rs/core/src/memories/types.rs | 21 +- codex-rs/core/src/state_db.rs | 58 - .../0011_generic_jobs_and_stage1_outputs.sql | 37 + codex-rs/state/src/lib.rs | 7 +- codex-rs/state/src/model/mod.rs | 6 +- codex-rs/state/src/model/stage1_output.rs | 56 + codex-rs/state/src/model/thread_memory.rs | 82 - codex-rs/state/src/runtime.rs | 1989 ++++------------- codex-rs/state/src/runtime/memory.rs | 800 +++++++ 25 files changed, 2455 insertions(+), 3292 deletions(-) delete mode 100644 codex-rs/core/src/codex/memory_startup.rs create mode 100644 codex-rs/core/src/memories/layout.rs create mode 100644 codex-rs/core/src/memories/scope.rs delete mode 100644 codex-rs/core/src/memories/selection.rs rename codex-rs/core/src/memories/{phase_one.rs => stage_one.rs} (82%) create mode 100644 codex-rs/core/src/memories/startup/dispatch.rs create mode 100644 codex-rs/core/src/memories/startup/extract.rs create mode 100644 codex-rs/core/src/memories/startup/mod.rs create mode 100644 codex-rs/core/src/memories/startup/watch.rs create mode 100644 codex-rs/core/src/memories/text.rs create mode 100644 codex-rs/state/migrations/0011_generic_jobs_and_stage1_outputs.sql create mode 100644 codex-rs/state/src/model/stage1_output.rs delete mode 100644 codex-rs/state/src/model/thread_memory.rs create mode 100644 codex-rs/state/src/runtime/memory.rs diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 69c16cc66..cd653793a 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -13,7 +13,6 @@ use crate::agent::AgentControl; use crate::agent::AgentStatus; use crate::agent::MAX_THREAD_SPAWN_DEPTH; use crate::agent::agent_status_from_event; -use crate::agent::status::is_final as is_final_agent_status; use crate::analytics_client::AnalyticsEventsClient; use crate::analytics_client::build_track_events_context; use crate::apps::render_apps_section; @@ -111,7 +110,6 @@ use crate::client::ModelClient; use crate::client::ModelClientSession; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; -use crate::client_common::ResponseStream; use crate::codex_thread::ThreadConfigSnapshot; use crate::compact::collect_user_messages; use crate::config::Config; @@ -192,10 +190,8 @@ use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; use crate::protocol::TurnDiffEvent; use crate::protocol::WarningEvent; -use crate::rollout::INTERACTIVE_SESSION_SOURCES; use crate::rollout::RolloutRecorder; use crate::rollout::RolloutRecorderParams; -use crate::rollout::list::ThreadSortKey; use crate::rollout::map_session_init_error; use crate::rollout::metadata; use crate::shell; @@ -249,8 +245,6 @@ use codex_protocol::user_input::UserInput; use codex_utils_readiness::Readiness; use codex_utils_readiness::ReadinessFlag; -mod memory_startup; - /// The high-level interface to the Codex system. /// It operates as a queue pair where you send submissions and receive events. pub struct Codex { @@ -1241,7 +1235,7 @@ impl Session { // record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted. sess.record_initial_history(initial_history).await; - memory_startup::start_memories_startup_task( + memories::start_memories_startup_task( &sess, Arc::clone(&config), &session_configuration.session_source, diff --git a/codex-rs/core/src/codex/memory_startup.rs b/codex-rs/core/src/codex/memory_startup.rs deleted file mode 100644 index f2a0bae27..000000000 --- a/codex-rs/core/src/codex/memory_startup.rs +++ /dev/null @@ -1,1169 +0,0 @@ -use super::*; -use chrono::DateTime; -use chrono::Utc; -use sha2::Digest; -use sha2::Sha256; -use std::time::Duration; - -const MEMORY_STARTUP_STAGE: &str = "run_memories_startup_pipeline"; -const PHASE_ONE_THREAD_SCAN_LIMIT: usize = 5_000; -const PHASE_ONE_DB_LOCK_RETRY_LIMIT: usize = 3; -const PHASE_ONE_DB_LOCK_RETRY_BACKOFF_MS: u64 = 25; -const PHASE_TWO_DB_LOCK_RETRY_LIMIT: usize = 3; -const PHASE_TWO_DB_LOCK_RETRY_BACKOFF_MS: u64 = 25; - -#[derive(Clone, Debug, PartialEq, Eq)] -struct MemoryScopeTarget { - scope_kind: &'static str, - scope_key: String, - memory_root: PathBuf, -} - -#[derive(Clone, Debug)] -struct ClaimedPhaseOneCandidate { - candidate: memories::RolloutCandidate, - claimed_scopes: Vec<(MemoryScopeTarget, String)>, -} - -#[derive(Clone)] -struct StageOneRequestContext { - model_info: ModelInfo, - otel_manager: OtelManager, - reasoning_effort: Option, - reasoning_summary: ReasoningSummaryConfig, - turn_metadata_header: Option, -} - -impl StageOneRequestContext { - fn from_turn_context(turn_context: &TurnContext, turn_metadata_header: Option) -> Self { - Self { - model_info: turn_context.model_info.clone(), - otel_manager: turn_context.otel_manager.clone(), - reasoning_effort: turn_context.reasoning_effort, - reasoning_summary: turn_context.reasoning_summary, - turn_metadata_header, - } - } -} - -pub(super) fn start_memories_startup_task( - session: &Arc, - config: Arc, - source: &SessionSource, -) { - if config.ephemeral - || !config.features.enabled(Feature::MemoryTool) - || matches!(source, SessionSource::SubAgent(_)) - { - return; - } - - let weak_session = Arc::downgrade(session); - tokio::spawn(async move { - let Some(session) = weak_session.upgrade() else { - return; - }; - if let Err(err) = run_memories_startup_pipeline(&session, config).await { - warn!("memories startup pipeline failed: {err}"); - } - }); -} - -pub(super) async fn run_memories_startup_pipeline( - session: &Arc, - config: Arc, -) -> CodexResult<()> { - let turn_context = session.new_default_turn().await; - - let Some(page) = state_db::list_threads_db( - session.services.state_db.as_deref(), - &config.codex_home, - PHASE_ONE_THREAD_SCAN_LIMIT, - None, - ThreadSortKey::UpdatedAt, - INTERACTIVE_SESSION_SOURCES, - None, - false, - ) - .await - else { - warn!("state db unavailable for memories startup pipeline; skipping"); - return Ok(()); - }; - - let selection_candidates = memories::select_rollout_candidates_from_db( - &page.items, - session.conversation_id, - PHASE_ONE_THREAD_SCAN_LIMIT, - memories::PHASE_ONE_MAX_ROLLOUT_AGE_DAYS, - ); - let claimed_candidates = claim_phase_one_candidates( - session, - config.as_ref(), - selection_candidates, - memories::MAX_ROLLOUTS_PER_STARTUP, - ) - .await; - info!( - "memory phase-1 candidate selection complete: {} claimed candidate(s) from {} indexed thread(s)", - claimed_candidates.len(), - page.items.len() - ); - - let touched_scope_count = if claimed_candidates.is_empty() { - 0 - } else { - let stage_one_context = StageOneRequestContext::from_turn_context( - turn_context.as_ref(), - turn_context.resolve_turn_metadata_header().await, - ); - let touched_scope_counts = futures::stream::iter(claimed_candidates.into_iter()) - .map(|claimed_candidate| { - let session = Arc::clone(session); - let stage_one_context = stage_one_context.clone(); - async move { - process_memory_candidate(session, claimed_candidate, stage_one_context).await - } - }) - .buffer_unordered(memories::PHASE_ONE_CONCURRENCY_LIMIT) - .collect::>() - .await; - touched_scope_counts.into_iter().sum::() - }; - info!( - "memory phase-1 extraction complete: {} scope(s) touched", - touched_scope_count - ); - - let dirty_scopes = - list_phase2_dirty_scopes(session, config.as_ref(), memories::MAX_ROLLOUTS_PER_STARTUP) - .await; - let consolidation_scope_count = dirty_scopes.len(); - futures::stream::iter(dirty_scopes.into_iter()) - .map(|scope| { - let session = Arc::clone(session); - let config = Arc::clone(&config); - async move { - run_memory_consolidation_for_scope(session, config, scope).await; - } - }) - .buffer_unordered(memories::PHASE_ONE_CONCURRENCY_LIMIT) - .collect::>() - .await; - info!( - "memory phase-2 consolidation dispatch complete: {} scope(s) scheduled", - consolidation_scope_count - ); - - Ok(()) -} - -async fn claim_phase_one_candidates( - session: &Session, - config: &Config, - candidates: Vec, - max_claimed_candidates: usize, -) -> Vec { - if max_claimed_candidates == 0 { - return Vec::new(); - } - - let Some(state_db) = session.services.state_db.as_deref() else { - return Vec::new(); - }; - - let mut claimed_candidates = Vec::new(); - for candidate in candidates { - if claimed_candidates.len() >= max_claimed_candidates { - break; - } - - let source_updated_at = parse_source_updated_at_epoch(&candidate); - let mut claimed_scopes = Vec::<(MemoryScopeTarget, String)>::new(); - for scope in memory_scope_targets_for_candidate(config, &candidate) { - let Some(claim) = try_claim_phase1_job_with_retry( - state_db, - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - session.conversation_id, - source_updated_at, - ) - .await - else { - continue; - }; - - if let codex_state::Phase1JobClaimOutcome::Claimed { ownership_token } = claim { - claimed_scopes.push((scope, ownership_token)); - } - } - - if !claimed_scopes.is_empty() { - claimed_candidates.push(ClaimedPhaseOneCandidate { - candidate, - claimed_scopes, - }); - } - } - - claimed_candidates -} - -async fn try_claim_phase1_job_with_retry( - state_db: &codex_state::StateRuntime, - thread_id: ThreadId, - scope_kind: &str, - scope_key: &str, - owner_session_id: ThreadId, - source_updated_at: i64, -) -> Option { - for attempt in 0..=PHASE_ONE_DB_LOCK_RETRY_LIMIT { - match state_db - .try_claim_phase1_job( - thread_id, - scope_kind, - scope_key, - owner_session_id, - source_updated_at, - memories::PHASE_ONE_JOB_LEASE_SECONDS, - ) - .await - { - Ok(claim) => return Some(claim), - Err(err) => { - let is_locked = err.to_string().contains("database is locked"); - if is_locked && attempt < PHASE_ONE_DB_LOCK_RETRY_LIMIT { - tokio::time::sleep(Duration::from_millis( - PHASE_ONE_DB_LOCK_RETRY_BACKOFF_MS * (attempt as u64 + 1), - )) - .await; - continue; - } - warn!("state db try_claim_phase1_job failed during {MEMORY_STARTUP_STAGE}: {err}"); - return None; - } - } - } - None -} - -async fn list_phase2_dirty_scopes( - session: &Session, - config: &Config, - limit: usize, -) -> Vec { - if limit == 0 { - return Vec::new(); - } - - let Some(state_db) = session.services.state_db.as_deref() else { - return Vec::new(); - }; - - let dirty_scopes = match state_db.list_dirty_memory_scopes(limit).await { - Ok(scopes) => scopes, - Err(err) => { - warn!("state db list_dirty_memory_scopes failed during {MEMORY_STARTUP_STAGE}: {err}"); - return Vec::new(); - } - }; - - dirty_scopes - .into_iter() - .filter_map(|dirty_scope| memory_scope_target_for_dirty_scope(config, dirty_scope)) - .collect() -} - -fn memory_scope_target_for_dirty_scope( - config: &Config, - dirty_scope: codex_state::DirtyMemoryScope, -) -> Option { - let scope_kind = dirty_scope.scope_kind; - let scope_key = dirty_scope.scope_key; - match scope_kind.as_str() { - memories::MEMORY_SCOPE_KIND_CWD => { - let cwd = PathBuf::from(&scope_key); - Some(MemoryScopeTarget { - scope_kind: memories::MEMORY_SCOPE_KIND_CWD, - scope_key, - memory_root: memories::memory_root_for_cwd(&config.codex_home, &cwd), - }) - } - memories::MEMORY_SCOPE_KIND_USER => { - if scope_key != memories::MEMORY_SCOPE_KEY_USER { - warn!( - "skipping unsupported user memory scope key for phase-2: {}:{}", - scope_kind, scope_key - ); - return None; - } - Some(MemoryScopeTarget { - scope_kind: memories::MEMORY_SCOPE_KIND_USER, - scope_key, - memory_root: memories::memory_root_for_user(&config.codex_home), - }) - } - _ => { - warn!( - "skipping unsupported memory scope for phase-2 consolidation: {}:{}", - scope_kind, scope_key - ); - None - } - } -} - -async fn try_claim_phase2_job_with_retry( - state_db: &codex_state::StateRuntime, - scope_kind: &str, - scope_key: &str, - owner_session_id: ThreadId, -) -> Option { - for attempt in 0..=PHASE_TWO_DB_LOCK_RETRY_LIMIT { - match state_db - .try_claim_phase2_job( - scope_kind, - scope_key, - owner_session_id, - memories::PHASE_TWO_JOB_LEASE_SECONDS, - ) - .await - { - Ok(claim) => return Some(claim), - Err(err) => { - let is_locked = err.to_string().contains("database is locked"); - if is_locked && attempt < PHASE_TWO_DB_LOCK_RETRY_LIMIT { - tokio::time::sleep(Duration::from_millis( - PHASE_TWO_DB_LOCK_RETRY_BACKOFF_MS * (attempt as u64 + 1), - )) - .await; - continue; - } - warn!("state db try_claim_phase2_job failed during {MEMORY_STARTUP_STAGE}: {err}"); - return None; - } - } - } - None -} - -async fn process_memory_candidate( - session: Arc, - claimed_candidate: ClaimedPhaseOneCandidate, - stage_one_context: StageOneRequestContext, -) -> usize { - let candidate = claimed_candidate.candidate; - let claimed_scopes = claimed_candidate.claimed_scopes; - - let mut ready_scopes = Vec::<(MemoryScopeTarget, String)>::new(); - for (scope, ownership_token) in claimed_scopes { - if let Err(err) = memories::ensure_layout(&scope.memory_root).await { - warn!( - "failed to create memory layout for scope {}:{} root={}: {err}", - scope.scope_kind, - scope.scope_key, - scope.memory_root.display() - ); - mark_phase1_job_failed_best_effort( - session.as_ref(), - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - &ownership_token, - "failed to create memory layout", - ) - .await; - continue; - } - ready_scopes.push((scope, ownership_token)); - } - if ready_scopes.is_empty() { - return 0; - } - - let (rollout_items, _thread_id, parse_errors) = - match RolloutRecorder::load_rollout_items(&candidate.rollout_path).await { - Ok(result) => result, - Err(err) => { - warn!( - "failed to load rollout {} for memories: {err}", - candidate.rollout_path.display() - ); - fail_claimed_phase_one_jobs( - &session, - &candidate, - &ready_scopes, - "failed to load rollout", - ) - .await; - return 0; - } - }; - if parse_errors > 0 { - warn!( - "rollout {} had {parse_errors} parse errors while preparing stage-1 memory input", - candidate.rollout_path.display() - ); - } - - let rollout_contents = match memories::serialize_filtered_rollout_response_items( - &rollout_items, - memories::StageOneRolloutFilter::default(), - ) { - Ok(contents) => contents, - Err(err) => { - warn!( - "failed to prepare filtered rollout payload {} for memories: {err}", - candidate.rollout_path.display() - ); - fail_claimed_phase_one_jobs( - &session, - &candidate, - &ready_scopes, - "failed to serialize filtered rollout", - ) - .await; - return 0; - } - }; - - let prompt = Prompt { - input: vec![ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: memories::build_stage_one_input_message( - &candidate.rollout_path, - &rollout_contents, - ), - }], - end_turn: None, - phase: None, - }], - tools: Vec::new(), - parallel_tool_calls: false, - base_instructions: BaseInstructions { - text: memories::RAW_MEMORY_PROMPT.to_string(), - }, - personality: None, - output_schema: Some(memories::stage_one_output_schema()), - }; - - let mut client_session = session.services.model_client.new_session(); - let mut stream = match client_session - .stream( - &prompt, - &stage_one_context.model_info, - &stage_one_context.otel_manager, - stage_one_context.reasoning_effort, - stage_one_context.reasoning_summary, - stage_one_context.turn_metadata_header.as_deref(), - ) - .await - { - Ok(stream) => stream, - Err(err) => { - warn!( - "stage-1 memory request failed for rollout {}: {err}", - candidate.rollout_path.display() - ); - fail_claimed_phase_one_jobs( - &session, - &candidate, - &ready_scopes, - "stage-1 memory request failed", - ) - .await; - return 0; - } - }; - - let output_text = match collect_response_text_until_completed(&mut stream).await { - Ok(text) => text, - Err(err) => { - warn!( - "failed while waiting for stage-1 memory response for rollout {}: {err}", - candidate.rollout_path.display() - ); - fail_claimed_phase_one_jobs( - &session, - &candidate, - &ready_scopes, - "stage-1 memory response stream failed", - ) - .await; - return 0; - } - }; - - let stage_one_output = match memories::parse_stage_one_output(&output_text) { - Ok(output) => output, - Err(err) => { - warn!( - "invalid stage-1 memory payload for rollout {}: {err}", - candidate.rollout_path.display() - ); - fail_claimed_phase_one_jobs( - &session, - &candidate, - &ready_scopes, - "invalid stage-1 memory payload", - ) - .await; - return 0; - } - }; - - let mut touched_scope_count = 0; - for (scope, ownership_token) in &ready_scopes { - if persist_phase_one_memory_for_scope( - &session, - &candidate, - scope, - ownership_token, - &stage_one_output.raw_memory, - &stage_one_output.summary, - ) - .await - { - touched_scope_count += 1; - } - } - - touched_scope_count -} - -fn parse_source_updated_at_epoch(candidate: &memories::RolloutCandidate) -> i64 { - candidate - .updated_at - .as_deref() - .and_then(|value| DateTime::parse_from_rfc3339(value).ok()) - .map(|value| value.with_timezone(&Utc).timestamp()) - .unwrap_or_else(|| Utc::now().timestamp()) -} - -fn memory_scope_targets_for_candidate( - config: &Config, - candidate: &memories::RolloutCandidate, -) -> Vec { - vec![ - MemoryScopeTarget { - scope_kind: memories::MEMORY_SCOPE_KIND_CWD, - scope_key: memories::memory_scope_key_for_cwd(&candidate.cwd), - memory_root: memories::memory_root_for_cwd(&config.codex_home, &candidate.cwd), - }, - MemoryScopeTarget { - scope_kind: memories::MEMORY_SCOPE_KIND_USER, - scope_key: memories::MEMORY_SCOPE_KEY_USER.to_string(), - memory_root: memories::memory_root_for_user(&config.codex_home), - }, - ] -} - -async fn fail_claimed_phase_one_jobs( - session: &Session, - candidate: &memories::RolloutCandidate, - claimed_scopes: &[(MemoryScopeTarget, String)], - reason: &str, -) { - for (scope, ownership_token) in claimed_scopes { - mark_phase1_job_failed_best_effort( - session, - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - reason, - ) - .await; - } -} - -async fn persist_phase_one_memory_for_scope( - session: &Session, - candidate: &memories::RolloutCandidate, - scope: &MemoryScopeTarget, - ownership_token: &str, - raw_memory: &str, - summary: &str, -) -> bool { - let Some(state_db) = session.services.state_db.as_deref() else { - mark_phase1_job_failed_best_effort( - session, - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - "state db unavailable for scoped thread memory upsert", - ) - .await; - return false; - }; - - let lease_renewed = match state_db - .renew_phase1_job_lease( - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - ) - .await - { - Ok(renewed) => renewed, - Err(err) => { - warn!("state db renew_phase1_job_lease failed during {MEMORY_STARTUP_STAGE}: {err}"); - return false; - } - }; - if !lease_renewed { - debug!( - "memory phase-1 write skipped after ownership changed: rollout={} scope={} scope_key={}", - candidate.rollout_path.display(), - scope.scope_kind, - scope.scope_key - ); - return false; - } - - let upserted = match state_db - .upsert_thread_memory_for_scope_if_phase1_owner( - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - raw_memory, - summary, - ) - .await - { - Ok(upserted) => upserted, - Err(err) => { - warn!( - "state db upsert_thread_memory_for_scope_if_phase1_owner failed during {MEMORY_STARTUP_STAGE}: {err}" - ); - mark_phase1_job_failed_best_effort( - session, - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - "failed to upsert scoped thread memory", - ) - .await; - return false; - } - }; - if upserted.is_none() { - debug!( - "memory phase-1 db upsert skipped after ownership changed: rollout={} scope={} scope_key={}", - candidate.rollout_path.display(), - scope.scope_kind, - scope.scope_key - ); - return false; - } - - let latest_memories = match state_db - .get_last_n_thread_memories_for_scope( - scope.scope_kind, - &scope.scope_key, - memories::MAX_RAW_MEMORIES_PER_SCOPE, - ) - .await - { - Ok(memories) => memories, - Err(err) => { - warn!( - "state db get_last_n_thread_memories_for_scope failed during {MEMORY_STARTUP_STAGE}: {err}" - ); - mark_phase1_job_failed_best_effort( - session, - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - "failed to read scope memories after upsert", - ) - .await; - return false; - } - }; - - if let Err(err) = - memories::sync_raw_memories_from_memories(&scope.memory_root, &latest_memories).await - { - warn!( - "failed syncing raw memories for scope {}:{} root={}: {err}", - scope.scope_kind, - scope.scope_key, - scope.memory_root.display() - ); - mark_phase1_job_failed_best_effort( - session, - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - "failed to sync scope raw memories", - ) - .await; - return false; - } - - if let Err(err) = - memories::rebuild_memory_summary_from_memories(&scope.memory_root, &latest_memories).await - { - warn!( - "failed rebuilding memory_summary for scope {}:{} root={}: {err}", - scope.scope_kind, - scope.scope_key, - scope.memory_root.display() - ); - mark_phase1_job_failed_best_effort( - session, - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - "failed to rebuild scope memory summary", - ) - .await; - return false; - } - - let mut hasher = Sha256::new(); - hasher.update(summary.as_bytes()); - let summary_hash = format!("{:x}", hasher.finalize()); - let raw_memory_path = scope - .memory_root - .join("raw_memories") - .join(format!("{}.md", candidate.thread_id)); - let marked_succeeded = match state_db - .mark_phase1_job_succeeded( - candidate.thread_id, - scope.scope_kind, - &scope.scope_key, - ownership_token, - &raw_memory_path.display().to_string(), - &summary_hash, - ) - .await - { - Ok(marked) => marked, - Err(err) => { - warn!("state db mark_phase1_job_succeeded failed during {MEMORY_STARTUP_STAGE}: {err}"); - return false; - } - }; - if !marked_succeeded { - return false; - } - - if let Err(err) = state_db - .mark_memory_scope_dirty(scope.scope_kind, &scope.scope_key, true) - .await - { - warn!("state db mark_memory_scope_dirty failed during {MEMORY_STARTUP_STAGE}: {err}"); - } - - info!( - "memory phase-1 raw memory persisted: rollout={} scope={} scope_key={} raw_memory_path={}", - candidate.rollout_path.display(), - scope.scope_kind, - scope.scope_key, - raw_memory_path.display() - ); - true -} - -async fn run_memory_consolidation_for_scope( - session: Arc, - config: Arc, - scope: MemoryScopeTarget, -) { - let Some(state_db) = session.services.state_db.as_deref() else { - warn!( - "state db unavailable for scope {}:{}; skipping consolidation", - scope.scope_kind, scope.scope_key - ); - return; - }; - - let Some(claim) = try_claim_phase2_job_with_retry( - state_db, - scope.scope_kind, - &scope.scope_key, - session.conversation_id, - ) - .await - else { - return; - }; - let ownership_token = match claim { - codex_state::Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token, - codex_state::Phase2JobClaimOutcome::SkippedNotDirty => { - debug!( - "memory phase-2 scope no longer dirty; skipping consolidation: {}:{}", - scope.scope_kind, scope.scope_key - ); - return; - } - codex_state::Phase2JobClaimOutcome::SkippedRunning => { - debug!( - "memory phase-2 job already running for scope {}:{}; skipping", - scope.scope_kind, scope.scope_key - ); - return; - } - }; - - if let Err(err) = memories::ensure_layout(&scope.memory_root).await { - warn!( - "failed to create memory layout for phase-2 scope {}:{} root={}: {err}", - scope.scope_kind, - scope.scope_key, - scope.memory_root.display() - ); - mark_phase2_job_failed_best_effort( - session.services.state_db.as_deref(), - scope.scope_kind, - &scope.scope_key, - &ownership_token, - "failed to create memory layout", - ) - .await; - return; - } - - let latest_memories = match state_db - .get_last_n_thread_memories_for_scope( - scope.scope_kind, - &scope.scope_key, - memories::MAX_RAW_MEMORIES_PER_SCOPE, - ) - .await - { - Ok(memories) => memories, - Err(err) => { - warn!( - "state db get_last_n_thread_memories_for_scope failed during {MEMORY_STARTUP_STAGE}: {err}" - ); - mark_phase2_job_failed_best_effort( - session.services.state_db.as_deref(), - scope.scope_kind, - &scope.scope_key, - &ownership_token, - "failed to read scope memories before consolidation", - ) - .await; - return; - } - }; - - let memory_root = scope.memory_root.clone(); - if let Err(err) = - memories::prune_to_recent_memories_and_rebuild_summary(&memory_root, &latest_memories).await - { - warn!( - "failed to refresh phase-1 memory outputs for scope {}:{}: {err}", - scope.scope_kind, scope.scope_key - ); - mark_phase2_job_failed_best_effort( - session.services.state_db.as_deref(), - scope.scope_kind, - &scope.scope_key, - &ownership_token, - "failed to refresh phase-1 memory outputs", - ) - .await; - return; - } - - if let Err(err) = memories::wipe_consolidation_outputs(&memory_root).await { - warn!( - "failed to wipe previous consolidation outputs for scope {}:{}: {err}", - scope.scope_kind, scope.scope_key - ); - mark_phase2_job_failed_best_effort( - session.services.state_db.as_deref(), - scope.scope_kind, - &scope.scope_key, - &ownership_token, - "failed to wipe previous consolidation outputs", - ) - .await; - return; - } - - let prompt = memories::build_consolidation_prompt(&memory_root); - let input = vec![UserInput::Text { - text: prompt, - text_elements: vec![], - }]; - let mut consolidation_config = config.as_ref().clone(); - consolidation_config.cwd = memory_root.clone(); - let source = SessionSource::SubAgent(SubAgentSource::Other( - memories::MEMORY_CONSOLIDATION_SUBAGENT_LABEL.to_string(), - )); - match session - .services - .agent_control - .spawn_agent(consolidation_config, input, Some(source)) - .await - { - Ok(consolidation_agent_id) => { - match state_db - .set_phase2_job_agent_thread_id( - scope.scope_kind, - &scope.scope_key, - &ownership_token, - consolidation_agent_id, - ) - .await - { - Ok(true) => {} - Ok(false) => { - debug!( - "memory phase-2 job lost ownership before agent registration: {}:{}", - scope.scope_kind, scope.scope_key - ); - return; - } - Err(err) => { - warn!( - "state db set_phase2_job_agent_thread_id failed during {MEMORY_STARTUP_STAGE}: {err}" - ); - } - } - info!( - "memory phase-2 consolidation agent started: scope={} scope_key={} agent_id={}", - scope.scope_kind, scope.scope_key, consolidation_agent_id - ); - spawn_phase2_completion_task( - session.as_ref(), - scope, - ownership_token, - consolidation_agent_id, - ); - } - Err(err) => { - warn!( - "failed to spawn memory consolidation agent for scope {}:{}: {err}", - scope.scope_kind, scope.scope_key - ); - mark_phase2_job_failed_best_effort( - session.services.state_db.as_deref(), - scope.scope_kind, - &scope.scope_key, - &ownership_token, - "failed to spawn consolidation agent", - ) - .await; - } - } -} - -fn spawn_phase2_completion_task( - session: &Session, - scope: MemoryScopeTarget, - ownership_token: String, - consolidation_agent_id: ThreadId, -) { - let state_db = session.services.state_db.clone(); - let agent_control = session.services.agent_control.clone(); - tokio::spawn(async move { - let Some(state_db) = state_db.as_deref() else { - return; - }; - let mut status_rx = match agent_control.subscribe_status(consolidation_agent_id).await { - Ok(status_rx) => status_rx, - Err(err) => { - warn!( - "failed to subscribe to memory consolidation agent {} for scope {}:{}: {err}", - consolidation_agent_id, scope.scope_kind, scope.scope_key - ); - if let Err(mark_err) = state_db - .mark_phase2_job_failed( - scope.scope_kind, - &scope.scope_key, - &ownership_token, - "failed to subscribe to consolidation agent status", - ) - .await - { - warn!( - "state db mark_phase2_job_failed failed during {MEMORY_STARTUP_STAGE}: {mark_err}" - ); - } - return; - } - }; - let mut heartbeat_interval = tokio::time::interval(Duration::from_secs( - memories::PHASE_TWO_JOB_HEARTBEAT_SECONDS, - )); - heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - let final_status = loop { - let status = status_rx.borrow().clone(); - if is_final_agent_status(&status) { - break status; - } - - tokio::select! { - changed = status_rx.changed() => { - if changed.is_err() { - warn!( - "lost status updates for memory consolidation agent {} in scope {}:{}", - consolidation_agent_id, scope.scope_kind, scope.scope_key - ); - break status; - } - } - _ = heartbeat_interval.tick() => { - match state_db - .heartbeat_phase2_job(scope.scope_kind, &scope.scope_key, &ownership_token) - .await - { - Ok(true) => {} - Ok(false) => { - debug!( - "memory phase-2 heartbeat lost ownership for scope {}:{}; skipping finalization", - scope.scope_kind, scope.scope_key - ); - return; - } - Err(err) => { - warn!("state db heartbeat_phase2_job failed during {MEMORY_STARTUP_STAGE}: {err}"); - return; - } - } - } - } - }; - - let phase2_succeeded = matches!(&final_status, AgentStatus::Completed(_)); - if phase2_succeeded { - match state_db - .mark_phase2_job_succeeded(scope.scope_kind, &scope.scope_key, &ownership_token) - .await - { - Ok(true) => {} - Ok(false) => { - debug!( - "memory phase-2 success finalization skipped after ownership changed: scope={} scope_key={}", - scope.scope_kind, scope.scope_key - ); - } - Err(err) => { - warn!( - "state db mark_phase2_job_succeeded failed during {MEMORY_STARTUP_STAGE}: {err}" - ); - } - } - info!( - "memory phase-2 consolidation agent finished: scope={} scope_key={} agent_id={} final_status={final_status:?}", - scope.scope_kind, scope.scope_key, consolidation_agent_id - ); - return; - } - - let failure_reason = format!("consolidation agent finished with status {final_status:?}"); - match state_db - .mark_phase2_job_failed( - scope.scope_kind, - &scope.scope_key, - &ownership_token, - &failure_reason, - ) - .await - { - Ok(true) => {} - Ok(false) => { - debug!( - "memory phase-2 failure finalization skipped after ownership changed: scope={} scope_key={}", - scope.scope_kind, scope.scope_key - ) - } - Err(err) => { - warn!( - "state db mark_phase2_job_failed failed during {MEMORY_STARTUP_STAGE}: {err}" - ); - } - } - warn!( - "memory phase-2 consolidation agent finished with non-success status: scope={} scope_key={} agent_id={} final_status={final_status:?}", - scope.scope_kind, scope.scope_key, consolidation_agent_id - ); - }); -} - -async fn mark_phase1_job_failed_best_effort( - session: &Session, - thread_id: ThreadId, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - failure_reason: &str, -) { - let Some(state_db) = session.services.state_db.as_deref() else { - return; - }; - if let Err(err) = state_db - .mark_phase1_job_failed( - thread_id, - scope_kind, - scope_key, - ownership_token, - failure_reason, - ) - .await - { - warn!("state db mark_phase1_job_failed failed during {MEMORY_STARTUP_STAGE}: {err}"); - } -} - -async fn mark_phase2_job_failed_best_effort( - state_db: Option<&codex_state::StateRuntime>, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - failure_reason: &str, -) { - let Some(state_db) = state_db else { - return; - }; - if let Err(err) = state_db - .mark_phase2_job_failed(scope_kind, scope_key, ownership_token, failure_reason) - .await - { - warn!("state db mark_phase2_job_failed failed during {MEMORY_STARTUP_STAGE}: {err}"); - } -} - -async fn collect_response_text_until_completed(stream: &mut ResponseStream) -> CodexResult { - let mut output_text = String::new(); - - loop { - let Some(event) = stream.next().await else { - return Err(CodexErr::Stream( - "stream closed before response.completed".to_string(), - None, - )); - }; - - match event? { - ResponseEvent::OutputTextDelta(delta) => output_text.push_str(&delta), - ResponseEvent::OutputItemDone(item) => { - if output_text.is_empty() - && let ResponseItem::Message { content, .. } = item - && let Some(text) = crate::compact::content_items_to_text(&content) - { - output_text.push_str(&text); - } - } - ResponseEvent::Completed { .. } => return Ok(output_text), - _ => {} - } - } -} diff --git a/codex-rs/core/src/memories/layout.rs b/codex-rs/core/src/memories/layout.rs new file mode 100644 index 000000000..8df5d2340 --- /dev/null +++ b/codex-rs/core/src/memories/layout.rs @@ -0,0 +1,59 @@ +use crate::path_utils::normalize_for_path_comparison; +use sha2::Digest; +use sha2::Sha256; +use std::path::Path; +use std::path::PathBuf; + +use super::scope::MEMORY_SCOPE_KEY_USER; + +pub(super) const MEMORY_SUBDIR: &str = "memory"; +pub(super) const RAW_MEMORIES_SUBDIR: &str = "raw_memories"; +pub(super) const MEMORY_SUMMARY_FILENAME: &str = "memory_summary.md"; +pub(super) const MEMORY_REGISTRY_FILENAME: &str = "MEMORY.md"; +pub(super) const LEGACY_CONSOLIDATED_FILENAME: &str = "consolidated.md"; +pub(super) const SKILLS_SUBDIR: &str = "skills"; + +const CWD_MEMORY_BUCKET_HEX_LEN: usize = 16; + +/// Returns the on-disk memory root directory for a given working directory. +/// +/// The cwd is normalized and hashed into a deterministic bucket under +/// `/memories//memory`. +pub(super) fn memory_root_for_cwd(codex_home: &Path, cwd: &Path) -> PathBuf { + let bucket = memory_bucket_for_cwd(cwd); + codex_home.join("memories").join(bucket).join(MEMORY_SUBDIR) +} + +/// Returns the on-disk user-shared memory root directory. +pub(super) fn memory_root_for_user(codex_home: &Path) -> PathBuf { + codex_home + .join("memories") + .join(MEMORY_SCOPE_KEY_USER) + .join(MEMORY_SUBDIR) +} + +pub(super) fn raw_memories_dir(root: &Path) -> PathBuf { + root.join(RAW_MEMORIES_SUBDIR) +} + +pub(super) fn memory_summary_file(root: &Path) -> PathBuf { + root.join(MEMORY_SUMMARY_FILENAME) +} + +/// Ensures the phase-1 memory directory layout exists for the given root. +pub(super) async fn ensure_layout(root: &Path) -> std::io::Result<()> { + tokio::fs::create_dir_all(raw_memories_dir(root)).await +} + +fn memory_bucket_for_cwd(cwd: &Path) -> String { + let normalized = normalize_cwd_for_memory(cwd); + let normalized = normalized.to_string_lossy(); + let mut hasher = Sha256::new(); + hasher.update(normalized.as_bytes()); + let full_hash = format!("{:x}", hasher.finalize()); + full_hash[..CWD_MEMORY_BUCKET_HEX_LEN].to_string() +} + +fn normalize_cwd_for_memory(cwd: &Path) -> PathBuf { + normalize_for_path_comparison(cwd).unwrap_or_else(|_| cwd.to_path_buf()) +} diff --git a/codex-rs/core/src/memories/mod.rs b/codex-rs/core/src/memories/mod.rs index 2767a85d4..71ea14afa 100644 --- a/codex-rs/core/src/memories/mod.rs +++ b/codex-rs/core/src/memories/mod.rs @@ -1,109 +1,46 @@ -mod phase_one; +//! Memory subsystem for startup extraction and consolidation. +//! +//! The startup memory pipeline is split into two phases: +//! - Phase 1: select rollouts, extract stage-1 raw memories, persist stage-1 outputs, and enqueue consolidation. +//! - Phase 2: claim scopes, materialize consolidation inputs, and dispatch consolidation agents. + +mod layout; mod prompts; mod rollout; -mod selection; +mod scope; +mod stage_one; +mod startup; mod storage; +mod text; mod types; #[cfg(test)] mod tests; -use crate::path_utils::normalize_for_path_comparison; -use sha2::Digest; -use sha2::Sha256; -use std::path::Path; -use std::path::PathBuf; - /// Subagent source label used to identify consolidation tasks. -pub(crate) const MEMORY_CONSOLIDATION_SUBAGENT_LABEL: &str = "memory_consolidation"; +const MEMORY_CONSOLIDATION_SUBAGENT_LABEL: &str = "memory_consolidation"; /// Maximum number of rollout candidates processed per startup pass. -pub(crate) const MAX_ROLLOUTS_PER_STARTUP: usize = 64; +const MAX_ROLLOUTS_PER_STARTUP: usize = 64; /// Concurrency cap for startup memory extraction and consolidation scheduling. -pub(crate) const PHASE_ONE_CONCURRENCY_LIMIT: usize = MAX_ROLLOUTS_PER_STARTUP; +const PHASE_ONE_CONCURRENCY_LIMIT: usize = MAX_ROLLOUTS_PER_STARTUP; +/// Concurrency cap for phase-2 consolidation dispatch. +const PHASE_TWO_CONCURRENCY_LIMIT: usize = MAX_ROLLOUTS_PER_STARTUP; /// Maximum number of recent raw memories retained per scope. -pub(crate) const MAX_RAW_MEMORIES_PER_SCOPE: usize = 64; +const MAX_RAW_MEMORIES_PER_SCOPE: usize = 64; /// Maximum rollout age considered for phase-1 extraction. -pub(crate) const PHASE_ONE_MAX_ROLLOUT_AGE_DAYS: i64 = 30; +const PHASE_ONE_MAX_ROLLOUT_AGE_DAYS: i64 = 30; /// Lease duration (seconds) for phase-1 job ownership. -pub(crate) const PHASE_ONE_JOB_LEASE_SECONDS: i64 = 3_600; +const PHASE_ONE_JOB_LEASE_SECONDS: i64 = 3_600; +/// Backoff delay (seconds) before retrying a failed stage-1 extraction job. +const PHASE_ONE_JOB_RETRY_DELAY_SECONDS: i64 = 3_600; /// Lease duration (seconds) for phase-2 consolidation job ownership. -pub(crate) const PHASE_TWO_JOB_LEASE_SECONDS: i64 = 3_600; +const PHASE_TWO_JOB_LEASE_SECONDS: i64 = 3_600; +/// Backoff delay (seconds) before retrying a failed phase-2 consolidation job. +const PHASE_TWO_JOB_RETRY_DELAY_SECONDS: i64 = 3_600; /// Heartbeat interval (seconds) for phase-2 running jobs. -pub(crate) const PHASE_TWO_JOB_HEARTBEAT_SECONDS: u64 = 30; -pub(crate) const MEMORY_SCOPE_KIND_CWD: &str = "cwd"; -pub(crate) const MEMORY_SCOPE_KIND_USER: &str = "user"; -pub(crate) const MEMORY_SCOPE_KEY_USER: &str = "user"; +const PHASE_TWO_JOB_HEARTBEAT_SECONDS: u64 = 30; -const MEMORY_SUBDIR: &str = "memory"; -const RAW_MEMORIES_SUBDIR: &str = "raw_memories"; -const MEMORY_SUMMARY_FILENAME: &str = "memory_summary.md"; -const MEMORY_REGISTRY_FILENAME: &str = "MEMORY.md"; -const LEGACY_CONSOLIDATED_FILENAME: &str = "consolidated.md"; -const SKILLS_SUBDIR: &str = "skills"; -const CWD_MEMORY_BUCKET_HEX_LEN: usize = 16; - -pub(crate) use phase_one::RAW_MEMORY_PROMPT; -pub(crate) use phase_one::parse_stage_one_output; -pub(crate) use phase_one::stage_one_output_schema; -pub(crate) use prompts::build_consolidation_prompt; -pub(crate) use prompts::build_stage_one_input_message; -#[cfg(test)] -pub(crate) use rollout::StageOneResponseItemKinds; -pub(crate) use rollout::StageOneRolloutFilter; -pub(crate) use rollout::serialize_filtered_rollout_response_items; -pub(crate) use selection::select_rollout_candidates_from_db; -pub(crate) use storage::prune_to_recent_memories_and_rebuild_summary; -pub(crate) use storage::rebuild_memory_summary_from_memories; -pub(crate) use storage::sync_raw_memories_from_memories; -pub(crate) use storage::wipe_consolidation_outputs; -pub(crate) use types::RolloutCandidate; - -/// Returns the on-disk memory root directory for a given working directory. +/// Starts the memory startup pipeline for eligible root sessions. /// -/// The cwd is normalized and hashed into a deterministic bucket under -/// `/memories//memory`. -pub(crate) fn memory_root_for_cwd(codex_home: &Path, cwd: &Path) -> PathBuf { - let bucket = memory_bucket_for_cwd(cwd); - codex_home.join("memories").join(bucket).join(MEMORY_SUBDIR) -} - -/// Returns the DB scope key for a cwd-scoped memory entry. -/// -/// This uses the same normalization/fallback behavior as cwd bucket derivation. -pub(crate) fn memory_scope_key_for_cwd(cwd: &Path) -> String { - normalize_cwd_for_memory(cwd).display().to_string() -} - -/// Returns the on-disk user-shared memory root directory. -pub(crate) fn memory_root_for_user(codex_home: &Path) -> PathBuf { - codex_home - .join("memories") - .join(MEMORY_SCOPE_KEY_USER) - .join(MEMORY_SUBDIR) -} - -fn raw_memories_dir(root: &Path) -> PathBuf { - root.join(RAW_MEMORIES_SUBDIR) -} - -fn memory_summary_file(root: &Path) -> PathBuf { - root.join(MEMORY_SUMMARY_FILENAME) -} - -/// Ensures the phase-1 memory directory layout exists for the given root. -pub(crate) async fn ensure_layout(root: &Path) -> std::io::Result<()> { - tokio::fs::create_dir_all(raw_memories_dir(root)).await -} - -fn memory_bucket_for_cwd(cwd: &Path) -> String { - let normalized = normalize_cwd_for_memory(cwd); - let normalized = normalized.to_string_lossy(); - let mut hasher = Sha256::new(); - hasher.update(normalized.as_bytes()); - let full_hash = format!("{:x}", hasher.finalize()); - full_hash[..CWD_MEMORY_BUCKET_HEX_LEN].to_string() -} - -fn normalize_cwd_for_memory(cwd: &Path) -> PathBuf { - normalize_for_path_comparison(cwd).unwrap_or_else(|_| cwd.to_path_buf()) -} +/// This is the single entrypoint that `codex` uses to trigger memory startup. +pub(crate) use startup::start_memories_startup_task; diff --git a/codex-rs/core/src/memories/prompts.rs b/codex-rs/core/src/memories/prompts.rs index 35d951add..c3faf1287 100644 --- a/codex-rs/core/src/memories/prompts.rs +++ b/codex-rs/core/src/memories/prompts.rs @@ -2,6 +2,9 @@ use askama::Template; use std::path::Path; use tracing::warn; +use super::text::prefix_at_char_boundary; +use super::text::suffix_at_char_boundary; + const MAX_ROLLOUT_BYTES_FOR_PROMPT: usize = 1_000_000; #[derive(Template)] @@ -20,7 +23,7 @@ struct StageOneInputTemplate<'a> { /// Builds the consolidation subagent prompt for a specific memory root. /// /// Falls back to a simple string replacement if Askama rendering fails. -pub(crate) fn build_consolidation_prompt(memory_root: &Path) -> String { +pub(super) fn build_consolidation_prompt(memory_root: &Path) -> String { let memory_root = memory_root.display().to_string(); let template = ConsolidationPromptTemplate { memory_root: &memory_root, @@ -39,7 +42,7 @@ pub(crate) fn build_consolidation_prompt(memory_root: &Path) -> String { /// /// Large rollout payloads are truncated to a bounded byte budget while keeping /// both head and tail context. -pub(crate) fn build_stage_one_input_message(rollout_path: &Path, rollout_contents: &str) -> String { +pub(super) fn build_stage_one_input_message(rollout_path: &Path, rollout_contents: &str) -> String { let (rollout_contents, truncated) = truncate_rollout_for_prompt(rollout_contents); if truncated { warn!( @@ -82,35 +85,6 @@ fn truncate_rollout_for_prompt(input: &str) -> (String, bool) { (truncated, true) } -fn prefix_at_char_boundary(input: &str, max_bytes: usize) -> &str { - if max_bytes >= input.len() { - return input; - } - let mut end = 0; - for (idx, _) in input.char_indices() { - if idx > max_bytes { - break; - } - end = idx; - } - &input[..end] -} - -fn suffix_at_char_boundary(input: &str, max_bytes: usize) -> &str { - if max_bytes >= input.len() { - return input; - } - let start_limit = input.len().saturating_sub(max_bytes); - let mut start = input.len(); - for (idx, _) in input.char_indices().rev() { - if idx < start_limit { - break; - } - start = idx; - } - &input[start..] -} - #[cfg(test)] mod tests { use super::*; diff --git a/codex-rs/core/src/memories/rollout.rs b/codex-rs/core/src/memories/rollout.rs index aa036f87e..1f126c4e4 100644 --- a/codex-rs/core/src/memories/rollout.rs +++ b/codex-rs/core/src/memories/rollout.rs @@ -5,7 +5,7 @@ use codex_protocol::protocol::RolloutItem; /// Bitmask selector for `ResponseItem` variants retained from rollout JSONL. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) struct StageOneResponseItemKinds(u16); +pub(super) struct StageOneResponseItemKinds(u16); impl StageOneResponseItemKinds { const MESSAGE: u16 = 1 << 0; @@ -20,7 +20,7 @@ impl StageOneResponseItemKinds { const COMPACTION: u16 = 1 << 9; const OTHER: u16 = 1 << 10; - pub(crate) const fn all() -> Self { + pub(super) const fn all() -> Self { Self( Self::MESSAGE | Self::REASONING @@ -37,7 +37,7 @@ impl StageOneResponseItemKinds { } #[cfg(test)] - pub(crate) const fn messages_only() -> Self { + pub(super) const fn messages_only() -> Self { Self(Self::MESSAGE) } @@ -72,19 +72,19 @@ impl Default for StageOneResponseItemKinds { /// Controls which rollout item kinds are retained for stage-1 memory extraction. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) struct StageOneRolloutFilter { +pub(super) struct StageOneRolloutFilter { /// Keep `RolloutItem::ResponseItem` entries. - pub(crate) keep_response_items: bool, + pub(super) keep_response_items: bool, /// Keep `RolloutItem::Compacted` entries (converted to assistant messages). - pub(crate) keep_compacted_items: bool, + pub(super) keep_compacted_items: bool, /// Restricts kept `ResponseItem` entries by variant. - pub(crate) response_item_kinds: StageOneResponseItemKinds, + pub(super) response_item_kinds: StageOneResponseItemKinds, /// Optional cap on retained items after filtering. - pub(crate) max_items: Option, + pub(super) max_items: Option, } impl StageOneRolloutFilter { - pub(crate) const fn response_and_compacted_items() -> Self { + pub(super) const fn response_and_compacted_items() -> Self { Self { keep_response_items: true, keep_compacted_items: true, @@ -104,7 +104,7 @@ impl Default for StageOneRolloutFilter { /// /// `RolloutItem::Compacted` entries are converted to assistant messages so the /// model sees the same response-item shape as normal transcript content. -pub(crate) fn filter_rollout_response_items( +pub(super) fn filter_rollout_response_items( items: &[RolloutItem], filter: StageOneRolloutFilter, ) -> Vec { @@ -139,7 +139,7 @@ pub(crate) fn filter_rollout_response_items( } /// Serializes filtered stage-1 memory items for prompt inclusion. -pub(crate) fn serialize_filtered_rollout_response_items( +pub(super) fn serialize_filtered_rollout_response_items( items: &[RolloutItem], filter: StageOneRolloutFilter, ) -> Result { diff --git a/codex-rs/core/src/memories/scope.rs b/codex-rs/core/src/memories/scope.rs new file mode 100644 index 000000000..b29bc67f9 --- /dev/null +++ b/codex-rs/core/src/memories/scope.rs @@ -0,0 +1,3 @@ +pub(super) const MEMORY_SCOPE_KIND_CWD: &str = "cwd"; +pub(super) const MEMORY_SCOPE_KIND_USER: &str = "user"; +pub(super) const MEMORY_SCOPE_KEY_USER: &str = "user"; diff --git a/codex-rs/core/src/memories/selection.rs b/codex-rs/core/src/memories/selection.rs deleted file mode 100644 index 9b0814943..000000000 --- a/codex-rs/core/src/memories/selection.rs +++ /dev/null @@ -1,47 +0,0 @@ -use chrono::Duration; -use chrono::Utc; -use codex_protocol::ThreadId; -use codex_state::ThreadMetadata; - -use super::types::RolloutCandidate; - -/// Selects rollout candidates that need stage-1 memory extraction. -/// -/// A rollout is selected when it is not the active thread and was updated -/// within the configured max age window. -pub(crate) fn select_rollout_candidates_from_db( - items: &[ThreadMetadata], - current_thread_id: ThreadId, - max_items: usize, - max_age_days: i64, -) -> Vec { - if max_items == 0 { - return Vec::new(); - } - - let cutoff = Utc::now() - Duration::days(max_age_days.max(0)); - - let mut candidates = Vec::new(); - - for item in items { - if item.id == current_thread_id { - continue; - } - if item.updated_at < cutoff { - continue; - } - - candidates.push(RolloutCandidate { - thread_id: item.id, - rollout_path: item.rollout_path.clone(), - cwd: item.cwd.clone(), - updated_at: Some(item.updated_at.to_rfc3339()), - }); - - if candidates.len() >= max_items { - break; - } - } - - candidates -} diff --git a/codex-rs/core/src/memories/phase_one.rs b/codex-rs/core/src/memories/stage_one.rs similarity index 82% rename from codex-rs/core/src/memories/phase_one.rs rename to codex-rs/core/src/memories/stage_one.rs index 544b20bff..0c5540f24 100644 --- a/codex-rs/core/src/memories/phase_one.rs +++ b/codex-rs/core/src/memories/stage_one.rs @@ -5,10 +5,12 @@ use regex::Regex; use serde_json::Value; use serde_json::json; +use super::text::compact_whitespace; +use super::text::truncate_text_for_storage; use super::types::StageOneOutput; /// System prompt for stage-1 raw memory extraction. -pub(crate) const RAW_MEMORY_PROMPT: &str = +pub(super) const RAW_MEMORY_PROMPT: &str = include_str!("../../templates/memories/stage_one_system.md"); const MAX_STAGE_ONE_RAW_MEMORY_CHARS: usize = 300_000; const MAX_STAGE_ONE_SUMMARY_CHARS: usize = 1_200; @@ -22,7 +24,7 @@ static SECRET_ASSIGNMENT_REGEX: Lazy = Lazy::new(|| { }); /// JSON schema used to constrain stage-1 model output. -pub(crate) fn stage_one_output_schema() -> Value { +pub(super) fn stage_one_output_schema() -> Value { json!({ "type": "object", "properties": { @@ -38,7 +40,7 @@ pub(crate) fn stage_one_output_schema() -> Value { /// /// Accepts plain JSON objects, fenced JSON, and object snippets embedded in /// extra text, then enforces redaction and size limits. -pub(crate) fn parse_stage_one_output(raw: &str) -> Result { +pub(super) fn parse_stage_one_output(raw: &str) -> Result { let parsed = parse_json_object_loose(raw)?; let output: StageOneOutput = serde_json::from_value(parsed).map_err(|err| { CodexErr::InvalidRequest(format!("invalid stage-1 memory output JSON payload: {err}")) @@ -91,35 +93,6 @@ fn parse_json_object_loose(raw: &str) -> Result { )) } -fn prefix_at_char_boundary(input: &str, max_bytes: usize) -> &str { - if max_bytes >= input.len() { - return input; - } - let mut end = 0; - for (idx, _) in input.char_indices() { - if idx > max_bytes { - break; - } - end = idx; - } - &input[..end] -} - -fn suffix_at_char_boundary(input: &str, max_bytes: usize) -> &str { - if max_bytes >= input.len() { - return input; - } - let start_limit = input.len().saturating_sub(max_bytes); - let mut start = input.len(); - for (idx, _) in input.char_indices().rev() { - if idx < start_limit { - break; - } - start = idx; - } - &input[start..] -} - fn normalize_stage_one_output(mut output: StageOneOutput) -> Result { output.raw_memory = output.raw_memory.trim().to_string(); output.summary = output.summary.trim().to_string(); @@ -157,10 +130,6 @@ fn normalize_stage_one_output(mut output: StageOneOutput) -> Result String { - input.split_whitespace().collect::>().join(" ") -} - fn redact_secrets(input: &str) -> String { let redacted = OPENAI_KEY_REGEX.replace_all(input, "[REDACTED_SECRET]"); let redacted = AWS_ACCESS_KEY_ID_REGEX.replace_all(&redacted, "[REDACTED_SECRET]"); @@ -204,20 +173,6 @@ fn has_raw_memory_structure(input: &str) -> bool { && trimmed.contains("Outcome:") } -fn truncate_text_for_storage(input: &str, max_bytes: usize, marker: &str) -> String { - if input.len() <= max_bytes { - return input.to_string(); - } - - let budget_without_marker = max_bytes.saturating_sub(marker.len()); - let head_budget = budget_without_marker / 2; - let tail_budget = budget_without_marker.saturating_sub(head_budget); - let head = prefix_at_char_boundary(input, head_budget); - let tail = suffix_at_char_boundary(input, tail_budget); - - format!("{head}{marker}{tail}") -} - fn compile_regex(pattern: &str) -> Regex { match Regex::new(pattern) { Ok(regex) => regex, diff --git a/codex-rs/core/src/memories/startup/dispatch.rs b/codex-rs/core/src/memories/startup/dispatch.rs new file mode 100644 index 000000000..c97bae80a --- /dev/null +++ b/codex-rs/core/src/memories/startup/dispatch.rs @@ -0,0 +1,221 @@ +use crate::codex::Session; +use crate::config::Config; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; +use codex_protocol::user_input::UserInput; +use std::sync::Arc; +use tracing::debug; +use tracing::info; +use tracing::warn; + +use super::super::MAX_RAW_MEMORIES_PER_SCOPE; +use super::super::MEMORY_CONSOLIDATION_SUBAGENT_LABEL; +use super::super::PHASE_TWO_JOB_LEASE_SECONDS; +use super::super::PHASE_TWO_JOB_RETRY_DELAY_SECONDS; +use super::super::prompts::build_consolidation_prompt; +use super::super::storage::rebuild_memory_summary_from_memories; +use super::super::storage::sync_raw_memories_from_memories; +use super::super::storage::wipe_consolidation_outputs; +use super::MemoryScopeTarget; +use super::watch::spawn_phase2_completion_task; + +pub(super) async fn run_memory_consolidation_for_scope( + session: Arc, + config: Arc, + scope: MemoryScopeTarget, +) { + let Some(state_db) = session.services.state_db.as_deref() else { + warn!( + "state db unavailable for scope {}:{}; skipping consolidation", + scope.scope_kind, scope.scope_key + ); + return; + }; + + let claim = match state_db + .try_claim_phase2_job( + scope.scope_kind, + &scope.scope_key, + session.conversation_id, + PHASE_TWO_JOB_LEASE_SECONDS, + ) + .await + { + Ok(claim) => claim, + Err(err) => { + warn!( + "state db try_claim_phase2_job failed for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key + ); + return; + } + }; + let (ownership_token, claimed_watermark) = match claim { + codex_state::Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + codex_state::Phase2JobClaimOutcome::SkippedNotDirty => { + debug!( + "memory phase-2 scope not pending (or already up to date); skipping consolidation: {}:{}", + scope.scope_kind, scope.scope_key + ); + return; + } + codex_state::Phase2JobClaimOutcome::SkippedRunning => { + debug!( + "memory phase-2 job already running for scope {}:{}; skipping", + scope.scope_kind, scope.scope_key + ); + return; + } + }; + + let latest_memories = match state_db + .list_stage1_outputs_for_scope( + scope.scope_kind, + &scope.scope_key, + MAX_RAW_MEMORIES_PER_SCOPE, + ) + .await + { + Ok(memories) => memories, + Err(err) => { + warn!( + "state db list_stage1_outputs_for_scope failed during consolidation for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key + ); + let _ = state_db + .mark_phase2_job_failed( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + "failed to read scope stage-1 outputs before consolidation", + PHASE_TWO_JOB_RETRY_DELAY_SECONDS, + ) + .await; + return; + } + }; + if latest_memories.is_empty() { + debug!( + "memory phase-2 scope has no stage-1 outputs; skipping consolidation: {}:{}", + scope.scope_kind, scope.scope_key + ); + let _ = state_db + .mark_phase2_job_succeeded( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + claimed_watermark, + ) + .await; + return; + }; + + let materialized_watermark = latest_memories + .iter() + .map(|memory| memory.source_updated_at.timestamp()) + .max() + .unwrap_or(claimed_watermark); + + if let Err(err) = sync_raw_memories_from_memories(&scope.memory_root, &latest_memories).await { + warn!( + "failed syncing phase-1 raw memories for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key + ); + let _ = state_db + .mark_phase2_job_failed( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + "failed syncing phase-1 raw memories", + PHASE_TWO_JOB_RETRY_DELAY_SECONDS, + ) + .await; + return; + } + + if let Err(err) = + rebuild_memory_summary_from_memories(&scope.memory_root, &latest_memories).await + { + warn!( + "failed rebuilding memory summary for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key + ); + let _ = state_db + .mark_phase2_job_failed( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + "failed rebuilding memory summary", + PHASE_TWO_JOB_RETRY_DELAY_SECONDS, + ) + .await; + return; + } + + if let Err(err) = wipe_consolidation_outputs(&scope.memory_root).await { + warn!( + "failed to wipe previous consolidation outputs for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key + ); + let _ = state_db + .mark_phase2_job_failed( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + "failed to wipe previous consolidation outputs", + PHASE_TWO_JOB_RETRY_DELAY_SECONDS, + ) + .await; + return; + } + + let prompt = build_consolidation_prompt(&scope.memory_root); + let input = vec![UserInput::Text { + text: prompt, + text_elements: vec![], + }]; + let mut consolidation_config = config.as_ref().clone(); + consolidation_config.cwd = scope.memory_root.clone(); + let source = SessionSource::SubAgent(SubAgentSource::Other( + MEMORY_CONSOLIDATION_SUBAGENT_LABEL.to_string(), + )); + + match session + .services + .agent_control + .spawn_agent(consolidation_config, input, Some(source)) + .await + { + Ok(consolidation_agent_id) => { + info!( + "memory phase-2 consolidation agent started: scope={} scope_key={} agent_id={}", + scope.scope_kind, scope.scope_key, consolidation_agent_id + ); + spawn_phase2_completion_task( + session.as_ref(), + scope, + ownership_token, + materialized_watermark, + consolidation_agent_id, + ); + } + Err(err) => { + warn!( + "failed to spawn memory consolidation agent for scope {}:{}: {err}", + scope.scope_kind, scope.scope_key + ); + let _ = state_db + .mark_phase2_job_failed( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + "failed to spawn consolidation agent", + PHASE_TWO_JOB_RETRY_DELAY_SECONDS, + ) + .await; + } + } +} diff --git a/codex-rs/core/src/memories/startup/extract.rs b/codex-rs/core/src/memories/startup/extract.rs new file mode 100644 index 000000000..dd5f25854 --- /dev/null +++ b/codex-rs/core/src/memories/startup/extract.rs @@ -0,0 +1,150 @@ +use crate::client_common::Prompt; +use crate::client_common::ResponseEvent; +use crate::client_common::ResponseStream; +use crate::codex::Session; +use crate::error::CodexErr; +use crate::error::Result as CodexResult; +use crate::rollout::RolloutRecorder; +use codex_protocol::models::BaseInstructions; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use futures::StreamExt; +use tracing::warn; + +use super::StageOneRequestContext; +use crate::memories::prompts::build_stage_one_input_message; +use crate::memories::rollout::StageOneRolloutFilter; +use crate::memories::rollout::serialize_filtered_rollout_response_items; +use crate::memories::stage_one::RAW_MEMORY_PROMPT; +use crate::memories::stage_one::parse_stage_one_output; +use crate::memories::stage_one::stage_one_output_schema; +use crate::memories::types::StageOneOutput; +use std::path::Path; + +pub(super) async fn extract_stage_one_output( + session: &Session, + rollout_path: &Path, + stage_one_context: &StageOneRequestContext, +) -> Result { + let (rollout_items, _thread_id, parse_errors) = + match RolloutRecorder::load_rollout_items(rollout_path).await { + Ok(result) => result, + Err(err) => { + warn!( + "failed to load rollout {} for memories: {err}", + rollout_path.display() + ); + return Err("failed to load rollout"); + } + }; + if parse_errors > 0 { + warn!( + "rollout {} had {parse_errors} parse errors while preparing stage-1 memory input", + rollout_path.display() + ); + } + + let rollout_contents = match serialize_filtered_rollout_response_items( + &rollout_items, + StageOneRolloutFilter::default(), + ) { + Ok(contents) => contents, + Err(err) => { + warn!( + "failed to prepare filtered rollout payload {} for memories: {err}", + rollout_path.display() + ); + return Err("failed to serialize filtered rollout"); + } + }; + + let prompt = Prompt { + input: vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: build_stage_one_input_message(rollout_path, &rollout_contents), + }], + end_turn: None, + phase: None, + }], + tools: Vec::new(), + parallel_tool_calls: false, + base_instructions: BaseInstructions { + text: RAW_MEMORY_PROMPT.to_string(), + }, + personality: None, + output_schema: Some(stage_one_output_schema()), + }; + + let mut client_session = session.services.model_client.new_session(); + let mut stream = match client_session + .stream( + &prompt, + &stage_one_context.model_info, + &stage_one_context.otel_manager, + stage_one_context.reasoning_effort, + stage_one_context.reasoning_summary, + stage_one_context.turn_metadata_header.as_deref(), + ) + .await + { + Ok(stream) => stream, + Err(err) => { + warn!( + "stage-1 memory request failed for rollout {}: {err}", + rollout_path.display() + ); + return Err("stage-1 memory request failed"); + } + }; + + let output_text = match collect_response_text_until_completed(&mut stream).await { + Ok(text) => text, + Err(err) => { + warn!( + "failed while waiting for stage-1 memory response for rollout {}: {err}", + rollout_path.display() + ); + return Err("stage-1 memory response stream failed"); + } + }; + + match parse_stage_one_output(&output_text) { + Ok(output) => Ok(output), + Err(err) => { + warn!( + "invalid stage-1 memory payload for rollout {}: {err}", + rollout_path.display() + ); + Err("invalid stage-1 memory payload") + } + } +} + +async fn collect_response_text_until_completed(stream: &mut ResponseStream) -> CodexResult { + let mut output_text = String::new(); + + loop { + let Some(event) = stream.next().await else { + return Err(CodexErr::Stream( + "stream closed before response.completed".to_string(), + None, + )); + }; + + match event? { + ResponseEvent::OutputTextDelta(delta) => output_text.push_str(&delta), + ResponseEvent::OutputItemDone(item) => { + if output_text.is_empty() + && let ResponseItem::Message { content, .. } = item + && let Some(text) = crate::compact::content_items_to_text(&content) + { + output_text.push_str(&text); + } + } + ResponseEvent::Completed { .. } => return Ok(output_text), + _ => {} + } + } +} diff --git a/codex-rs/core/src/memories/startup/mod.rs b/codex-rs/core/src/memories/startup/mod.rs new file mode 100644 index 000000000..df908bac9 --- /dev/null +++ b/codex-rs/core/src/memories/startup/mod.rs @@ -0,0 +1,352 @@ +mod dispatch; +mod extract; +mod watch; + +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::config::Config; +use crate::error::Result as CodexResult; +use crate::features::Feature; +use crate::memories::layout::memory_root_for_cwd; +use crate::memories::layout::memory_root_for_user; +use crate::memories::scope::MEMORY_SCOPE_KEY_USER; +use crate::memories::scope::MEMORY_SCOPE_KIND_CWD; +use crate::memories::scope::MEMORY_SCOPE_KIND_USER; +use crate::rollout::INTERACTIVE_SESSION_SOURCES; +use codex_otel::OtelManager; +use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig; +use codex_protocol::protocol::SessionSource; +use futures::StreamExt; +use serde_json::Value; +use std::path::PathBuf; +use std::sync::Arc; +use tracing::info; +use tracing::warn; + +pub(super) const PHASE_ONE_THREAD_SCAN_LIMIT: usize = 5_000; + +#[derive(Clone)] +struct StageOneRequestContext { + model_info: ModelInfo, + otel_manager: OtelManager, + reasoning_effort: Option, + reasoning_summary: ReasoningSummaryConfig, + turn_metadata_header: Option, +} + +impl StageOneRequestContext { + fn from_turn_context(turn_context: &TurnContext, turn_metadata_header: Option) -> Self { + Self { + model_info: turn_context.model_info.clone(), + otel_manager: turn_context.otel_manager.clone(), + reasoning_effort: turn_context.reasoning_effort, + reasoning_summary: turn_context.reasoning_summary, + turn_metadata_header, + } + } +} + +/// Canonical memory scope metadata used by both startup phases. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(super) struct MemoryScopeTarget { + /// Scope family used for DB ownership and dirty-state tracking. + pub(super) scope_kind: &'static str, + /// Scope identifier used for DB keys. + pub(super) scope_key: String, + /// On-disk root where phase-1 artifacts and phase-2 outputs live. + pub(super) memory_root: PathBuf, +} + +/// Converts a pending scope consolidation row into a concrete filesystem target for phase 2. +/// +/// Unsupported scope kinds or malformed user-scope keys are ignored. +pub(super) fn memory_scope_target_for_pending_scope( + config: &Config, + pending_scope: codex_state::PendingScopeConsolidation, +) -> Option { + let scope_kind = pending_scope.scope_kind; + let scope_key = pending_scope.scope_key; + + match scope_kind.as_str() { + MEMORY_SCOPE_KIND_CWD => { + let cwd = PathBuf::from(&scope_key); + Some(MemoryScopeTarget { + scope_kind: MEMORY_SCOPE_KIND_CWD, + scope_key, + memory_root: memory_root_for_cwd(&config.codex_home, &cwd), + }) + } + MEMORY_SCOPE_KIND_USER => { + if scope_key != MEMORY_SCOPE_KEY_USER { + warn!( + "skipping unsupported user memory scope key for phase-2: {}:{}", + scope_kind, scope_key + ); + return None; + } + Some(MemoryScopeTarget { + scope_kind: MEMORY_SCOPE_KIND_USER, + scope_key, + memory_root: memory_root_for_user(&config.codex_home), + }) + } + _ => { + warn!( + "skipping unsupported memory scope for phase-2 consolidation: {}:{}", + scope_kind, scope_key + ); + None + } + } +} + +/// Starts the asynchronous startup memory pipeline for an eligible root session. +/// +/// The pipeline is skipped for ephemeral sessions, disabled feature flags, and +/// subagent sessions. +pub(crate) fn start_memories_startup_task( + session: &Arc, + config: Arc, + source: &SessionSource, +) { + if config.ephemeral + || !config.features.enabled(Feature::MemoryTool) + || matches!(source, SessionSource::SubAgent(_)) + { + return; + } + + let weak_session = Arc::downgrade(session); + tokio::spawn(async move { + let Some(session) = weak_session.upgrade() else { + return; + }; + if let Err(err) = run_memories_startup_pipeline(&session, config).await { + warn!("memories startup pipeline failed: {err}"); + } + }); +} + +/// Runs the startup memory pipeline. +/// +/// Phase 1 selects rollout candidates, performs stage-1 extraction requests in +/// parallel, persists stage-1 outputs, and enqueues consolidation work. +/// +/// Phase 2 claims pending scopes and spawns consolidation agents. +pub(super) async fn run_memories_startup_pipeline( + session: &Arc, + config: Arc, +) -> CodexResult<()> { + let Some(state_db) = session.services.state_db.as_deref() else { + warn!("state db unavailable for memories startup pipeline; skipping"); + return Ok(()); + }; + + let allowed_sources = INTERACTIVE_SESSION_SOURCES + .iter() + .map(|value| match serde_json::to_value(value) { + Ok(Value::String(s)) => s, + Ok(other) => other.to_string(), + Err(_) => String::new(), + }) + .collect::>(); + + let claimed_candidates = match state_db + .claim_stage1_jobs_for_startup( + session.conversation_id, + PHASE_ONE_THREAD_SCAN_LIMIT, + super::MAX_ROLLOUTS_PER_STARTUP, + super::PHASE_ONE_MAX_ROLLOUT_AGE_DAYS, + allowed_sources.as_slice(), + super::PHASE_ONE_JOB_LEASE_SECONDS, + ) + .await + { + Ok(claims) => claims, + Err(err) => { + warn!("state db claim_stage1_jobs_for_startup failed during memories startup: {err}"); + Vec::new() + } + }; + + let claimed_count = claimed_candidates.len(); + let mut succeeded_count = 0; + if claimed_count > 0 { + let turn_context = session.new_default_turn().await; + let stage_one_context = StageOneRequestContext::from_turn_context( + turn_context.as_ref(), + turn_context.resolve_turn_metadata_header().await, + ); + + succeeded_count = futures::stream::iter(claimed_candidates.into_iter()) + .map(|claim| { + let session = Arc::clone(session); + let stage_one_context = stage_one_context.clone(); + async move { + let thread = claim.thread; + let stage_one_output = match extract::extract_stage_one_output( + session.as_ref(), + &thread.rollout_path, + &stage_one_context, + ) + .await + { + Ok(output) => output, + Err(reason) => { + if let Some(state_db) = session.services.state_db.as_deref() { + let _ = state_db + .mark_stage1_job_failed( + thread.id, + &claim.ownership_token, + reason, + super::PHASE_ONE_JOB_RETRY_DELAY_SECONDS, + ) + .await; + } + return false; + } + }; + + let Some(state_db) = session.services.state_db.as_deref() else { + return false; + }; + + state_db + .mark_stage1_job_succeeded( + thread.id, + &claim.ownership_token, + thread.updated_at.timestamp(), + &stage_one_output.raw_memory, + &stage_one_output.summary, + ) + .await + .unwrap_or(false) + } + }) + .buffer_unordered(super::PHASE_ONE_CONCURRENCY_LIMIT) + .collect::>() + .await + .into_iter() + .filter(|ok| *ok) + .count(); + } + + info!( + "memory stage-1 extraction complete: {} job(s) claimed, {} succeeded", + claimed_count, succeeded_count + ); + + let consolidation_scope_count = run_consolidation_dispatch(session, config).await; + info!( + "memory consolidation dispatch complete: {} scope(s) scheduled", + consolidation_scope_count + ); + + Ok(()) +} + +async fn run_consolidation_dispatch(session: &Arc, config: Arc) -> usize { + let scopes = list_consolidation_scopes( + session.as_ref(), + config.as_ref(), + super::MAX_ROLLOUTS_PER_STARTUP, + ) + .await; + let consolidation_scope_count = scopes.len(); + + futures::stream::iter(scopes.into_iter()) + .map(|scope| { + let session = Arc::clone(session); + let config = Arc::clone(&config); + async move { + dispatch::run_memory_consolidation_for_scope(session, config, scope).await; + } + }) + .buffer_unordered(super::PHASE_TWO_CONCURRENCY_LIMIT) + .collect::>() + .await; + + consolidation_scope_count +} + +async fn list_consolidation_scopes( + session: &Session, + config: &Config, + limit: usize, +) -> Vec { + if limit == 0 { + return Vec::new(); + } + + let Some(state_db) = session.services.state_db.as_deref() else { + return Vec::new(); + }; + + let pending_scopes = match state_db.list_pending_scope_consolidations(limit).await { + Ok(scopes) => scopes, + Err(_) => return Vec::new(), + }; + + pending_scopes + .into_iter() + .filter_map(|scope| memory_scope_target_for_pending_scope(config, scope)) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::test_config; + use std::path::PathBuf; + + /// Verifies that phase-2 pending scope rows are translated only for supported scopes. + #[test] + fn pending_scope_mapping_accepts_supported_scopes_only() { + let mut config = test_config(); + config.codex_home = PathBuf::from("/tmp/memory-startup-test-home"); + + let cwd_target = memory_scope_target_for_pending_scope( + &config, + codex_state::PendingScopeConsolidation { + scope_kind: MEMORY_SCOPE_KIND_CWD.to_string(), + scope_key: "/tmp/project-a".to_string(), + }, + ) + .expect("cwd scope should map"); + assert_eq!(cwd_target.scope_kind, MEMORY_SCOPE_KIND_CWD); + + let user_target = memory_scope_target_for_pending_scope( + &config, + codex_state::PendingScopeConsolidation { + scope_kind: MEMORY_SCOPE_KIND_USER.to_string(), + scope_key: MEMORY_SCOPE_KEY_USER.to_string(), + }, + ) + .expect("valid user scope should map"); + assert_eq!(user_target.scope_kind, MEMORY_SCOPE_KIND_USER); + + assert!( + memory_scope_target_for_pending_scope( + &config, + codex_state::PendingScopeConsolidation { + scope_kind: MEMORY_SCOPE_KIND_USER.to_string(), + scope_key: "unexpected-user-key".to_string(), + }, + ) + .is_none() + ); + + assert!( + memory_scope_target_for_pending_scope( + &config, + codex_state::PendingScopeConsolidation { + scope_kind: "unknown".to_string(), + scope_key: "scope".to_string(), + }, + ) + .is_none() + ); + } +} diff --git a/codex-rs/core/src/memories/startup/watch.rs b/codex-rs/core/src/memories/startup/watch.rs new file mode 100644 index 000000000..ea4341e11 --- /dev/null +++ b/codex-rs/core/src/memories/startup/watch.rs @@ -0,0 +1,188 @@ +use crate::agent::AgentStatus; +use crate::agent::status::is_final as is_final_agent_status; +use crate::codex::Session; +use codex_protocol::ThreadId; +use std::time::Duration; +use tracing::debug; +use tracing::info; +use tracing::warn; + +use super::super::PHASE_TWO_JOB_HEARTBEAT_SECONDS; +use super::super::PHASE_TWO_JOB_LEASE_SECONDS; +use super::super::PHASE_TWO_JOB_RETRY_DELAY_SECONDS; +use super::MemoryScopeTarget; + +pub(super) fn spawn_phase2_completion_task( + session: &Session, + scope: MemoryScopeTarget, + ownership_token: String, + completion_watermark: i64, + consolidation_agent_id: ThreadId, +) { + let state_db = session.services.state_db.clone(); + let agent_control = session.services.agent_control.clone(); + + tokio::spawn(async move { + let Some(state_db) = state_db.as_deref() else { + return; + }; + + let mut status_rx = match agent_control.subscribe_status(consolidation_agent_id).await { + Ok(status_rx) => status_rx, + Err(err) => { + warn!( + "failed to subscribe to memory consolidation agent {} for scope {}:{}: {err}", + consolidation_agent_id, scope.scope_kind, scope.scope_key + ); + let _ = state_db + .mark_phase2_job_failed( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + "failed to subscribe to consolidation agent status", + PHASE_TWO_JOB_RETRY_DELAY_SECONDS, + ) + .await; + return; + } + }; + + let mut heartbeat_interval = + tokio::time::interval(Duration::from_secs(PHASE_TWO_JOB_HEARTBEAT_SECONDS)); + heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + let final_status = loop { + let status = status_rx.borrow().clone(); + if is_final_agent_status(&status) { + break status; + } + + tokio::select! { + changed = status_rx.changed() => { + if changed.is_err() { + warn!( + "lost status updates for memory consolidation agent {} in scope {}:{}", + consolidation_agent_id, scope.scope_kind, scope.scope_key + ); + break status; + } + } + _ = heartbeat_interval.tick() => { + match state_db + .heartbeat_phase2_job( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + PHASE_TWO_JOB_LEASE_SECONDS, + ) + .await + { + Ok(true) => {} + Ok(false) => { + debug!( + "memory phase-2 heartbeat lost ownership for scope {}:{}; skipping finalization", + scope.scope_kind, scope.scope_key + ); + return; + } + Err(err) => { + warn!( + "state db heartbeat_phase2_job failed during memories startup: {err}" + ); + return; + } + } + } + } + }; + + if is_phase2_success(&final_status) { + match state_db + .mark_phase2_job_succeeded( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + completion_watermark, + ) + .await + { + Ok(true) => {} + Ok(false) => { + debug!( + "memory phase-2 success finalization skipped after ownership changed: scope={} scope_key={}", + scope.scope_kind, scope.scope_key + ); + } + Err(err) => { + warn!( + "state db mark_phase2_job_succeeded failed during memories startup: {err}" + ); + } + } + info!( + "memory phase-2 consolidation agent finished: scope={} scope_key={} agent_id={} final_status={final_status:?}", + scope.scope_kind, scope.scope_key, consolidation_agent_id + ); + return; + } + + let failure_reason = phase2_failure_reason(&final_status); + match state_db + .mark_phase2_job_failed( + scope.scope_kind, + &scope.scope_key, + &ownership_token, + &failure_reason, + PHASE_TWO_JOB_RETRY_DELAY_SECONDS, + ) + .await + { + Ok(true) => {} + Ok(false) => { + debug!( + "memory phase-2 failure finalization skipped after ownership changed: scope={} scope_key={}", + scope.scope_kind, scope.scope_key + ); + } + Err(err) => { + warn!("state db mark_phase2_job_failed failed during memories startup: {err}"); + } + } + warn!( + "memory phase-2 consolidation agent finished with non-success status: scope={} scope_key={} agent_id={} final_status={final_status:?}", + scope.scope_kind, scope.scope_key, consolidation_agent_id + ); + }); +} + +fn is_phase2_success(final_status: &AgentStatus) -> bool { + matches!(final_status, AgentStatus::Completed(_)) +} + +fn phase2_failure_reason(final_status: &AgentStatus) -> String { + format!("consolidation agent finished with status {final_status:?}") +} + +#[cfg(test)] +mod tests { + use super::is_phase2_success; + use super::phase2_failure_reason; + use crate::agent::AgentStatus; + + #[test] + fn phase2_success_only_for_completed_status() { + assert!(is_phase2_success(&AgentStatus::Completed(None))); + assert!(!is_phase2_success(&AgentStatus::Running)); + assert!(!is_phase2_success(&AgentStatus::Errored( + "oops".to_string() + ))); + } + + #[test] + fn phase2_failure_reason_includes_status() { + let status = AgentStatus::Errored("boom".to_string()); + let reason = phase2_failure_reason(&status); + assert!(reason.contains("consolidation agent finished with status")); + assert!(reason.contains("boom")); + } +} diff --git a/codex-rs/core/src/memories/storage.rs b/codex-rs/core/src/memories/storage.rs index f1dbd96f8..43c0ff50b 100644 --- a/codex-rs/core/src/memories/storage.rs +++ b/codex-rs/core/src/memories/storage.rs @@ -1,48 +1,32 @@ -use codex_state::ThreadMemory; +use codex_state::Stage1Output; use std::collections::BTreeSet; use std::fmt::Write as _; use std::path::Path; use std::path::PathBuf; use tracing::warn; -use super::LEGACY_CONSOLIDATED_FILENAME; use super::MAX_RAW_MEMORIES_PER_SCOPE; -use super::MEMORY_REGISTRY_FILENAME; -use super::SKILLS_SUBDIR; -use super::ensure_layout; -use super::memory_summary_file; -use super::raw_memories_dir; - -/// Prunes stale raw memory files and rebuilds the routing summary for recent memories. -pub(crate) async fn prune_to_recent_memories_and_rebuild_summary( - root: &Path, - memories: &[ThreadMemory], -) -> std::io::Result<()> { - ensure_layout(root).await?; - - let keep = memories - .iter() - .take(MAX_RAW_MEMORIES_PER_SCOPE) - .map(|memory| memory.thread_id.to_string()) - .collect::>(); - - prune_raw_memories(root, &keep).await?; - rebuild_memory_summary(root, memories).await -} +use super::text::compact_whitespace; +use crate::memories::layout::LEGACY_CONSOLIDATED_FILENAME; +use crate::memories::layout::MEMORY_REGISTRY_FILENAME; +use crate::memories::layout::SKILLS_SUBDIR; +use crate::memories::layout::ensure_layout; +use crate::memories::layout::memory_summary_file; +use crate::memories::layout::raw_memories_dir; /// Rebuild `memory_summary.md` for a scope without pruning raw memory files. -pub(crate) async fn rebuild_memory_summary_from_memories( +pub(super) async fn rebuild_memory_summary_from_memories( root: &Path, - memories: &[ThreadMemory], + memories: &[Stage1Output], ) -> std::io::Result<()> { ensure_layout(root).await?; rebuild_memory_summary(root, memories).await } /// Syncs canonical raw memory files from DB-backed memory rows. -pub(crate) async fn sync_raw_memories_from_memories( +pub(super) async fn sync_raw_memories_from_memories( root: &Path, - memories: &[ThreadMemory], + memories: &[Stage1Output], ) -> std::io::Result<()> { ensure_layout(root).await?; @@ -65,7 +49,7 @@ pub(crate) async fn sync_raw_memories_from_memories( /// Clears consolidation outputs so a fresh consolidation run can regenerate them. /// /// Phase-1 artifacts (`raw_memories/` and `memory_summary.md`) are preserved. -pub(crate) async fn wipe_consolidation_outputs(root: &Path) -> std::io::Result<()> { +pub(super) async fn wipe_consolidation_outputs(root: &Path) -> std::io::Result<()> { for file_name in [MEMORY_REGISTRY_FILENAME, LEGACY_CONSOLIDATED_FILENAME] { let path = root.join(file_name); if let Err(err) = tokio::fs::remove_file(&path).await @@ -91,7 +75,7 @@ pub(crate) async fn wipe_consolidation_outputs(root: &Path) -> std::io::Result<( Ok(()) } -async fn rebuild_memory_summary(root: &Path, memories: &[ThreadMemory]) -> std::io::Result<()> { +async fn rebuild_memory_summary(root: &Path, memories: &[Stage1Output]) -> std::io::Result<()> { let mut body = String::from("# Memory Summary\n\n"); if memories.is_empty() { @@ -101,7 +85,7 @@ async fn rebuild_memory_summary(root: &Path, memories: &[ThreadMemory]) -> std:: body.push_str("Map of concise summaries to thread IDs (latest first):\n\n"); for memory in memories.iter().take(MAX_RAW_MEMORIES_PER_SCOPE) { - let summary = compact_summary_for_index(&memory.memory_summary); + let summary = compact_whitespace(&memory.summary); writeln!(body, "- {summary} (thread: `{}`)", memory.thread_id) .map_err(|err| std::io::Error::other(format!("format memory summary: {err}")))?; } @@ -178,7 +162,7 @@ async fn remove_outdated_thread_raw_memories( async fn write_raw_memory_for_thread( root: &Path, - memory: &ThreadMemory, + memory: &Stage1Output, ) -> std::io::Result { let path = raw_memories_dir(root).join(format!("{}.md", memory.thread_id)); @@ -187,8 +171,12 @@ async fn write_raw_memory_for_thread( let mut body = String::new(); writeln!(body, "thread_id: {}", memory.thread_id) .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; - writeln!(body, "updated_at: {}", memory.updated_at.to_rfc3339()) - .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; + writeln!( + body, + "updated_at: {}", + memory.source_updated_at.to_rfc3339() + ) + .map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; writeln!(body).map_err(|err| std::io::Error::other(format!("format raw memory: {err}")))?; body.push_str(memory.raw_memory.trim()); body.push('\n'); @@ -197,10 +185,6 @@ async fn write_raw_memory_for_thread( Ok(path) } -fn compact_summary_for_index(summary: &str) -> String { - summary.split_whitespace().collect::>().join(" ") -} - fn extract_thread_id_from_summary_filename(file_name: &str) -> Option<&str> { let stem = file_name.strip_suffix(".md")?; if stem.is_empty() { diff --git a/codex-rs/core/src/memories/tests.rs b/codex-rs/core/src/memories/tests.rs index 1123a97bb..337b55317 100644 --- a/codex-rs/core/src/memories/tests.rs +++ b/codex-rs/core/src/memories/tests.rs @@ -1,17 +1,14 @@ -use super::MEMORY_SCOPE_KIND_CWD; -use super::PHASE_ONE_MAX_ROLLOUT_AGE_DAYS; -use super::StageOneResponseItemKinds; -use super::StageOneRolloutFilter; -use super::ensure_layout; -use super::memory_root_for_cwd; -use super::memory_scope_key_for_cwd; -use super::memory_summary_file; -use super::parse_stage_one_output; -use super::prune_to_recent_memories_and_rebuild_summary; -use super::raw_memories_dir; -use super::select_rollout_candidates_from_db; -use super::serialize_filtered_rollout_response_items; -use super::wipe_consolidation_outputs; +use super::rollout::StageOneResponseItemKinds; +use super::rollout::StageOneRolloutFilter; +use super::rollout::serialize_filtered_rollout_response_items; +use super::stage_one::parse_stage_one_output; +use super::storage::rebuild_memory_summary_from_memories; +use super::storage::sync_raw_memories_from_memories; +use super::storage::wipe_consolidation_outputs; +use crate::memories::layout::ensure_layout; +use crate::memories::layout::memory_root_for_cwd; +use crate::memories::layout::memory_summary_file; +use crate::memories::layout::raw_memories_dir; use chrono::TimeZone; use chrono::Utc; use codex_protocol::ThreadId; @@ -19,44 +16,10 @@ use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::CompactedItem; use codex_protocol::protocol::RolloutItem; -use codex_state::ThreadMemory; -use codex_state::ThreadMetadata; +use codex_state::Stage1Output; use pretty_assertions::assert_eq; -use std::path::PathBuf; use tempfile::tempdir; -fn thread_metadata( - thread_id: ThreadId, - path: PathBuf, - cwd: PathBuf, - title: &str, - updated_at_secs: i64, -) -> ThreadMetadata { - let updated_at = Utc - .timestamp_opt(updated_at_secs, 0) - .single() - .expect("timestamp"); - ThreadMetadata { - id: thread_id, - rollout_path: path, - created_at: updated_at, - updated_at, - source: "cli".to_string(), - model_provider: "openai".to_string(), - cwd, - cli_version: "test".to_string(), - title: title.to_string(), - sandbox_policy: "read_only".to_string(), - approval_mode: "on_request".to_string(), - tokens_used: 0, - first_user_message: None, - archived_at: None, - git_branch: None, - git_sha: None, - git_origin_url: None, - } -} - #[test] fn memory_root_varies_by_cwd() { let dir = tempdir().expect("tempdir"); @@ -100,22 +63,6 @@ fn memory_root_encoding_avoids_component_collisions() { assert!(!root_hash.display().to_string().contains("workspace")); } -#[test] -fn memory_scope_key_uses_normalized_cwd() { - let dir = tempdir().expect("tempdir"); - let workspace = dir.path().join("workspace"); - std::fs::create_dir_all(&workspace).expect("mkdir workspace"); - std::fs::create_dir_all(workspace.join("nested")).expect("mkdir nested"); - - let alias = workspace.join("nested").join(".."); - let normalized = workspace - .canonicalize() - .expect("canonical workspace path should resolve"); - let alias_key = memory_scope_key_for_cwd(&alias); - let normalized_key = memory_scope_key_for_cwd(&normalized); - assert_eq!(alias_key, normalized_key); -} - #[test] fn parse_stage_one_output_accepts_fenced_json() { let raw = "```json\n{\"rawMemory\":\"abc\",\"summary\":\"short\"}\n```"; @@ -224,61 +171,6 @@ fn serialize_filtered_rollout_response_items_filters_by_response_item_kind() { assert!(matches!(parsed[0], ResponseItem::Message { .. })); } -#[test] -fn select_rollout_candidates_filters_by_age_window() { - let dir = tempdir().expect("tempdir"); - let cwd_a = dir.path().join("workspace-a"); - let cwd_b = dir.path().join("workspace-b"); - std::fs::create_dir_all(&cwd_a).expect("mkdir cwd a"); - std::fs::create_dir_all(&cwd_b).expect("mkdir cwd b"); - - let now = Utc::now().timestamp(); - let current_thread_id = ThreadId::default(); - let recent_thread_id = ThreadId::default(); - let old_thread_id = ThreadId::default(); - let recent_two_thread_id = ThreadId::default(); - - let current = thread_metadata( - current_thread_id, - dir.path().join("current.jsonl"), - cwd_a.clone(), - "current", - now, - ); - let recent = thread_metadata( - recent_thread_id, - dir.path().join("recent.jsonl"), - cwd_a, - "recent", - now - 10, - ); - let old = thread_metadata( - old_thread_id, - dir.path().join("old.jsonl"), - cwd_b.clone(), - "old", - now - (PHASE_ONE_MAX_ROLLOUT_AGE_DAYS + 1) * 24 * 60 * 60, - ); - let recent_two = thread_metadata( - recent_two_thread_id, - dir.path().join("recent-two.jsonl"), - cwd_b, - "recent-two", - now - 20, - ); - - let candidates = select_rollout_candidates_from_db( - &[current, recent, old, recent_two], - current_thread_id, - 5, - PHASE_ONE_MAX_ROLLOUT_AGE_DAYS, - ); - - assert_eq!(candidates.len(), 2); - assert_eq!(candidates[0].thread_id, recent_thread_id); - assert_eq!(candidates[1].thread_id, recent_two_thread_id); -} - #[tokio::test] async fn prune_and_rebuild_summary_keeps_latest_memories_only() { let dir = tempdir().expect("tempdir"); @@ -296,22 +188,20 @@ async fn prune_and_rebuild_summary_keeps_latest_memories_only() { .await .expect("write drop"); - let memories = vec![ThreadMemory { + let memories = vec![Stage1Output { thread_id: ThreadId::try_from(keep_id.clone()).expect("thread id"), - scope_kind: MEMORY_SCOPE_KIND_CWD.to_string(), - scope_key: "scope".to_string(), + source_updated_at: Utc.timestamp_opt(100, 0).single().expect("timestamp"), raw_memory: "raw memory".to_string(), - memory_summary: "short summary".to_string(), - updated_at: Utc.timestamp_opt(100, 0).single().expect("timestamp"), - last_used_at: None, - used_count: 0, - invalidated_at: None, - invalid_reason: None, + summary: "short summary".to_string(), + generated_at: Utc.timestamp_opt(101, 0).single().expect("timestamp"), }]; - prune_to_recent_memories_and_rebuild_summary(&root, &memories) + sync_raw_memories_from_memories(&root, &memories) .await - .expect("prune and rebuild"); + .expect("sync raw memories"); + rebuild_memory_summary_from_memories(&root, &memories) + .await + .expect("rebuild memory summary"); assert!(keep_path.is_file()); assert!(!drop_path.exists()); diff --git a/codex-rs/core/src/memories/text.rs b/codex-rs/core/src/memories/text.rs new file mode 100644 index 000000000..213804ad3 --- /dev/null +++ b/codex-rs/core/src/memories/text.rs @@ -0,0 +1,50 @@ +pub(super) fn compact_whitespace(input: &str) -> String { + input.split_whitespace().collect::>().join(" ") +} + +pub(super) fn truncate_text_for_storage(input: &str, max_bytes: usize, marker: &str) -> String { + if input.len() <= max_bytes { + return input.to_string(); + } + + let budget_without_marker = max_bytes.saturating_sub(marker.len()); + let head_budget = budget_without_marker / 2; + let tail_budget = budget_without_marker.saturating_sub(head_budget); + let head = prefix_at_char_boundary(input, head_budget); + let tail = suffix_at_char_boundary(input, tail_budget); + + format!("{head}{marker}{tail}") +} + +pub(super) fn prefix_at_char_boundary(input: &str, max_bytes: usize) -> &str { + if max_bytes >= input.len() { + return input; + } + + let mut end = 0; + for (idx, _) in input.char_indices() { + if idx > max_bytes { + break; + } + end = idx; + } + + &input[..end] +} + +pub(super) fn suffix_at_char_boundary(input: &str, max_bytes: usize) -> &str { + if max_bytes >= input.len() { + return input; + } + + let start_limit = input.len().saturating_sub(max_bytes); + let mut start = input.len(); + for (idx, _) in input.char_indices().rev() { + if idx < start_limit { + break; + } + start = idx; + } + + &input[start..] +} diff --git a/codex-rs/core/src/memories/types.rs b/codex-rs/core/src/memories/types.rs index 7ba66ee81..5054ff541 100644 --- a/codex-rs/core/src/memories/types.rs +++ b/codex-rs/core/src/memories/types.rs @@ -1,26 +1,11 @@ -use codex_protocol::ThreadId; use serde::Deserialize; -use std::path::PathBuf; - -/// A rollout selected for stage-1 memory extraction during startup. -#[derive(Debug, Clone)] -pub(crate) struct RolloutCandidate { - /// Source thread identifier for this rollout. - pub(crate) thread_id: ThreadId, - /// Absolute path to the rollout file to summarize. - pub(crate) rollout_path: PathBuf, - /// Thread working directory used for per-project memory bucketing. - pub(crate) cwd: PathBuf, - /// Last observed thread update timestamp (RFC3339), if available. - pub(crate) updated_at: Option, -} /// Parsed stage-1 model output payload. #[derive(Debug, Clone, Deserialize)] -pub(crate) struct StageOneOutput { +pub(super) struct StageOneOutput { /// Detailed markdown raw memory for a single rollout. #[serde(rename = "rawMemory", alias = "traceMemory")] - pub(crate) raw_memory: String, + pub(super) raw_memory: String, /// Compact summary line used for routing and indexing. - pub(crate) summary: String, + pub(super) summary: String, } diff --git a/codex-rs/core/src/state_db.rs b/codex-rs/core/src/state_db.rs index f4d760cc2..06bd99724 100644 --- a/codex-rs/core/src/state_db.rs +++ b/codex-rs/core/src/state_db.rs @@ -315,64 +315,6 @@ pub async fn persist_dynamic_tools( } } -/// Get memory summaries for a thread id using SQLite. -pub async fn get_thread_memory( - context: Option<&codex_state::StateRuntime>, - thread_id: ThreadId, - stage: &str, -) -> Option { - let ctx = context?; - match ctx.get_thread_memory(thread_id).await { - Ok(memory) => memory, - Err(err) => { - warn!("state db get_thread_memory failed during {stage}: {err}"); - None - } - } -} - -/// Upsert memory summaries for a thread id using SQLite. -pub async fn upsert_thread_memory( - context: Option<&codex_state::StateRuntime>, - thread_id: ThreadId, - raw_memory: &str, - memory_summary: &str, - stage: &str, -) -> Option { - let ctx = context?; - match ctx - .upsert_thread_memory(thread_id, raw_memory, memory_summary) - .await - { - Ok(memory) => Some(memory), - Err(err) => { - warn!("state db upsert_thread_memory failed during {stage}: {err}"); - None - } - } -} - -/// Get the last N memories corresponding to a cwd using an exact path match. -pub async fn get_last_n_thread_memories_for_cwd( - context: Option<&codex_state::StateRuntime>, - cwd: &Path, - n: usize, - stage: &str, -) -> Option> { - let ctx = context?; - let normalized_cwd = normalize_cwd_for_state_db(cwd); - match ctx - .get_last_n_thread_memories_for_cwd(&normalized_cwd, n) - .await - { - Ok(memories) => Some(memories), - Err(err) => { - warn!("state db get_last_n_thread_memories_for_cwd failed during {stage}: {err}"); - None - } - } -} - /// Reconcile rollout items into SQLite, falling back to scanning the rollout file. pub async fn reconcile_rollout( context: Option<&codex_state::StateRuntime>, diff --git a/codex-rs/state/migrations/0011_generic_jobs_and_stage1_outputs.sql b/codex-rs/state/migrations/0011_generic_jobs_and_stage1_outputs.sql new file mode 100644 index 000000000..9095c8b8e --- /dev/null +++ b/codex-rs/state/migrations/0011_generic_jobs_and_stage1_outputs.sql @@ -0,0 +1,37 @@ +DROP TABLE IF EXISTS thread_memory; +DROP TABLE IF EXISTS memory_phase1_jobs; +DROP TABLE IF EXISTS memory_scope_dirty; +DROP TABLE IF EXISTS memory_phase2_jobs; +DROP TABLE IF EXISTS memory_consolidation_locks; + +CREATE TABLE IF NOT EXISTS stage1_outputs ( + thread_id TEXT PRIMARY KEY, + source_updated_at INTEGER NOT NULL, + raw_memory TEXT NOT NULL, + summary TEXT NOT NULL, + generated_at INTEGER NOT NULL, + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_stage1_outputs_source_updated_at + ON stage1_outputs(source_updated_at DESC, thread_id DESC); + +CREATE TABLE IF NOT EXISTS jobs ( + kind TEXT NOT NULL, + job_key TEXT NOT NULL, + status TEXT NOT NULL, + worker_id TEXT, + ownership_token TEXT, + started_at INTEGER, + finished_at INTEGER, + lease_until INTEGER, + retry_at INTEGER, + retry_remaining INTEGER NOT NULL, + last_error TEXT, + input_watermark INTEGER, + last_success_watermark INTEGER, + PRIMARY KEY (kind, job_key) +); + +CREATE INDEX IF NOT EXISTS idx_jobs_kind_status_retry_lease + ON jobs(kind, status, retry_at, lease_until); diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index 0219dbe49..a521d9997 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -27,15 +27,16 @@ pub use model::BackfillStats; pub use model::BackfillStatus; pub use model::ExtractionOutcome; pub use model::SortKey; -pub use model::ThreadMemory; +pub use model::Stage1Output; pub use model::ThreadMetadata; pub use model::ThreadMetadataBuilder; pub use model::ThreadsPage; -pub use runtime::DirtyMemoryScope; -pub use runtime::Phase1JobClaimOutcome; +pub use runtime::PendingScopeConsolidation; pub use runtime::Phase2JobClaimOutcome; pub use runtime::STATE_DB_FILENAME; pub use runtime::STATE_DB_VERSION; +pub use runtime::Stage1JobClaim; +pub use runtime::Stage1JobClaimOutcome; pub use runtime::state_db_filename; pub use runtime::state_db_path; diff --git a/codex-rs/state/src/model/mod.rs b/codex-rs/state/src/model/mod.rs index 6bec8875d..57fc80596 100644 --- a/codex-rs/state/src/model/mod.rs +++ b/codex-rs/state/src/model/mod.rs @@ -1,6 +1,6 @@ mod backfill_state; mod log; -mod thread_memory; +mod stage1_output; mod thread_metadata; pub use backfill_state::BackfillState; @@ -8,7 +8,7 @@ pub use backfill_state::BackfillStatus; pub use log::LogEntry; pub use log::LogQuery; pub use log::LogRow; -pub use thread_memory::ThreadMemory; +pub use stage1_output::Stage1Output; pub use thread_metadata::Anchor; pub use thread_metadata::BackfillStats; pub use thread_metadata::ExtractionOutcome; @@ -17,7 +17,7 @@ pub use thread_metadata::ThreadMetadata; pub use thread_metadata::ThreadMetadataBuilder; pub use thread_metadata::ThreadsPage; -pub(crate) use thread_memory::ThreadMemoryRow; +pub(crate) use stage1_output::Stage1OutputRow; pub(crate) use thread_metadata::ThreadRow; pub(crate) use thread_metadata::anchor_from_item; pub(crate) use thread_metadata::datetime_to_epoch_seconds; diff --git a/codex-rs/state/src/model/stage1_output.rs b/codex-rs/state/src/model/stage1_output.rs new file mode 100644 index 000000000..e69c6db62 --- /dev/null +++ b/codex-rs/state/src/model/stage1_output.rs @@ -0,0 +1,56 @@ +use anyhow::Result; +use chrono::DateTime; +use chrono::Utc; +use codex_protocol::ThreadId; +use sqlx::Row; +use sqlx::sqlite::SqliteRow; + +/// Stored stage-1 memory extraction output for a single thread. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Stage1Output { + pub thread_id: ThreadId, + pub source_updated_at: DateTime, + pub raw_memory: String, + pub summary: String, + pub generated_at: DateTime, +} + +#[derive(Debug)] +pub(crate) struct Stage1OutputRow { + thread_id: String, + source_updated_at: i64, + raw_memory: String, + summary: String, + generated_at: i64, +} + +impl Stage1OutputRow { + pub(crate) fn try_from_row(row: &SqliteRow) -> Result { + Ok(Self { + thread_id: row.try_get("thread_id")?, + source_updated_at: row.try_get("source_updated_at")?, + raw_memory: row.try_get("raw_memory")?, + summary: row.try_get("summary")?, + generated_at: row.try_get("generated_at")?, + }) + } +} + +impl TryFrom for Stage1Output { + type Error = anyhow::Error; + + fn try_from(row: Stage1OutputRow) -> std::result::Result { + Ok(Self { + thread_id: ThreadId::try_from(row.thread_id)?, + source_updated_at: epoch_seconds_to_datetime(row.source_updated_at)?, + raw_memory: row.raw_memory, + summary: row.summary, + generated_at: epoch_seconds_to_datetime(row.generated_at)?, + }) + } +} + +fn epoch_seconds_to_datetime(secs: i64) -> Result> { + DateTime::::from_timestamp(secs, 0) + .ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}")) +} diff --git a/codex-rs/state/src/model/thread_memory.rs b/codex-rs/state/src/model/thread_memory.rs deleted file mode 100644 index b0b29ce7e..000000000 --- a/codex-rs/state/src/model/thread_memory.rs +++ /dev/null @@ -1,82 +0,0 @@ -use anyhow::Result; -use chrono::DateTime; -use chrono::Utc; -use codex_protocol::ThreadId; -use sqlx::Row; -use sqlx::sqlite::SqliteRow; - -/// Stored memory summaries for a single thread. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ThreadMemory { - pub thread_id: ThreadId, - pub scope_kind: String, - pub scope_key: String, - pub raw_memory: String, - pub memory_summary: String, - pub updated_at: DateTime, - pub last_used_at: Option>, - pub used_count: i64, - pub invalidated_at: Option>, - pub invalid_reason: Option, -} - -#[derive(Debug)] -pub(crate) struct ThreadMemoryRow { - thread_id: String, - scope_kind: String, - scope_key: String, - raw_memory: String, - memory_summary: String, - updated_at: i64, - last_used_at: Option, - used_count: i64, - invalidated_at: Option, - invalid_reason: Option, -} - -impl ThreadMemoryRow { - pub(crate) fn try_from_row(row: &SqliteRow) -> Result { - Ok(Self { - thread_id: row.try_get("thread_id")?, - scope_kind: row.try_get("scope_kind")?, - scope_key: row.try_get("scope_key")?, - raw_memory: row.try_get("raw_memory")?, - memory_summary: row.try_get("memory_summary")?, - updated_at: row.try_get("updated_at")?, - last_used_at: row.try_get("last_used_at")?, - used_count: row.try_get("used_count")?, - invalidated_at: row.try_get("invalidated_at")?, - invalid_reason: row.try_get("invalid_reason")?, - }) - } -} - -impl TryFrom for ThreadMemory { - type Error = anyhow::Error; - - fn try_from(row: ThreadMemoryRow) -> std::result::Result { - Ok(Self { - thread_id: ThreadId::try_from(row.thread_id)?, - scope_kind: row.scope_kind, - scope_key: row.scope_key, - raw_memory: row.raw_memory, - memory_summary: row.memory_summary, - updated_at: epoch_seconds_to_datetime(row.updated_at)?, - last_used_at: row - .last_used_at - .map(epoch_seconds_to_datetime) - .transpose()?, - used_count: row.used_count, - invalidated_at: row - .invalidated_at - .map(epoch_seconds_to_datetime) - .transpose()?, - invalid_reason: row.invalid_reason, - }) - } -} - -fn epoch_seconds_to_datetime(secs: i64) -> Result> { - DateTime::::from_timestamp(secs, 0) - .ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}")) -} diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index b3c763c7c..e29b5aafb 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -3,13 +3,11 @@ use crate::LogEntry; use crate::LogQuery; use crate::LogRow; use crate::SortKey; -use crate::ThreadMemory; use crate::ThreadMetadata; use crate::ThreadMetadataBuilder; use crate::ThreadsPage; use crate::apply_rollout_item; use crate::migrations::MIGRATOR; -use crate::model::ThreadMemoryRow; use crate::model::ThreadRow; use crate::model::anchor_from_item; use crate::model::datetime_to_epoch_seconds; @@ -42,9 +40,14 @@ pub const STATE_DB_FILENAME: &str = "state"; pub const STATE_DB_VERSION: u32 = 4; const MEMORY_SCOPE_KIND_CWD: &str = "cwd"; +const MEMORY_SCOPE_KIND_USER: &str = "user"; +const MEMORY_SCOPE_KEY_USER: &str = "user"; const METRIC_DB_INIT: &str = "codex.db.init"; +mod memory; +// Memory-specific CRUD and phase job lifecycle methods live in `runtime/memory.rs`. + #[derive(Clone)] pub struct StateRuntime { codex_home: PathBuf, @@ -52,24 +55,49 @@ pub struct StateRuntime { pool: Arc, } +/// Result of trying to claim a stage-1 memory extraction job. #[derive(Debug, Clone, PartialEq, Eq)] -pub enum Phase1JobClaimOutcome { +pub enum Stage1JobClaimOutcome { + /// The caller owns the job and should continue with extraction. Claimed { ownership_token: String }, - SkippedTerminalFailure, + /// Existing output is already newer than or equal to the source rollout. SkippedUpToDate, + /// Another worker currently owns a fresh lease for this job. SkippedRunning, + /// The job is in backoff and should not be retried yet. + SkippedRetryBackoff, + /// The job has exhausted retries and should not be retried automatically. + SkippedRetryExhausted, } +/// Claimed stage-1 job with thread metadata. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct DirtyMemoryScope { +pub struct Stage1JobClaim { + pub thread: ThreadMetadata, + pub ownership_token: String, +} + +/// Scope row used to queue phase-2 consolidation work. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PendingScopeConsolidation { + /// Scope family (`cwd` or `user`). pub scope_kind: String, + /// Scope identifier keyed by `scope_kind`. pub scope_key: String, } +/// Result of trying to claim a phase-2 consolidation job. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Phase2JobClaimOutcome { - Claimed { ownership_token: String }, + /// The caller owns the scope and should spawn consolidation. + Claimed { + ownership_token: String, + /// Snapshot of `input_watermark` at claim time. + input_watermark: i64, + }, + /// The scope is not pending consolidation (or is already up to date). SkippedNotDirty, + /// Another worker currently owns a fresh lease for this scope. SkippedRunning, } @@ -254,38 +282,6 @@ ORDER BY position ASC Ok(Some(tools)) } - /// Get memory summaries for a thread, if present. - pub async fn get_thread_memory( - &self, - thread_id: ThreadId, - ) -> anyhow::Result> { - let row = sqlx::query( - r#" -SELECT - thread_id, - scope_kind, - scope_key, - raw_memory, - memory_summary, - updated_at, - last_used_at, - used_count, - invalidated_at, - invalid_reason -FROM thread_memory -WHERE thread_id = ? -ORDER BY updated_at DESC, scope_kind DESC, scope_key DESC -LIMIT 1 - "#, - ) - .bind(thread_id.to_string()) - .fetch_optional(self.pool.as_ref()) - .await?; - - row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from)) - .transpose() - } - /// Find a rollout path by thread id using the underlying database. pub async fn find_rollout_path_by_id( &self, @@ -542,855 +538,6 @@ ON CONFLICT(id) DO UPDATE SET Ok(()) } - /// Insert or update memory summaries for a thread in the cwd scope. - /// - /// This method always advances `updated_at`, even if summaries are unchanged. - pub async fn upsert_thread_memory( - &self, - thread_id: ThreadId, - raw_memory: &str, - memory_summary: &str, - ) -> anyhow::Result { - let Some(thread) = self.get_thread(thread_id).await? else { - return Err(anyhow::anyhow!("thread not found: {thread_id}")); - }; - let scope_key = thread.cwd.display().to_string(); - self.upsert_thread_memory_for_scope( - thread_id, - MEMORY_SCOPE_KIND_CWD, - scope_key.as_str(), - raw_memory, - memory_summary, - ) - .await - } - - /// Insert or update memory summaries for a thread in an explicit scope. - pub async fn upsert_thread_memory_for_scope( - &self, - thread_id: ThreadId, - scope_kind: &str, - scope_key: &str, - raw_memory: &str, - memory_summary: &str, - ) -> anyhow::Result { - if self.get_thread(thread_id).await?.is_none() { - return Err(anyhow::anyhow!("thread not found: {thread_id}")); - } - - let updated_at = Utc::now().timestamp(); - sqlx::query( - r#" -INSERT INTO thread_memory ( - thread_id, - scope_kind, - scope_key, - raw_memory, - memory_summary, - updated_at -) VALUES (?, ?, ?, ?, ?, ?) -ON CONFLICT(thread_id, scope_kind, scope_key) DO UPDATE SET - raw_memory = excluded.raw_memory, - memory_summary = excluded.memory_summary, - updated_at = CASE - WHEN excluded.updated_at <= thread_memory.updated_at THEN thread_memory.updated_at + 1 - ELSE excluded.updated_at - END - "#, - ) - .bind(thread_id.to_string()) - .bind(scope_kind) - .bind(scope_key) - .bind(raw_memory) - .bind(memory_summary) - .bind(updated_at) - .execute(self.pool.as_ref()) - .await?; - - let row = sqlx::query( - r#" -SELECT - thread_id, - scope_kind, - scope_key, - raw_memory, - memory_summary, - updated_at, - last_used_at, - used_count, - invalidated_at, - invalid_reason -FROM thread_memory -WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - "#, - ) - .bind(thread_id.to_string()) - .bind(scope_kind) - .bind(scope_key) - .fetch_optional(self.pool.as_ref()) - .await?; - - row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from)) - .transpose()? - .ok_or_else(|| anyhow::anyhow!("failed to load upserted thread memory: {thread_id}")) - } - - /// Insert or update memory summaries for a thread/scope only if the caller - /// still owns the corresponding phase-1 running job. - pub async fn upsert_thread_memory_for_scope_if_phase1_owner( - &self, - thread_id: ThreadId, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - raw_memory: &str, - memory_summary: &str, - ) -> anyhow::Result> { - if self.get_thread(thread_id).await?.is_none() { - return Err(anyhow::anyhow!("thread not found: {thread_id}")); - } - - let updated_at = Utc::now().timestamp(); - let rows_affected = sqlx::query( - r#" -INSERT INTO thread_memory ( - thread_id, - scope_kind, - scope_key, - raw_memory, - memory_summary, - updated_at -) -SELECT ?, ?, ?, ?, ?, ? -WHERE EXISTS ( - SELECT 1 - FROM memory_phase1_jobs - WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - AND status = 'running' AND ownership_token = ? -) -ON CONFLICT(thread_id, scope_kind, scope_key) DO UPDATE SET - raw_memory = excluded.raw_memory, - memory_summary = excluded.memory_summary, - updated_at = CASE - WHEN excluded.updated_at <= thread_memory.updated_at THEN thread_memory.updated_at + 1 - ELSE excluded.updated_at - END - "#, - ) - .bind(thread_id.to_string()) - .bind(scope_kind) - .bind(scope_key) - .bind(raw_memory) - .bind(memory_summary) - .bind(updated_at) - .bind(thread_id.to_string()) - .bind(scope_kind) - .bind(scope_key) - .bind(ownership_token) - .execute(self.pool.as_ref()) - .await? - .rows_affected(); - - if rows_affected == 0 { - return Ok(None); - } - - let row = sqlx::query( - r#" -SELECT - thread_id, - scope_kind, - scope_key, - raw_memory, - memory_summary, - updated_at, - last_used_at, - used_count, - invalidated_at, - invalid_reason -FROM thread_memory -WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - "#, - ) - .bind(thread_id.to_string()) - .bind(scope_kind) - .bind(scope_key) - .fetch_optional(self.pool.as_ref()) - .await?; - - row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from)) - .transpose() - } - - /// Get the last `n` memories for threads with an exact cwd match. - pub async fn get_last_n_thread_memories_for_cwd( - &self, - cwd: &Path, - n: usize, - ) -> anyhow::Result> { - self.get_last_n_thread_memories_for_scope( - MEMORY_SCOPE_KIND_CWD, - &cwd.display().to_string(), - n, - ) - .await - } - - /// Get the last `n` memories for a specific memory scope. - pub async fn get_last_n_thread_memories_for_scope( - &self, - scope_kind: &str, - scope_key: &str, - n: usize, - ) -> anyhow::Result> { - if n == 0 { - return Ok(Vec::new()); - } - - let rows = sqlx::query( - r#" -SELECT - thread_id, - scope_kind, - scope_key, - raw_memory, - memory_summary, - updated_at, - last_used_at, - used_count, - invalidated_at, - invalid_reason -FROM thread_memory -WHERE scope_kind = ? AND scope_key = ? AND invalidated_at IS NULL -ORDER BY updated_at DESC, thread_id DESC -LIMIT ? - "#, - ) - .bind(scope_kind) - .bind(scope_key) - .bind(n as i64) - .fetch_all(self.pool.as_ref()) - .await?; - - rows.into_iter() - .map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from)) - .collect() - } - - /// Try to claim a phase-1 memory extraction job for `(thread, scope)`. - pub async fn try_claim_phase1_job( - &self, - thread_id: ThreadId, - scope_kind: &str, - scope_key: &str, - owner_session_id: ThreadId, - source_updated_at: i64, - lease_seconds: i64, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let stale_cutoff = now.saturating_sub(lease_seconds.max(0)); - let ownership_token = Uuid::new_v4().to_string(); - let thread_id = thread_id.to_string(); - let owner_session_id = owner_session_id.to_string(); - - let mut tx = self.pool.begin().await?; - let existing = sqlx::query( - r#" -SELECT status, source_updated_at, started_at -FROM memory_phase1_jobs -WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - "#, - ) - .bind(thread_id.as_str()) - .bind(scope_kind) - .bind(scope_key) - .fetch_optional(&mut *tx) - .await?; - - let Some(existing) = existing else { - sqlx::query( - r#" -INSERT INTO memory_phase1_jobs ( - thread_id, - scope_kind, - scope_key, - status, - owner_session_id, - started_at, - finished_at, - failure_reason, - source_updated_at, - raw_memory_path, - summary_hash, - ownership_token -) VALUES (?, ?, ?, 'running', ?, ?, NULL, NULL, ?, NULL, NULL, ?) - "#, - ) - .bind(thread_id.as_str()) - .bind(scope_kind) - .bind(scope_key) - .bind(owner_session_id.as_str()) - .bind(now) - .bind(source_updated_at) - .bind(ownership_token.as_str()) - .execute(&mut *tx) - .await?; - tx.commit().await?; - return Ok(Phase1JobClaimOutcome::Claimed { ownership_token }); - }; - - let status: String = existing.try_get("status")?; - let existing_source_updated_at: i64 = existing.try_get("source_updated_at")?; - let existing_started_at: Option = existing.try_get("started_at")?; - if status == "failed" { - tx.commit().await?; - return Ok(Phase1JobClaimOutcome::SkippedTerminalFailure); - } - if status == "succeeded" && existing_source_updated_at >= source_updated_at { - tx.commit().await?; - return Ok(Phase1JobClaimOutcome::SkippedUpToDate); - } - if status == "running" && existing_started_at.is_some_and(|started| started > stale_cutoff) - { - tx.commit().await?; - return Ok(Phase1JobClaimOutcome::SkippedRunning); - } - - let rows_affected = if let Some(existing_started_at) = existing_started_at { - sqlx::query( - r#" -UPDATE memory_phase1_jobs -SET - status = 'running', - owner_session_id = ?, - started_at = ?, - finished_at = NULL, - failure_reason = NULL, - source_updated_at = ?, - raw_memory_path = NULL, - summary_hash = NULL, - ownership_token = ? -WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - AND status = ? AND source_updated_at = ? AND started_at = ? - "#, - ) - .bind(owner_session_id.as_str()) - .bind(now) - .bind(source_updated_at) - .bind(ownership_token.as_str()) - .bind(thread_id.as_str()) - .bind(scope_kind) - .bind(scope_key) - .bind(status.as_str()) - .bind(existing_source_updated_at) - .bind(existing_started_at) - .execute(&mut *tx) - .await? - .rows_affected() - } else { - sqlx::query( - r#" -UPDATE memory_phase1_jobs -SET - status = 'running', - owner_session_id = ?, - started_at = ?, - finished_at = NULL, - failure_reason = NULL, - source_updated_at = ?, - raw_memory_path = NULL, - summary_hash = NULL, - ownership_token = ? -WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - AND status = ? AND source_updated_at = ? AND started_at IS NULL - "#, - ) - .bind(owner_session_id.as_str()) - .bind(now) - .bind(source_updated_at) - .bind(ownership_token.as_str()) - .bind(thread_id.as_str()) - .bind(scope_kind) - .bind(scope_key) - .bind(status.as_str()) - .bind(existing_source_updated_at) - .execute(&mut *tx) - .await? - .rows_affected() - }; - - tx.commit().await?; - if rows_affected == 0 { - Ok(Phase1JobClaimOutcome::SkippedRunning) - } else { - Ok(Phase1JobClaimOutcome::Claimed { ownership_token }) - } - } - - /// Finalize a claimed phase-1 job as succeeded. - pub async fn mark_phase1_job_succeeded( - &self, - thread_id: ThreadId, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - raw_memory_path: &str, - summary_hash: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let rows_affected = sqlx::query( - r#" -UPDATE memory_phase1_jobs -SET - status = 'succeeded', - finished_at = ?, - failure_reason = NULL, - raw_memory_path = ?, - summary_hash = ? -WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - AND status = 'running' AND ownership_token = ? - "#, - ) - .bind(now) - .bind(raw_memory_path) - .bind(summary_hash) - .bind(thread_id.to_string()) - .bind(scope_kind) - .bind(scope_key) - .bind(ownership_token) - .execute(self.pool.as_ref()) - .await? - .rows_affected(); - Ok(rows_affected > 0) - } - - /// Finalize a claimed phase-1 job as failed. - pub async fn mark_phase1_job_failed( - &self, - thread_id: ThreadId, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - failure_reason: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let rows_affected = sqlx::query( - r#" -UPDATE memory_phase1_jobs -SET - status = 'failed', - finished_at = ?, - failure_reason = ? -WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - AND status = 'running' AND ownership_token = ? - "#, - ) - .bind(now) - .bind(failure_reason) - .bind(thread_id.to_string()) - .bind(scope_kind) - .bind(scope_key) - .bind(ownership_token) - .execute(self.pool.as_ref()) - .await? - .rows_affected(); - Ok(rows_affected > 0) - } - - /// Refresh lease timestamp for a claimed phase-1 job. - /// - /// Returns `true` only when the current owner token still matches. - pub async fn renew_phase1_job_lease( - &self, - thread_id: ThreadId, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let rows_affected = sqlx::query( - r#" -UPDATE memory_phase1_jobs -SET started_at = ? -WHERE thread_id = ? AND scope_kind = ? AND scope_key = ? - AND status = 'running' AND ownership_token = ? - "#, - ) - .bind(now) - .bind(thread_id.to_string()) - .bind(scope_kind) - .bind(scope_key) - .bind(ownership_token) - .execute(self.pool.as_ref()) - .await? - .rows_affected(); - Ok(rows_affected > 0) - } - - /// Mark a memory scope as dirty/clean for phase-2 consolidation scheduling. - pub async fn mark_memory_scope_dirty( - &self, - scope_kind: &str, - scope_key: &str, - dirty: bool, - ) -> anyhow::Result<()> { - let now = Utc::now().timestamp(); - sqlx::query( - r#" -INSERT INTO memory_scope_dirty (scope_kind, scope_key, dirty, updated_at) -VALUES (?, ?, ?, ?) -ON CONFLICT(scope_kind, scope_key) DO UPDATE SET - dirty = excluded.dirty, - updated_at = excluded.updated_at - "#, - ) - .bind(scope_kind) - .bind(scope_key) - .bind(dirty) - .bind(now) - .execute(self.pool.as_ref()) - .await?; - Ok(()) - } - - /// List scopes that currently require phase-2 consolidation. - pub async fn list_dirty_memory_scopes( - &self, - limit: usize, - ) -> anyhow::Result> { - if limit == 0 { - return Ok(Vec::new()); - } - - let rows = sqlx::query( - r#" -SELECT scope_kind, scope_key -FROM memory_scope_dirty -WHERE dirty = 1 -ORDER BY updated_at DESC, scope_kind ASC, scope_key ASC -LIMIT ? - "#, - ) - .bind(limit as i64) - .fetch_all(self.pool.as_ref()) - .await?; - - rows.into_iter() - .map(|row| { - Ok(DirtyMemoryScope { - scope_kind: row.try_get("scope_kind")?, - scope_key: row.try_get("scope_key")?, - }) - }) - .collect() - } - - /// Try to claim a phase-2 consolidation job for `(scope_kind, scope_key)`. - pub async fn try_claim_phase2_job( - &self, - scope_kind: &str, - scope_key: &str, - owner_session_id: ThreadId, - lease_seconds: i64, - ) -> anyhow::Result { - const CAS_RETRY_LIMIT: usize = 3; - - for _ in 0..CAS_RETRY_LIMIT { - let now = Utc::now().timestamp(); - let stale_cutoff = now.saturating_sub(lease_seconds.max(0)); - let ownership_token = Uuid::new_v4().to_string(); - let owner_session_id = owner_session_id.to_string(); - - let mut tx = self.pool.begin().await?; - - let dirty_row = sqlx::query( - r#" -SELECT dirty -FROM memory_scope_dirty -WHERE scope_kind = ? AND scope_key = ? - "#, - ) - .bind(scope_kind) - .bind(scope_key) - .fetch_optional(&mut *tx) - .await?; - let Some(dirty_row) = dirty_row else { - tx.commit().await?; - return Ok(Phase2JobClaimOutcome::SkippedNotDirty); - }; - let dirty: bool = dirty_row.try_get("dirty")?; - if !dirty { - tx.commit().await?; - return Ok(Phase2JobClaimOutcome::SkippedNotDirty); - } - - let existing = sqlx::query( - r#" -SELECT status, last_heartbeat_at, attempt -FROM memory_phase2_jobs -WHERE scope_kind = ? AND scope_key = ? - "#, - ) - .bind(scope_kind) - .bind(scope_key) - .fetch_optional(&mut *tx) - .await?; - - let Some(existing) = existing else { - sqlx::query( - r#" -INSERT INTO memory_phase2_jobs ( - scope_kind, - scope_key, - status, - owner_session_id, - agent_thread_id, - started_at, - last_heartbeat_at, - finished_at, - attempt, - failure_reason, - ownership_token -) VALUES (?, ?, 'running', ?, NULL, ?, ?, NULL, 1, NULL, ?) - "#, - ) - .bind(scope_kind) - .bind(scope_key) - .bind(owner_session_id.as_str()) - .bind(now) - .bind(now) - .bind(ownership_token.as_str()) - .execute(&mut *tx) - .await?; - tx.commit().await?; - return Ok(Phase2JobClaimOutcome::Claimed { ownership_token }); - }; - - let status: String = existing.try_get("status")?; - let existing_last_heartbeat_at: Option = existing.try_get("last_heartbeat_at")?; - let existing_attempt: i64 = existing.try_get("attempt")?; - if status == "running" - && existing_last_heartbeat_at - .is_some_and(|last_heartbeat_at| last_heartbeat_at > stale_cutoff) - { - tx.commit().await?; - return Ok(Phase2JobClaimOutcome::SkippedRunning); - } - - let new_attempt = existing_attempt.saturating_add(1); - let rows_affected = if let Some(existing_last_heartbeat_at) = existing_last_heartbeat_at - { - sqlx::query( - r#" -UPDATE memory_phase2_jobs -SET - status = 'running', - owner_session_id = ?, - agent_thread_id = NULL, - started_at = ?, - last_heartbeat_at = ?, - finished_at = NULL, - attempt = ?, - failure_reason = NULL, - ownership_token = ? -WHERE scope_kind = ? AND scope_key = ? - AND status = ? AND attempt = ? AND last_heartbeat_at = ? - "#, - ) - .bind(owner_session_id.as_str()) - .bind(now) - .bind(now) - .bind(new_attempt) - .bind(ownership_token.as_str()) - .bind(scope_kind) - .bind(scope_key) - .bind(status.as_str()) - .bind(existing_attempt) - .bind(existing_last_heartbeat_at) - .execute(&mut *tx) - .await? - .rows_affected() - } else { - sqlx::query( - r#" -UPDATE memory_phase2_jobs -SET - status = 'running', - owner_session_id = ?, - agent_thread_id = NULL, - started_at = ?, - last_heartbeat_at = ?, - finished_at = NULL, - attempt = ?, - failure_reason = NULL, - ownership_token = ? -WHERE scope_kind = ? AND scope_key = ? - AND status = ? AND attempt = ? AND last_heartbeat_at IS NULL - "#, - ) - .bind(owner_session_id.as_str()) - .bind(now) - .bind(now) - .bind(new_attempt) - .bind(ownership_token.as_str()) - .bind(scope_kind) - .bind(scope_key) - .bind(status.as_str()) - .bind(existing_attempt) - .execute(&mut *tx) - .await? - .rows_affected() - }; - - if rows_affected == 0 { - tx.rollback().await?; - continue; - } - - tx.commit().await?; - return Ok(Phase2JobClaimOutcome::Claimed { ownership_token }); - } - - Ok(Phase2JobClaimOutcome::SkippedRunning) - } - - /// Persist the spawned phase-2 agent id for an owned running job. - pub async fn set_phase2_job_agent_thread_id( - &self, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - agent_thread_id: ThreadId, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let rows_affected = sqlx::query( - r#" -UPDATE memory_phase2_jobs -SET - agent_thread_id = ?, - last_heartbeat_at = ? -WHERE scope_kind = ? AND scope_key = ? - AND status = 'running' AND ownership_token = ? - "#, - ) - .bind(agent_thread_id.to_string()) - .bind(now) - .bind(scope_kind) - .bind(scope_key) - .bind(ownership_token) - .execute(self.pool.as_ref()) - .await? - .rows_affected(); - Ok(rows_affected > 0) - } - - /// Refresh heartbeat timestamp for an owned running phase-2 job. - pub async fn heartbeat_phase2_job( - &self, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let rows_affected = sqlx::query( - r#" -UPDATE memory_phase2_jobs -SET last_heartbeat_at = ? -WHERE scope_kind = ? AND scope_key = ? - AND status = 'running' AND ownership_token = ? - "#, - ) - .bind(now) - .bind(scope_kind) - .bind(scope_key) - .bind(ownership_token) - .execute(self.pool.as_ref()) - .await? - .rows_affected(); - Ok(rows_affected > 0) - } - - /// Finalize a claimed phase-2 job as succeeded and clear dirty state. - pub async fn mark_phase2_job_succeeded( - &self, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let mut tx = self.pool.begin().await?; - let rows_affected = sqlx::query( - r#" -UPDATE memory_phase2_jobs -SET - status = 'succeeded', - finished_at = ?, - failure_reason = NULL -WHERE scope_kind = ? AND scope_key = ? - AND status = 'running' AND ownership_token = ? - "#, - ) - .bind(now) - .bind(scope_kind) - .bind(scope_key) - .bind(ownership_token) - .execute(&mut *tx) - .await? - .rows_affected(); - - if rows_affected == 0 { - tx.commit().await?; - return Ok(false); - } - - sqlx::query( - r#" -UPDATE memory_scope_dirty -SET dirty = 0, updated_at = ? -WHERE scope_kind = ? AND scope_key = ? - "#, - ) - .bind(now) - .bind(scope_kind) - .bind(scope_key) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - Ok(true) - } - - /// Finalize a claimed phase-2 job as failed, leaving dirty scope set. - pub async fn mark_phase2_job_failed( - &self, - scope_kind: &str, - scope_key: &str, - ownership_token: &str, - failure_reason: &str, - ) -> anyhow::Result { - let now = Utc::now().timestamp(); - let rows_affected = sqlx::query( - r#" -UPDATE memory_phase2_jobs -SET - status = 'failed', - finished_at = ?, - failure_reason = ? -WHERE scope_kind = ? AND scope_key = ? - AND status = 'running' AND ownership_token = ? - "#, - ) - .bind(now) - .bind(failure_reason) - .bind(scope_kind) - .bind(scope_key) - .bind(ownership_token) - .execute(self.pool.as_ref()) - .await? - .rows_affected(); - Ok(rows_affected > 0) - } - /// Persist dynamic tools for a thread if none have been stored yet. /// /// Dynamic tools are defined at thread start and should not change afterward. @@ -1770,14 +917,16 @@ fn push_thread_order_and_limit( #[cfg(test)] mod tests { - use super::Phase1JobClaimOutcome; + use super::PendingScopeConsolidation; use super::Phase2JobClaimOutcome; use super::STATE_DB_FILENAME; use super::STATE_DB_VERSION; + use super::Stage1JobClaimOutcome; use super::StateRuntime; use super::ThreadMetadata; use super::state_db_filename; use chrono::DateTime; + use chrono::Duration; use chrono::Utc; use codex_protocol::ThreadId; use codex_protocol::protocol::AskForApproval; @@ -1929,7 +1078,7 @@ mod tests { } #[tokio::test] - async fn upsert_and_get_thread_memory() { + async fn stage1_claim_skips_when_up_to_date() { let codex_home = unique_temp_dir(); let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) .await @@ -1942,458 +1091,149 @@ mod tests { .await .expect("upsert thread"); - assert_eq!( + let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + let claim = runtime + .try_claim_stage1_job(thread_id, owner_a, 100, 3600) + .await + .expect("claim stage1 job"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected claim outcome: {other:?}"), + }; + + assert!( runtime - .get_thread_memory(thread_id) + .mark_stage1_job_succeeded(thread_id, ownership_token.as_str(), 100, "raw", "sum") .await - .expect("get memory before insert"), - None + .expect("mark stage1 succeeded"), + "stage1 success should finalize for current token" ); - let inserted = runtime - .upsert_thread_memory(thread_id, "trace one", "memory one") + let up_to_date = runtime + .try_claim_stage1_job(thread_id, owner_b, 100, 3600) .await - .expect("upsert memory"); - assert_eq!(inserted.thread_id, thread_id); - assert_eq!(inserted.raw_memory, "trace one"); - assert_eq!(inserted.memory_summary, "memory one"); + .expect("claim stage1 up-to-date"); + assert_eq!(up_to_date, Stage1JobClaimOutcome::SkippedUpToDate); - let updated = runtime - .upsert_thread_memory(thread_id, "trace two", "memory two") + let needs_rerun = runtime + .try_claim_stage1_job(thread_id, owner_b, 101, 3600) .await - .expect("update memory"); - assert_eq!(updated.thread_id, thread_id); - assert_eq!(updated.raw_memory, "trace two"); - assert_eq!(updated.memory_summary, "memory two"); + .expect("claim stage1 newer source"); assert!( - updated.updated_at >= inserted.updated_at, - "updated_at should not move backward" + matches!(needs_rerun, Stage1JobClaimOutcome::Claimed { .. }), + "newer source_updated_at should be claimable" ); let _ = tokio::fs::remove_dir_all(codex_home).await; } #[tokio::test] - async fn get_last_n_thread_memories_for_cwd_matches_exactly() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let cwd_a = codex_home.join("workspace-a"); - let cwd_b = codex_home.join("workspace-b"); - let t1 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let t2 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let t3 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, t1, cwd_a.clone())) - .await - .expect("upsert thread t1"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, t2, cwd_a.clone())) - .await - .expect("upsert thread t2"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, t3, cwd_b.clone())) - .await - .expect("upsert thread t3"); - - let first = runtime - .upsert_thread_memory(t1, "trace-1", "memory-1") - .await - .expect("upsert t1 memory"); - runtime - .upsert_thread_memory(t2, "trace-2", "memory-2") - .await - .expect("upsert t2 memory"); - runtime - .upsert_thread_memory(t3, "trace-3", "memory-3") - .await - .expect("upsert t3 memory"); - // Ensure deterministic ordering even when updates happen in the same second. - runtime - .upsert_thread_memory(t1, "trace-1b", "memory-1b") - .await - .expect("upsert t1 memory again"); - - let cwd_a_memories = runtime - .get_last_n_thread_memories_for_cwd(cwd_a.as_path(), 2) - .await - .expect("list cwd a memories"); - assert_eq!(cwd_a_memories.len(), 2); - assert_eq!(cwd_a_memories[0].thread_id, t1); - assert_eq!(cwd_a_memories[0].raw_memory, "trace-1b"); - assert_eq!(cwd_a_memories[0].memory_summary, "memory-1b"); - assert_eq!(cwd_a_memories[1].thread_id, t2); - assert!(cwd_a_memories[0].updated_at >= first.updated_at); - - let cwd_b_memories = runtime - .get_last_n_thread_memories_for_cwd(cwd_b.as_path(), 10) - .await - .expect("list cwd b memories"); - assert_eq!(cwd_b_memories.len(), 1); - assert_eq!(cwd_b_memories[0].thread_id, t3); - - let none = runtime - .get_last_n_thread_memories_for_cwd(codex_home.join("missing").as_path(), 10) - .await - .expect("list missing cwd memories"); - assert_eq!(none, Vec::new()); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn upsert_thread_memory_errors_for_unknown_thread() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let unknown_thread_id = - ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let err = runtime - .upsert_thread_memory(unknown_thread_id, "trace", "memory") - .await - .expect_err("unknown thread should fail"); - assert!( - err.to_string().contains("thread not found"), - "error should mention missing thread: {err}" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn get_last_n_thread_memories_for_cwd_zero_returns_empty() { + async fn stage1_running_stale_can_be_stolen_but_fresh_running_is_skipped() { let codex_home = unique_temp_dir(); let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) .await .expect("initialize runtime"); let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let cwd = codex_home.join("workspace"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd.clone())) - .await - .expect("upsert thread"); - runtime - .upsert_thread_memory(thread_id, "trace", "memory") - .await - .expect("upsert memory"); - - let memories = runtime - .get_last_n_thread_memories_for_cwd(cwd.as_path(), 0) - .await - .expect("query memories"); - assert_eq!(memories, Vec::new()); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn get_last_n_thread_memories_for_cwd_does_not_prefix_match() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let cwd_exact = codex_home.join("workspace"); - let cwd_prefix = codex_home.join("workspace-child"); - let t_exact = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let t_prefix = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - t_exact, - cwd_exact.clone(), - )) - .await - .expect("upsert exact thread"); - runtime - .upsert_thread(&test_thread_metadata( - &codex_home, - t_prefix, - cwd_prefix.clone(), - )) - .await - .expect("upsert prefix thread"); - runtime - .upsert_thread_memory(t_exact, "trace-exact", "memory-exact") - .await - .expect("upsert exact memory"); - runtime - .upsert_thread_memory(t_prefix, "trace-prefix", "memory-prefix") - .await - .expect("upsert prefix memory"); - - let exact_only = runtime - .get_last_n_thread_memories_for_cwd(cwd_exact.as_path(), 10) - .await - .expect("query exact cwd"); - assert_eq!(exact_only.len(), 1); - assert_eq!(exact_only[0].thread_id, t_exact); - assert_eq!(exact_only[0].memory_summary, "memory-exact"); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase2_job_claim_requires_dirty_scope() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - - let claim_without_dirty = runtime - .try_claim_phase2_job("cwd", "scope", owner, 3600) - .await - .expect("claim without dirty"); - assert_eq!(claim_without_dirty, Phase2JobClaimOutcome::SkippedNotDirty); - - runtime - .mark_memory_scope_dirty("cwd", "scope", false) - .await - .expect("mark dirty false"); - let claim_with_false_dirty = runtime - .try_claim_phase2_job("cwd", "scope", owner, 3600) - .await - .expect("claim with false dirty"); - assert_eq!( - claim_with_false_dirty, - Phase2JobClaimOutcome::SkippedNotDirty - ); - - runtime - .mark_memory_scope_dirty("cwd", "scope", true) - .await - .expect("mark dirty true"); - let claim_with_dirty = runtime - .try_claim_phase2_job("cwd", "scope", owner, 3600) - .await - .expect("claim with dirty"); - assert!( - matches!(claim_with_dirty, Phase2JobClaimOutcome::Claimed { .. }), - "dirty scope should be claimable" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase2_running_job_skips_fresh_claims_and_allows_stale_steal() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - + let cwd = codex_home.join("workspace"); runtime - .mark_memory_scope_dirty("cwd", "scope", true) + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) .await - .expect("mark dirty true"); + .expect("upsert thread"); let claim_a = runtime - .try_claim_phase2_job("cwd", "scope", owner_a, 3600) + .try_claim_stage1_job(thread_id, owner_a, 100, 3600) .await - .expect("claim owner_a"); - let owner_a_token = match claim_a { - Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; + .expect("claim a"); + assert!(matches!(claim_a, Stage1JobClaimOutcome::Claimed { .. })); - let fresh_claim_b = runtime - .try_claim_phase2_job("cwd", "scope", owner_b, 3600) + let claim_b_fresh = runtime + .try_claim_stage1_job(thread_id, owner_b, 100, 3600) .await - .expect("fresh claim owner_b"); - assert_eq!(fresh_claim_b, Phase2JobClaimOutcome::SkippedRunning); + .expect("claim b fresh"); + assert_eq!(claim_b_fresh, Stage1JobClaimOutcome::SkippedRunning); - assert!( - runtime - .heartbeat_phase2_job("cwd", "scope", owner_a_token.as_str()) - .await - .expect("owner_a heartbeat"), - "current owner should heartbeat" - ); - assert!( - !runtime - .heartbeat_phase2_job("cwd", "scope", "wrong-token") - .await - .expect("wrong token heartbeat"), - "wrong token should not heartbeat" - ); - - let stale_claim_b = runtime - .try_claim_phase2_job("cwd", "scope", owner_b, 0) + sqlx::query("UPDATE jobs SET lease_until = 0 WHERE kind = 'memory_stage1' AND job_key = ?") + .bind(thread_id.to_string()) + .execute(runtime.pool.as_ref()) .await - .expect("stale claim owner_b"); - let owner_b_token = match stale_claim_b { - Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected stale claim outcome: {other:?}"), - }; + .expect("force stale lease"); - assert!( - !runtime - .heartbeat_phase2_job("cwd", "scope", owner_a_token.as_str()) - .await - .expect("stale owner heartbeat"), - "stale owner should lose heartbeat ownership" - ); - assert!( - runtime - .heartbeat_phase2_job("cwd", "scope", owner_b_token.as_str()) - .await - .expect("new owner heartbeat"), - "new owner should heartbeat" - ); + let claim_b_stale = runtime + .try_claim_stage1_job(thread_id, owner_b, 100, 3600) + .await + .expect("claim b stale"); + assert!(matches!( + claim_b_stale, + Stage1JobClaimOutcome::Claimed { .. } + )); let _ = tokio::fs::remove_dir_all(codex_home).await; } #[tokio::test] - async fn phase2_success_requires_owner_and_clears_dirty_scope() { + async fn claim_stage1_jobs_filters_by_age_and_current_thread() { let codex_home = unique_temp_dir(); let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) .await .expect("initialize runtime"); - let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + let now = Utc::now(); + let recent_at = now - Duration::seconds(10); + let old_at = now - Duration::days(31); + + let current_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("current thread id"); + let recent_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("recent thread id"); + let old_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("old thread id"); + + let mut current = + test_thread_metadata(&codex_home, current_thread_id, codex_home.join("current")); + current.created_at = now; + current.updated_at = now; runtime - .mark_memory_scope_dirty("cwd", "scope", true) + .upsert_thread(¤t) .await - .expect("mark dirty true"); - let claim = runtime - .try_claim_phase2_job("cwd", "scope", owner, 3600) - .await - .expect("claim"); - let ownership_token = match claim { - Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; + .expect("upsert current"); - assert!( - !runtime - .mark_phase2_job_succeeded("cwd", "scope", "wrong-token") - .await - .expect("wrong token success should fail"), - "wrong token should not finalize phase2 job" - ); - let dirty_after_wrong_token = sqlx::query( - "SELECT dirty FROM memory_scope_dirty WHERE scope_kind = ? AND scope_key = ?", - ) - .bind("cwd") - .bind("scope") - .fetch_one(runtime.pool.as_ref()) - .await - .expect("fetch dirty after wrong token") - .try_get::("dirty") - .expect("dirty value"); - assert!(dirty_after_wrong_token, "dirty scope should remain dirty"); + let mut recent = + test_thread_metadata(&codex_home, recent_thread_id, codex_home.join("recent")); + recent.created_at = recent_at; + recent.updated_at = recent_at; + runtime.upsert_thread(&recent).await.expect("upsert recent"); - assert!( - runtime - .mark_phase2_job_succeeded("cwd", "scope", ownership_token.as_str()) - .await - .expect("owner success should pass"), - "owner token should finalize phase2 job" - ); - let dirty_after_success = sqlx::query( - "SELECT dirty FROM memory_scope_dirty WHERE scope_kind = ? AND scope_key = ?", - ) - .bind("cwd") - .bind("scope") - .fetch_one(runtime.pool.as_ref()) - .await - .expect("fetch dirty after success") - .try_get::("dirty") - .expect("dirty value"); - assert!( - !dirty_after_success, - "successful phase2 finalization should clear dirty scope" - ); - let dirty_scopes = runtime - .list_dirty_memory_scopes(10) + let mut old = test_thread_metadata(&codex_home, old_thread_id, codex_home.join("old")); + old.created_at = old_at; + old.updated_at = old_at; + runtime.upsert_thread(&old).await.expect("upsert old"); + + let allowed_sources = vec!["cli".to_string()]; + let claims = runtime + .claim_stage1_jobs_for_startup( + current_thread_id, + 10, + 5, + 30, + allowed_sources.as_slice(), + 3600, + ) .await - .expect("list dirty scopes"); - assert_eq!(dirty_scopes, Vec::new()); + .expect("claim stage1 jobs"); + + assert_eq!(claims.len(), 1); + assert_eq!(claims[0].thread.id, recent_thread_id); let _ = tokio::fs::remove_dir_all(codex_home).await; } #[tokio::test] - async fn phase2_failure_keeps_scope_dirty_and_allows_retry() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - - runtime - .mark_memory_scope_dirty("cwd", "scope", true) - .await - .expect("mark dirty true"); - let claim_a = runtime - .try_claim_phase2_job("cwd", "scope", owner_a, 3600) - .await - .expect("claim owner_a"); - let owner_a_token = match claim_a { - Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - - assert!( - runtime - .mark_phase2_job_failed( - "cwd", - "scope", - owner_a_token.as_str(), - "consolidation failed", - ) - .await - .expect("mark phase2 failed"), - "owner token should fail phase2 job" - ); - let dirty_scopes = runtime - .list_dirty_memory_scopes(10) - .await - .expect("list dirty scopes"); - assert_eq!( - dirty_scopes, - vec![super::DirtyMemoryScope { - scope_kind: "cwd".to_string(), - scope_key: "scope".to_string(), - }] - ); - - let claim_b = runtime - .try_claim_phase2_job("cwd", "scope", owner_b, 3600) - .await - .expect("claim owner_b"); - assert!( - matches!(claim_b, Phase2JobClaimOutcome::Claimed { .. }), - "failed jobs should be retryable while dirty" - ); - - let attempt = sqlx::query( - "SELECT attempt FROM memory_phase2_jobs WHERE scope_kind = ? AND scope_key = ?", - ) - .bind("cwd") - .bind("scope") - .fetch_one(runtime.pool.as_ref()) - .await - .expect("fetch attempt") - .try_get::("attempt") - .expect("attempt value"); - assert_eq!(attempt, 2); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase1_job_claim_and_success_require_current_owner_token() { + async fn stage1_output_cascades_on_thread_delete() { let codex_home = unique_temp_dir(); let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) .await @@ -2408,276 +1248,23 @@ mod tests { .expect("upsert thread"); let claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner, 100, 3600) + .try_claim_stage1_job(thread_id, owner, 100, 3600) .await - .expect("claim phase1 job"); + .expect("claim stage1"); let ownership_token = match claim { - Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - - assert!( - !runtime - .mark_phase1_job_succeeded( - thread_id, - "cwd", - "scope", - "wrong-token", - "/tmp/path", - "summary-hash" - ) - .await - .expect("mark succeeded wrong token should fail"), - "wrong token should not finalize the job" - ); - assert!( - runtime - .mark_phase1_job_succeeded( - thread_id, - "cwd", - "scope", - ownership_token.as_str(), - "/tmp/path", - "summary-hash" - ) - .await - .expect("mark succeeded with current token"), - "current token should finalize the job" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase1_job_running_stale_can_be_stolen_but_fresh_running_is_skipped() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let cwd = codex_home.join("workspace"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) - .await - .expect("upsert thread"); - - let first_claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600) - .await - .expect("first claim"); - assert!( - matches!(first_claim, Phase1JobClaimOutcome::Claimed { .. }), - "first claim should acquire" - ); - - let fresh_second_claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 3600) - .await - .expect("fresh second claim"); - assert_eq!(fresh_second_claim, Phase1JobClaimOutcome::SkippedRunning); - - let stale_second_claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0) - .await - .expect("stale second claim"); - assert!( - matches!(stale_second_claim, Phase1JobClaimOutcome::Claimed { .. }), - "stale running job should be stealable" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase1_job_lease_renewal_requires_current_owner_token() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let cwd = codex_home.join("workspace"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) - .await - .expect("upsert thread"); - - let first_claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600) - .await - .expect("first claim"); - let owner_a_token = match first_claim { - Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - - let stolen_claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0) - .await - .expect("stolen claim"); - let owner_b_token = match stolen_claim { - Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - - assert!( - !runtime - .renew_phase1_job_lease(thread_id, "cwd", "scope", owner_a_token.as_str()) - .await - .expect("old owner lease renewal should fail"), - "stale owner token should not renew lease" - ); - assert!( - runtime - .renew_phase1_job_lease(thread_id, "cwd", "scope", owner_b_token.as_str()) - .await - .expect("current owner lease renewal should succeed"), - "current owner token should renew lease" - ); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase1_owner_guarded_upsert_rejects_stale_owner() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let cwd = codex_home.join("workspace"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) - .await - .expect("upsert thread"); - - let first_claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600) - .await - .expect("first claim"); - let owner_a_token = match first_claim { - Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - - let stolen_claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0) - .await - .expect("stolen claim"); - let owner_b_token = match stolen_claim { - Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, - other => panic!("unexpected claim outcome: {other:?}"), - }; - - let stale_upsert = runtime - .upsert_thread_memory_for_scope_if_phase1_owner( - thread_id, - "cwd", - "scope", - owner_a_token.as_str(), - "stale raw memory", - "stale summary", - ) - .await - .expect("stale owner upsert"); - assert!( - stale_upsert.is_none(), - "stale owner token should not upsert thread memory" - ); - - let current_upsert = runtime - .upsert_thread_memory_for_scope_if_phase1_owner( - thread_id, - "cwd", - "scope", - owner_b_token.as_str(), - "fresh raw memory", - "fresh summary", - ) - .await - .expect("current owner upsert"); - let current_upsert = current_upsert.expect("current owner should upsert"); - assert_eq!(current_upsert.raw_memory, "fresh raw memory"); - assert_eq!(current_upsert.memory_summary, "fresh summary"); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn phase1_job_failed_is_terminal() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); - let cwd = codex_home.join("workspace"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) - .await - .expect("upsert thread"); - - let claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600) - .await - .expect("claim"); - let ownership_token = match claim { - Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, other => panic!("unexpected claim outcome: {other:?}"), }; assert!( runtime - .mark_phase1_job_failed( - thread_id, - "cwd", - "scope", - ownership_token.as_str(), - "prompt failed" - ) + .mark_stage1_job_succeeded(thread_id, ownership_token.as_str(), 100, "raw", "sum") .await - .expect("mark failed"), - "owner token should be able to fail job" + .expect("mark stage1 succeeded"), + "mark stage1 succeeded should write stage1_outputs" ); - let second_claim = runtime - .try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 101, 3600) - .await - .expect("second claim"); - assert_eq!(second_claim, Phase1JobClaimOutcome::SkippedTerminalFailure); - - let _ = tokio::fs::remove_dir_all(codex_home).await; - } - - #[tokio::test] - async fn deleting_thread_cascades_thread_memory() { - let codex_home = unique_temp_dir(); - let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) - .await - .expect("initialize runtime"); - - let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); - let cwd = codex_home.join("workspace"); - runtime - .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) - .await - .expect("upsert thread"); - runtime - .upsert_thread_memory(thread_id, "trace", "memory") - .await - .expect("upsert memory"); - let count_before = - sqlx::query("SELECT COUNT(*) AS count FROM thread_memory WHERE thread_id = ?") + sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") .bind(thread_id.to_string()) .fetch_one(runtime.pool.as_ref()) .await @@ -2693,7 +1280,7 @@ mod tests { .expect("delete thread"); let count_after = - sqlx::query("SELECT COUNT(*) AS count FROM thread_memory WHERE thread_id = ?") + sqlx::query("SELECT COUNT(*) AS count FROM stage1_outputs WHERE thread_id = ?") .bind(thread_id.to_string()) .fetch_one(runtime.pool.as_ref()) .await @@ -2701,12 +1288,308 @@ mod tests { .try_get::("count") .expect("count value"); assert_eq!(count_after, 0); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn phase2_consolidation_jobs_rerun_when_watermark_advances() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + runtime + .enqueue_scope_consolidation("cwd", "/tmp/project-a", 100) + .await + .expect("enqueue scope"); + + let scopes = runtime + .list_pending_scope_consolidations(10) + .await + .expect("list pending"); assert_eq!( + scopes, + vec![PendingScopeConsolidation { + scope_kind: "cwd".to_string(), + scope_key: "/tmp/project-a".to_string(), + }] + ); + + let claim = runtime + .try_claim_phase2_job("cwd", "/tmp/project-a", owner, 3600) + .await + .expect("claim phase2"); + let (ownership_token, input_watermark) = match claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + assert!( runtime - .get_thread_memory(thread_id) + .mark_phase2_job_succeeded( + "cwd", + "/tmp/project-a", + ownership_token.as_str(), + input_watermark, + ) .await - .expect("get memory after delete"), - None + .expect("mark phase2 succeeded"), + "phase2 success should finalize for current token" + ); + + let claim_up_to_date = runtime + .try_claim_phase2_job("cwd", "/tmp/project-a", owner, 3600) + .await + .expect("claim phase2 up-to-date"); + assert_eq!(claim_up_to_date, Phase2JobClaimOutcome::SkippedNotDirty); + + runtime + .enqueue_scope_consolidation("cwd", "/tmp/project-a", 101) + .await + .expect("enqueue scope again"); + + let claim_rerun = runtime + .try_claim_phase2_job("cwd", "/tmp/project-a", owner, 3600) + .await + .expect("claim phase2 rerun"); + assert!( + matches!(claim_rerun, Phase2JobClaimOutcome::Claimed { .. }), + "advanced watermark should be claimable" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn list_stage1_outputs_for_cwd_scope_matches_canonical_equivalent_paths() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let workspace = codex_home.join("workspace"); + tokio::fs::create_dir_all(&workspace) + .await + .expect("create workspace"); + let non_normalized_cwd = workspace.join("..").join("workspace"); + let canonical_scope_key = workspace + .canonicalize() + .expect("canonicalize workspace") + .display() + .to_string(); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + non_normalized_cwd, + )) + .await + .expect("upsert thread"); + + let claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + 100, + "raw memory", + "summary", + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + + let outputs = runtime + .list_stage1_outputs_for_scope("cwd", canonical_scope_key.as_str(), 10) + .await + .expect("list stage1 outputs for canonical cwd scope"); + assert_eq!(outputs.len(), 1); + assert_eq!(outputs[0].thread_id, thread_id); + assert_eq!(outputs[0].summary, "summary"); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_stage1_job_succeeded_normalizes_cwd_scope_job_key() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let workspace = codex_home.join("workspace"); + tokio::fs::create_dir_all(&workspace) + .await + .expect("create workspace"); + let canonical_scope_key = workspace + .canonicalize() + .expect("canonicalize workspace") + .display() + .to_string(); + let cwd_alias = workspace.join("."); + + let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a"); + let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_a, workspace)) + .await + .expect("upsert thread a"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_b, cwd_alias)) + .await + .expect("upsert thread b"); + + let claim_a = runtime + .try_claim_stage1_job(thread_a, owner, 100, 3600) + .await + .expect("claim stage1 a"); + let token_a = match claim_a { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome for thread a: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded(thread_a, token_a.as_str(), 100, "raw-a", "summary-a") + .await + .expect("mark stage1 succeeded a"), + "stage1 success should persist output for thread a" + ); + + let claim_b = runtime + .try_claim_stage1_job(thread_b, owner, 101, 3600) + .await + .expect("claim stage1 b"); + let token_b = match claim_b { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome for thread b: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded(thread_b, token_b.as_str(), 101, "raw-b", "summary-b") + .await + .expect("mark stage1 succeeded b"), + "stage1 success should persist output for thread b" + ); + + let pending_scopes = runtime + .list_pending_scope_consolidations(10) + .await + .expect("list pending scopes"); + let cwd_scopes = pending_scopes + .iter() + .filter(|scope| scope.scope_kind == "cwd") + .cloned() + .collect::>(); + assert_eq!(cwd_scopes.len(), 1); + assert_eq!(cwd_scopes[0].scope_key, canonical_scope_key); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn list_pending_scope_consolidations_omits_unclaimable_jobs() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + runtime + .enqueue_scope_consolidation("cwd", "scope-running", 200) + .await + .expect("enqueue running scope"); + runtime + .enqueue_scope_consolidation("cwd", "scope-backoff", 199) + .await + .expect("enqueue backoff scope"); + runtime + .enqueue_scope_consolidation("cwd", "scope-exhausted", 198) + .await + .expect("enqueue exhausted scope"); + runtime + .enqueue_scope_consolidation("cwd", "scope-claimable-a", 90) + .await + .expect("enqueue claimable scope a"); + runtime + .enqueue_scope_consolidation("cwd", "scope-claimable-b", 89) + .await + .expect("enqueue claimable scope b"); + + let running_claim = runtime + .try_claim_phase2_job("cwd", "scope-running", owner, 3600) + .await + .expect("claim running scope"); + assert!( + matches!(running_claim, Phase2JobClaimOutcome::Claimed { .. }), + "scope-running should be claimed" + ); + + let backoff_claim = runtime + .try_claim_phase2_job("cwd", "scope-backoff", owner, 3600) + .await + .expect("claim backoff scope"); + let backoff_token = match backoff_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, .. + } => ownership_token, + other => panic!("unexpected backoff claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_phase2_job_failed( + "cwd", + "scope-backoff", + backoff_token.as_str(), + "temporary failure", + 3600, + ) + .await + .expect("mark backoff scope failed"), + "backoff scope should transition to retry backoff" + ); + + sqlx::query("UPDATE jobs SET retry_remaining = 0 WHERE kind = ? AND job_key = ?") + .bind("memory_consolidate_cwd") + .bind("scope-exhausted") + .execute(runtime.pool.as_ref()) + .await + .expect("set exhausted scope retries to zero"); + + let pending = runtime + .list_pending_scope_consolidations(2) + .await + .expect("list pending scopes"); + assert_eq!( + pending, + vec![ + PendingScopeConsolidation { + scope_kind: "cwd".to_string(), + scope_key: "scope-claimable-a".to_string(), + }, + PendingScopeConsolidation { + scope_kind: "cwd".to_string(), + scope_key: "scope-claimable-b".to_string(), + }, + ] ); let _ = tokio::fs::remove_dir_all(codex_home).await; diff --git a/codex-rs/state/src/runtime/memory.rs b/codex-rs/state/src/runtime/memory.rs new file mode 100644 index 000000000..0f706d926 --- /dev/null +++ b/codex-rs/state/src/runtime/memory.rs @@ -0,0 +1,800 @@ +use super::*; +use crate::Stage1Output; +use crate::model::Stage1OutputRow; +use chrono::Duration; +use sqlx::Executor; +use sqlx::Sqlite; +use std::collections::HashSet; +use std::path::Path; +use std::path::PathBuf; + +const JOB_KIND_MEMORY_STAGE1: &str = "memory_stage1"; +const JOB_KIND_MEMORY_CONSOLIDATE_CWD: &str = "memory_consolidate_cwd"; +const JOB_KIND_MEMORY_CONSOLIDATE_USER: &str = "memory_consolidate_user"; + +const DEFAULT_RETRY_REMAINING: i64 = 3; + +fn job_kind_for_scope(scope_kind: &str) -> Option<&'static str> { + match scope_kind { + MEMORY_SCOPE_KIND_CWD => Some(JOB_KIND_MEMORY_CONSOLIDATE_CWD), + MEMORY_SCOPE_KIND_USER => Some(JOB_KIND_MEMORY_CONSOLIDATE_USER), + _ => None, + } +} + +fn scope_kind_for_job_kind(job_kind: &str) -> Option<&'static str> { + match job_kind { + JOB_KIND_MEMORY_CONSOLIDATE_CWD => Some(MEMORY_SCOPE_KIND_CWD), + JOB_KIND_MEMORY_CONSOLIDATE_USER => Some(MEMORY_SCOPE_KIND_USER), + _ => None, + } +} + +fn normalize_cwd_for_scope_matching(cwd: &str) -> Option { + Path::new(cwd).canonicalize().ok() +} + +impl StateRuntime { + pub async fn claim_stage1_jobs_for_startup( + &self, + current_thread_id: ThreadId, + scan_limit: usize, + max_claimed: usize, + max_age_days: i64, + allowed_sources: &[String], + lease_seconds: i64, + ) -> anyhow::Result> { + if scan_limit == 0 || max_claimed == 0 { + return Ok(Vec::new()); + } + + let page = self + .list_threads( + scan_limit, + None, + SortKey::UpdatedAt, + allowed_sources, + None, + false, + ) + .await?; + + let cutoff = Utc::now() - Duration::days(max_age_days.max(0)); + let mut claimed = Vec::new(); + + for item in page.items { + if claimed.len() >= max_claimed { + break; + } + if item.id == current_thread_id { + continue; + } + if item.updated_at < cutoff { + continue; + } + + if let Stage1JobClaimOutcome::Claimed { ownership_token } = self + .try_claim_stage1_job( + item.id, + current_thread_id, + item.updated_at.timestamp(), + lease_seconds, + ) + .await? + { + claimed.push(Stage1JobClaim { + thread: item, + ownership_token, + }); + } + } + + Ok(claimed) + } + + pub async fn get_stage1_output( + &self, + thread_id: ThreadId, + ) -> anyhow::Result> { + let row = sqlx::query( + r#" +SELECT thread_id, source_updated_at, raw_memory, summary, generated_at +FROM stage1_outputs +WHERE thread_id = ? + "#, + ) + .bind(thread_id.to_string()) + .fetch_optional(self.pool.as_ref()) + .await?; + + row.map(|row| Stage1OutputRow::try_from_row(&row).and_then(Stage1Output::try_from)) + .transpose() + } + + pub async fn list_stage1_outputs_for_scope( + &self, + scope_kind: &str, + scope_key: &str, + n: usize, + ) -> anyhow::Result> { + if n == 0 { + return Ok(Vec::new()); + } + + let rows = match scope_kind { + MEMORY_SCOPE_KIND_CWD => { + let exact_rows = sqlx::query( + r#" +SELECT so.thread_id, so.source_updated_at, so.raw_memory, so.summary, so.generated_at +FROM stage1_outputs AS so +JOIN threads AS t ON t.id = so.thread_id +WHERE t.cwd = ? +ORDER BY so.source_updated_at DESC, so.thread_id DESC +LIMIT ? + "#, + ) + .bind(scope_key) + .bind(n as i64) + .fetch_all(self.pool.as_ref()) + .await?; + + if let Some(normalized_scope_key) = normalize_cwd_for_scope_matching(scope_key) { + let mut rows = Vec::new(); + let mut selected_thread_ids = HashSet::new(); + let candidate_rows = sqlx::query( + r#" +SELECT so.thread_id, so.source_updated_at, so.raw_memory, so.summary, so.generated_at, t.cwd AS thread_cwd +FROM stage1_outputs AS so +JOIN threads AS t ON t.id = so.thread_id +ORDER BY so.source_updated_at DESC, so.thread_id DESC + "#, + ) + .fetch_all(self.pool.as_ref()) + .await?; + + for row in candidate_rows { + if rows.len() >= n { + break; + } + let thread_id: String = row.try_get("thread_id")?; + if selected_thread_ids.contains(&thread_id) { + continue; + } + let thread_cwd: String = row.try_get("thread_cwd")?; + if let Some(normalized_thread_cwd) = + normalize_cwd_for_scope_matching(&thread_cwd) + && normalized_thread_cwd == normalized_scope_key + { + selected_thread_ids.insert(thread_id); + rows.push(row); + } + } + if rows.is_empty() { exact_rows } else { rows } + } else { + exact_rows + } + } + MEMORY_SCOPE_KIND_USER => { + sqlx::query( + r#" +SELECT so.thread_id, so.source_updated_at, so.raw_memory, so.summary, so.generated_at +FROM stage1_outputs AS so +JOIN threads AS t ON t.id = so.thread_id +ORDER BY so.source_updated_at DESC, so.thread_id DESC +LIMIT ? + "#, + ) + .bind(n as i64) + .fetch_all(self.pool.as_ref()) + .await? + } + _ => return Ok(Vec::new()), + }; + + rows.into_iter() + .map(|row| Stage1OutputRow::try_from_row(&row).and_then(Stage1Output::try_from)) + .collect::, _>>() + } + + pub async fn try_claim_stage1_job( + &self, + thread_id: ThreadId, + worker_id: ThreadId, + source_updated_at: i64, + lease_seconds: i64, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let lease_until = now.saturating_add(lease_seconds.max(0)); + let ownership_token = Uuid::new_v4().to_string(); + let thread_id = thread_id.to_string(); + let worker_id = worker_id.to_string(); + + let mut tx = self.pool.begin().await?; + + let existing_output = sqlx::query( + r#" +SELECT source_updated_at +FROM stage1_outputs +WHERE thread_id = ? + "#, + ) + .bind(thread_id.as_str()) + .fetch_optional(&mut *tx) + .await?; + if let Some(existing_output) = existing_output { + let existing_source_updated_at: i64 = existing_output.try_get("source_updated_at")?; + if existing_source_updated_at >= source_updated_at { + tx.commit().await?; + return Ok(Stage1JobClaimOutcome::SkippedUpToDate); + } + } + + let existing_job = sqlx::query( + r#" +SELECT status, lease_until, retry_at, retry_remaining +FROM jobs +WHERE kind = ? AND job_key = ? + "#, + ) + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(thread_id.as_str()) + .fetch_optional(&mut *tx) + .await?; + + let Some(existing_job) = existing_job else { + sqlx::query( + r#" +INSERT INTO jobs ( + kind, + job_key, + status, + worker_id, + ownership_token, + started_at, + finished_at, + lease_until, + retry_at, + retry_remaining, + last_error, + input_watermark, + last_success_watermark +) VALUES (?, ?, 'running', ?, ?, ?, NULL, ?, NULL, ?, NULL, ?, NULL) + "#, + ) + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(thread_id.as_str()) + .bind(worker_id.as_str()) + .bind(ownership_token.as_str()) + .bind(now) + .bind(lease_until) + .bind(DEFAULT_RETRY_REMAINING) + .bind(source_updated_at) + .execute(&mut *tx) + .await?; + tx.commit().await?; + return Ok(Stage1JobClaimOutcome::Claimed { ownership_token }); + }; + + let status: String = existing_job.try_get("status")?; + let existing_lease_until: Option = existing_job.try_get("lease_until")?; + let retry_at: Option = existing_job.try_get("retry_at")?; + let retry_remaining: i64 = existing_job.try_get("retry_remaining")?; + + if retry_remaining <= 0 { + tx.commit().await?; + return Ok(Stage1JobClaimOutcome::SkippedRetryExhausted); + } + if retry_at.is_some_and(|retry_at| retry_at > now) { + tx.commit().await?; + return Ok(Stage1JobClaimOutcome::SkippedRetryBackoff); + } + if status == "running" && existing_lease_until.is_some_and(|lease_until| lease_until > now) + { + tx.commit().await?; + return Ok(Stage1JobClaimOutcome::SkippedRunning); + } + + let rows_affected = sqlx::query( + r#" +UPDATE jobs +SET + status = 'running', + worker_id = ?, + ownership_token = ?, + started_at = ?, + finished_at = NULL, + lease_until = ?, + retry_at = NULL, + last_error = NULL, + input_watermark = ? +WHERE kind = ? AND job_key = ? + AND (status != 'running' OR lease_until IS NULL OR lease_until <= ?) + AND (retry_at IS NULL OR retry_at <= ?) + AND retry_remaining > 0 + "#, + ) + .bind(worker_id.as_str()) + .bind(ownership_token.as_str()) + .bind(now) + .bind(lease_until) + .bind(source_updated_at) + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(thread_id.as_str()) + .bind(now) + .bind(now) + .execute(&mut *tx) + .await? + .rows_affected(); + + tx.commit().await?; + if rows_affected == 0 { + Ok(Stage1JobClaimOutcome::SkippedRunning) + } else { + Ok(Stage1JobClaimOutcome::Claimed { ownership_token }) + } + } + + pub async fn mark_stage1_job_succeeded( + &self, + thread_id: ThreadId, + ownership_token: &str, + source_updated_at: i64, + raw_memory: &str, + summary: &str, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let thread_id = thread_id.to_string(); + + let mut tx = self.pool.begin().await?; + let rows_affected = sqlx::query( + r#" +UPDATE jobs +SET + status = 'done', + finished_at = ?, + lease_until = NULL, + last_error = NULL, + last_success_watermark = input_watermark +WHERE kind = ? AND job_key = ? + AND status = 'running' AND ownership_token = ? + "#, + ) + .bind(now) + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(thread_id.as_str()) + .bind(ownership_token) + .execute(&mut *tx) + .await? + .rows_affected(); + + if rows_affected == 0 { + tx.commit().await?; + return Ok(false); + } + + sqlx::query( + r#" +INSERT INTO stage1_outputs ( + thread_id, + source_updated_at, + raw_memory, + summary, + generated_at +) VALUES (?, ?, ?, ?, ?) +ON CONFLICT(thread_id) DO UPDATE SET + source_updated_at = excluded.source_updated_at, + raw_memory = excluded.raw_memory, + summary = excluded.summary, + generated_at = excluded.generated_at +WHERE excluded.source_updated_at >= stage1_outputs.source_updated_at + "#, + ) + .bind(thread_id.as_str()) + .bind(source_updated_at) + .bind(raw_memory) + .bind(summary) + .bind(now) + .execute(&mut *tx) + .await?; + + if let Some(thread_row) = sqlx::query( + r#" +SELECT cwd +FROM threads +WHERE id = ? + "#, + ) + .bind(thread_id.as_str()) + .fetch_optional(&mut *tx) + .await? + { + let cwd: String = thread_row.try_get("cwd")?; + let normalized_cwd = normalize_cwd_for_scope_matching(&cwd) + .unwrap_or_else(|| PathBuf::from(&cwd)) + .display() + .to_string(); + enqueue_scope_consolidation_with_executor( + &mut *tx, + MEMORY_SCOPE_KIND_CWD, + &normalized_cwd, + source_updated_at, + ) + .await?; + enqueue_scope_consolidation_with_executor( + &mut *tx, + MEMORY_SCOPE_KIND_USER, + MEMORY_SCOPE_KEY_USER, + source_updated_at, + ) + .await?; + } + + tx.commit().await?; + Ok(true) + } + + pub async fn mark_stage1_job_failed( + &self, + thread_id: ThreadId, + ownership_token: &str, + failure_reason: &str, + retry_delay_seconds: i64, + ) -> anyhow::Result { + let now = Utc::now().timestamp(); + let retry_at = now.saturating_add(retry_delay_seconds.max(0)); + let thread_id = thread_id.to_string(); + + let rows_affected = sqlx::query( + r#" +UPDATE jobs +SET + status = 'error', + finished_at = ?, + lease_until = NULL, + retry_at = ?, + retry_remaining = retry_remaining - 1, + last_error = ? +WHERE kind = ? AND job_key = ? + AND status = 'running' AND ownership_token = ? + "#, + ) + .bind(now) + .bind(retry_at) + .bind(failure_reason) + .bind(JOB_KIND_MEMORY_STAGE1) + .bind(thread_id.as_str()) + .bind(ownership_token) + .execute(self.pool.as_ref()) + .await? + .rows_affected(); + + Ok(rows_affected > 0) + } + + pub async fn enqueue_scope_consolidation( + &self, + scope_kind: &str, + scope_key: &str, + input_watermark: i64, + ) -> anyhow::Result<()> { + enqueue_scope_consolidation_with_executor( + self.pool.as_ref(), + scope_kind, + scope_key, + input_watermark, + ) + .await + } + + pub async fn list_pending_scope_consolidations( + &self, + limit: usize, + ) -> anyhow::Result> { + if limit == 0 { + return Ok(Vec::new()); + } + let now = Utc::now().timestamp(); + + let rows = sqlx::query( + r#" +SELECT kind, job_key +FROM jobs +WHERE kind IN (?, ?) + AND input_watermark IS NOT NULL + AND input_watermark > COALESCE(last_success_watermark, 0) + AND retry_remaining > 0 + AND (retry_at IS NULL OR retry_at <= ?) + AND (status != 'running' OR lease_until IS NULL OR lease_until <= ?) +ORDER BY input_watermark DESC, kind ASC, job_key ASC +LIMIT ? + "#, + ) + .bind(JOB_KIND_MEMORY_CONSOLIDATE_CWD) + .bind(JOB_KIND_MEMORY_CONSOLIDATE_USER) + .bind(now) + .bind(now) + .bind(limit as i64) + .fetch_all(self.pool.as_ref()) + .await?; + + Ok(rows + .into_iter() + .filter_map(|row| { + let kind: String = row.try_get("kind").ok()?; + let scope_kind = scope_kind_for_job_kind(&kind)?; + let scope_key: String = row.try_get("job_key").ok()?; + Some(PendingScopeConsolidation { + scope_kind: scope_kind.to_string(), + scope_key, + }) + }) + .collect::>()) + } + + /// Try to claim a phase-2 consolidation job for `(scope_kind, scope_key)`. + pub async fn try_claim_phase2_job( + &self, + scope_kind: &str, + scope_key: &str, + worker_id: ThreadId, + lease_seconds: i64, + ) -> anyhow::Result { + let Some(job_kind) = job_kind_for_scope(scope_kind) else { + return Ok(Phase2JobClaimOutcome::SkippedNotDirty); + }; + + let now = Utc::now().timestamp(); + let lease_until = now.saturating_add(lease_seconds.max(0)); + let ownership_token = Uuid::new_v4().to_string(); + let worker_id = worker_id.to_string(); + + let mut tx = self.pool.begin().await?; + + let existing_job = sqlx::query( + r#" +SELECT status, lease_until, retry_at, retry_remaining, input_watermark, last_success_watermark +FROM jobs +WHERE kind = ? AND job_key = ? + "#, + ) + .bind(job_kind) + .bind(scope_key) + .fetch_optional(&mut *tx) + .await?; + + let Some(existing_job) = existing_job else { + tx.commit().await?; + return Ok(Phase2JobClaimOutcome::SkippedNotDirty); + }; + + let input_watermark: Option = existing_job.try_get("input_watermark")?; + let input_watermark_value = input_watermark.unwrap_or(0); + let last_success_watermark: Option = existing_job.try_get("last_success_watermark")?; + if input_watermark_value <= last_success_watermark.unwrap_or(0) { + tx.commit().await?; + return Ok(Phase2JobClaimOutcome::SkippedNotDirty); + } + + let status: String = existing_job.try_get("status")?; + let existing_lease_until: Option = existing_job.try_get("lease_until")?; + let retry_at: Option = existing_job.try_get("retry_at")?; + let retry_remaining: i64 = existing_job.try_get("retry_remaining")?; + + if retry_remaining <= 0 { + tx.commit().await?; + return Ok(Phase2JobClaimOutcome::SkippedNotDirty); + } + if retry_at.is_some_and(|retry_at| retry_at > now) { + tx.commit().await?; + return Ok(Phase2JobClaimOutcome::SkippedNotDirty); + } + if status == "running" && existing_lease_until.is_some_and(|lease_until| lease_until > now) + { + tx.commit().await?; + return Ok(Phase2JobClaimOutcome::SkippedRunning); + } + + let rows_affected = sqlx::query( + r#" +UPDATE jobs +SET + status = 'running', + worker_id = ?, + ownership_token = ?, + started_at = ?, + finished_at = NULL, + lease_until = ?, + retry_at = NULL, + last_error = NULL +WHERE kind = ? AND job_key = ? + AND (status != 'running' OR lease_until IS NULL OR lease_until <= ?) + AND (retry_at IS NULL OR retry_at <= ?) + AND retry_remaining > 0 + "#, + ) + .bind(worker_id.as_str()) + .bind(ownership_token.as_str()) + .bind(now) + .bind(lease_until) + .bind(job_kind) + .bind(scope_key) + .bind(now) + .bind(now) + .execute(&mut *tx) + .await? + .rows_affected(); + + tx.commit().await?; + if rows_affected == 0 { + Ok(Phase2JobClaimOutcome::SkippedRunning) + } else { + Ok(Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark: input_watermark_value, + }) + } + } + + pub async fn heartbeat_phase2_job( + &self, + scope_kind: &str, + scope_key: &str, + ownership_token: &str, + lease_seconds: i64, + ) -> anyhow::Result { + let Some(job_kind) = job_kind_for_scope(scope_kind) else { + return Ok(false); + }; + + let now = Utc::now().timestamp(); + let lease_until = now.saturating_add(lease_seconds.max(0)); + let rows_affected = sqlx::query( + r#" +UPDATE jobs +SET lease_until = ? +WHERE kind = ? AND job_key = ? + AND status = 'running' AND ownership_token = ? + "#, + ) + .bind(lease_until) + .bind(job_kind) + .bind(scope_key) + .bind(ownership_token) + .execute(self.pool.as_ref()) + .await? + .rows_affected(); + + Ok(rows_affected > 0) + } + + pub async fn mark_phase2_job_succeeded( + &self, + scope_kind: &str, + scope_key: &str, + ownership_token: &str, + completed_watermark: i64, + ) -> anyhow::Result { + let Some(job_kind) = job_kind_for_scope(scope_kind) else { + return Ok(false); + }; + + let now = Utc::now().timestamp(); + let rows_affected = sqlx::query( + r#" +UPDATE jobs +SET + status = 'done', + finished_at = ?, + lease_until = NULL, + last_error = NULL, + last_success_watermark = max(COALESCE(last_success_watermark, 0), ?) +WHERE kind = ? AND job_key = ? + AND status = 'running' AND ownership_token = ? + "#, + ) + .bind(now) + .bind(completed_watermark) + .bind(job_kind) + .bind(scope_key) + .bind(ownership_token) + .execute(self.pool.as_ref()) + .await? + .rows_affected(); + + Ok(rows_affected > 0) + } + + pub async fn mark_phase2_job_failed( + &self, + scope_kind: &str, + scope_key: &str, + ownership_token: &str, + failure_reason: &str, + retry_delay_seconds: i64, + ) -> anyhow::Result { + let Some(job_kind) = job_kind_for_scope(scope_kind) else { + return Ok(false); + }; + + let now = Utc::now().timestamp(); + let retry_at = now.saturating_add(retry_delay_seconds.max(0)); + let rows_affected = sqlx::query( + r#" +UPDATE jobs +SET + status = 'error', + finished_at = ?, + lease_until = NULL, + retry_at = ?, + retry_remaining = retry_remaining - 1, + last_error = ? +WHERE kind = ? AND job_key = ? + AND status = 'running' AND ownership_token = ? + "#, + ) + .bind(now) + .bind(retry_at) + .bind(failure_reason) + .bind(job_kind) + .bind(scope_key) + .bind(ownership_token) + .execute(self.pool.as_ref()) + .await? + .rows_affected(); + + Ok(rows_affected > 0) + } +} + +async fn enqueue_scope_consolidation_with_executor<'e, E>( + executor: E, + scope_kind: &str, + scope_key: &str, + input_watermark: i64, +) -> anyhow::Result<()> +where + E: Executor<'e, Database = Sqlite>, +{ + let Some(job_kind) = job_kind_for_scope(scope_kind) else { + return Ok(()); + }; + + sqlx::query( + r#" +INSERT INTO jobs ( + kind, + job_key, + status, + worker_id, + ownership_token, + started_at, + finished_at, + lease_until, + retry_at, + retry_remaining, + last_error, + input_watermark, + last_success_watermark +) VALUES (?, ?, 'pending', NULL, NULL, NULL, NULL, NULL, NULL, ?, NULL, ?, 0) +ON CONFLICT(kind, job_key) DO UPDATE SET + status = CASE + WHEN jobs.status = 'running' THEN 'running' + ELSE 'pending' + END, + retry_at = CASE + WHEN jobs.status = 'running' THEN jobs.retry_at + ELSE NULL + END, + retry_remaining = max(jobs.retry_remaining, excluded.retry_remaining), + input_watermark = max(COALESCE(jobs.input_watermark, 0), excluded.input_watermark) + "#, + ) + .bind(job_kind) + .bind(scope_key) + .bind(DEFAULT_RETRY_REMAINING) + .bind(input_watermark) + .execute(executor) + .await?; + + Ok(()) +}