From 382fa338b3f1823711c30b89a790bb6b32f66c8b Mon Sep 17 00:00:00 2001 From: jif-oai Date: Thu, 26 Feb 2026 13:19:57 +0000 Subject: [PATCH] feat: memories forgetting (#12900) Add diff based memory forgetting --- codex-rs/core/src/memories/README.md | 20 + codex-rs/core/src/memories/phase2.rs | 65 +- codex-rs/core/src/memories/prompts.rs | 83 ++- .../core/templates/memories/consolidation.md | 66 +- codex-rs/core/tests/suite/memories.rs | 282 ++++++++ codex-rs/core/tests/suite/mod.rs | 1 + .../0018_phase2_selection_snapshot.sql | 3 + codex-rs/state/src/lib.rs | 2 + codex-rs/state/src/model/memories.rs | 27 + codex-rs/state/src/model/mod.rs | 3 + codex-rs/state/src/runtime.rs | 660 +++++++++++++++++- codex-rs/state/src/runtime/memories.rs | 162 ++++- 12 files changed, 1335 insertions(+), 39 deletions(-) create mode 100644 codex-rs/core/tests/suite/memories.rs create mode 100644 codex-rs/state/migrations/0018_phase2_selection_snapshot.sql diff --git a/codex-rs/core/src/memories/README.md b/codex-rs/core/src/memories/README.md index c19a0c680..afbc94e4d 100644 --- a/codex-rs/core/src/memories/README.md +++ b/codex-rs/core/src/memories/README.md @@ -70,11 +70,31 @@ What it does: If there is input, it then: - spawns an internal consolidation sub-agent +- builds the Phase 2 prompt with a diff of the current Phase 1 input + selection versus the last successful Phase 2 selection (`added`, + `retained`, `removed`) - runs it with no approvals, no network, and local write access only - disables collab for that agent (to prevent recursive delegation) - watches the agent status and heartbeats the global job lease while it runs - marks the phase-2 job success/failure in the state DB when the agent finishes +Selection diff behavior: + +- successful Phase 2 runs mark the exact stage-1 snapshots they consumed with + `selected_for_phase2 = 1` and persist the matching + `selected_for_phase2_source_updated_at` +- Phase 1 upserts preserve the previous `selected_for_phase2` baseline until + the next successful Phase 2 run rewrites it +- the next Phase 2 run compares the current top-N stage-1 inputs against that + prior snapshot selection to label inputs as `added` or `retained`; a + refreshed thread stays `added` until Phase 2 successfully selects its newer + snapshot +- rows that were previously selected but still exist outside the current top-N + selection are surfaced as `removed` +- before the agent starts, local `rollout_summaries/` and `raw_memories.md` + keep the union of the current selection and the previous successful + selection, so removed-thread evidence stays available during forgetting + Watermark behavior: - The global phase-2 job claim includes an input watermark representing the latest input timestamp known when the job was claimed. diff --git a/codex-rs/core/src/memories/phase2.rs b/codex-rs/core/src/memories/phase2.rs index 6794a4574..f86b3bbde 100644 --- a/codex-rs/core/src/memories/phase2.rs +++ b/codex-rs/core/src/memories/phase2.rs @@ -8,6 +8,7 @@ use crate::memories::metrics; use crate::memories::phase_two; use crate::memories::prompts::build_consolidation_prompt; use crate::memories::storage::rebuild_raw_memories_file_from_memories; +use crate::memories::storage::rollout_summary_file_stem; use crate::memories::storage::sync_rollout_summaries_from_memories; use codex_config::Constrained; use codex_protocol::ThreadId; @@ -17,8 +18,10 @@ use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_protocol::protocol::TokenUsage; use codex_protocol::user_input::UserInput; +use codex_state::Stage1Output; use codex_state::StateRuntime; use codex_utils_absolute_path::AbsolutePathBuf; +use std::collections::HashSet; use std::sync::Arc; use std::time::Duration; use tokio::sync::watch; @@ -73,21 +76,24 @@ pub(super) async fn run(session: &Arc, config: Arc) { }; // 3. Query the memories - let raw_memories = match db.list_stage1_outputs_for_global(max_raw_memories).await { - Ok(memories) => memories, + let selection = match db.get_phase2_input_selection(max_raw_memories).await { + Ok(selection) => selection, Err(err) => { tracing::error!("failed to list stage1 outputs from global: {}", err); job::failed(session, db, &claim, "failed_load_stage1_outputs").await; return; } }; + let raw_memories = selection.selected.to_vec(); + let artifact_memories = artifact_memories_for_phase2(&selection); let new_watermark = get_watermark(claim.watermark, &raw_memories); // 4. Update the file system by syncing the raw memories with the one extracted from DB at // step 3 // [`rollout_summaries/`] if let Err(err) = - sync_rollout_summaries_from_memories(&root, &raw_memories, max_raw_memories).await + sync_rollout_summaries_from_memories(&root, &artifact_memories, artifact_memories.len()) + .await { tracing::error!("failed syncing local memory artifacts for global consolidation: {err}"); job::failed(session, db, &claim, "failed_sync_artifacts").await; @@ -95,7 +101,8 @@ pub(super) async fn run(session: &Arc, config: Arc) { } // [`raw_memories.md`] if let Err(err) = - rebuild_raw_memories_file_from_memories(&root, &raw_memories, max_raw_memories).await + rebuild_raw_memories_file_from_memories(&root, &artifact_memories, artifact_memories.len()) + .await { tracing::error!("failed syncing local memory artifacts for global consolidation: {err}"); job::failed(session, db, &claim, "failed_rebuild_raw_memories").await; @@ -103,12 +110,20 @@ pub(super) async fn run(session: &Arc, config: Arc) { } if raw_memories.is_empty() { // We check only after sync of the file system. - job::succeed(session, db, &claim, new_watermark, "succeeded_no_input").await; + job::succeed( + session, + db, + &claim, + new_watermark, + &[], + "succeeded_no_input", + ) + .await; return; } // 5. Spawn the agent - let prompt = agent::get_prompt(config); + let prompt = agent::get_prompt(config, &selection); let source = SessionSource::SubAgent(SubAgentSource::MemoryConsolidation); let thread_id = match session .services @@ -129,6 +144,7 @@ pub(super) async fn run(session: &Arc, config: Arc) { session, claim, new_watermark, + raw_memories.clone(), thread_id, phase_two_e2e_timer, ); @@ -140,6 +156,22 @@ pub(super) async fn run(session: &Arc, config: Arc) { emit_metrics(session, counters); } +fn artifact_memories_for_phase2( + selection: &codex_state::Phase2InputSelection, +) -> Vec { + let mut seen = HashSet::new(); + let mut memories = selection.selected.clone(); + for memory in &selection.selected { + seen.insert(rollout_summary_file_stem(memory)); + } + for memory in &selection.previous_selected { + if seen.insert(rollout_summary_file_stem(memory)) { + memories.push(memory.clone()); + } + } + memories +} + mod job { use super::*; @@ -205,6 +237,7 @@ mod job { db: &StateRuntime, claim: &Claim, completion_watermark: i64, + selected_outputs: &[codex_state::Stage1Output], reason: &'static str, ) { session.services.otel_manager.counter( @@ -213,7 +246,7 @@ mod job { &[("status", reason)], ); let _ = db - .mark_global_phase2_job_succeeded(&claim.token, completion_watermark) + .mark_global_phase2_job_succeeded(&claim.token, completion_watermark, selected_outputs) .await; } } @@ -266,9 +299,12 @@ mod agent { Some(agent_config) } - pub(super) fn get_prompt(config: Arc) -> Vec { + pub(super) fn get_prompt( + config: Arc, + selection: &codex_state::Phase2InputSelection, + ) -> Vec { let root = memory_root(&config.codex_home); - let prompt = build_consolidation_prompt(&root); + let prompt = build_consolidation_prompt(&root, selection); vec![UserInput::Text { text: prompt, text_elements: vec![], @@ -280,6 +316,7 @@ mod agent { session: &Arc, claim: Claim, new_watermark: i64, + selected_outputs: Vec, thread_id: ThreadId, phase_two_e2e_timer: Option, ) { @@ -316,7 +353,15 @@ mod agent { if let Some(token_usage) = agent_control.get_total_token_usage(thread_id).await { emit_token_usage_metrics(&session, &token_usage); } - job::succeed(&session, &db, &claim, new_watermark, "succeeded").await; + job::succeed( + &session, + &db, + &claim, + new_watermark, + &selected_outputs, + "succeeded", + ) + .await; } else { job::failed(&session, &db, &claim, "failed_agent").await; } diff --git a/codex-rs/core/src/memories/prompts.rs b/codex-rs/core/src/memories/prompts.rs index 9b341c1ff..35cfe1edf 100644 --- a/codex-rs/core/src/memories/prompts.rs +++ b/codex-rs/core/src/memories/prompts.rs @@ -1,9 +1,13 @@ use crate::memories::memory_root; use crate::memories::phase_one; +use crate::memories::storage::rollout_summary_file_stem_from_parts; use crate::truncate::TruncationPolicy; use crate::truncate::truncate_text; use askama::Template; use codex_protocol::openai_models::ModelInfo; +use codex_state::Phase2InputSelection; +use codex_state::Stage1Output; +use codex_state::Stage1OutputRef; use std::path::Path; use tokio::fs; use tracing::warn; @@ -12,6 +16,7 @@ use tracing::warn; #[template(path = "memories/consolidation.md", escape = "none")] struct ConsolidationPromptTemplate<'a> { memory_root: &'a str, + phase2_input_selection: &'a str, } #[derive(Template)] @@ -30,17 +35,91 @@ struct MemoryToolDeveloperInstructionsTemplate<'a> { } /// Builds the consolidation subagent prompt for a specific memory root. -pub(super) fn build_consolidation_prompt(memory_root: &Path) -> String { +pub(super) fn build_consolidation_prompt( + memory_root: &Path, + selection: &Phase2InputSelection, +) -> String { let memory_root = memory_root.display().to_string(); + let phase2_input_selection = render_phase2_input_selection(selection); let template = ConsolidationPromptTemplate { memory_root: &memory_root, + phase2_input_selection: &phase2_input_selection, }; template.render().unwrap_or_else(|err| { warn!("failed to render memories consolidation prompt template: {err}"); - format!("## Memory Phase 2 (Consolidation)\nConsolidate Codex memories in: {memory_root}") + format!( + "## Memory Phase 2 (Consolidation)\nConsolidate Codex memories in: {memory_root}\n\n{phase2_input_selection}" + ) }) } +fn render_phase2_input_selection(selection: &Phase2InputSelection) -> String { + let retained = selection.retained_thread_ids.len(); + let added = selection.selected.len().saturating_sub(retained); + let selected = if selection.selected.is_empty() { + "- none".to_string() + } else { + selection + .selected + .iter() + .map(|item| { + render_selected_input_line( + item, + selection.retained_thread_ids.contains(&item.thread_id), + ) + }) + .collect::>() + .join("\n") + }; + let removed = if selection.removed.is_empty() { + "- none".to_string() + } else { + selection + .removed + .iter() + .map(render_removed_input_line) + .collect::>() + .join("\n") + }; + + format!( + "- selected inputs this run: {}\n- newly added since the last successful Phase 2 run: {added}\n- retained from the last successful Phase 2 run: {retained}\n- removed from the last successful Phase 2 run: {}\n\nCurrent selected Phase 1 inputs:\n{selected}\n\nRemoved from the last successful Phase 2 selection:\n{removed}\n", + selection.selected.len(), + selection.removed.len(), + ) +} + +fn render_selected_input_line(item: &Stage1Output, retained: bool) -> String { + let status = if retained { "retained" } else { "added" }; + let rollout_summary_file = format!( + "rollout_summaries/{}.md", + rollout_summary_file_stem_from_parts( + item.thread_id, + item.source_updated_at, + item.rollout_slug.as_deref(), + ) + ); + format!( + "- [{status}] thread_id={}, rollout_summary_file={rollout_summary_file}", + item.thread_id + ) +} + +fn render_removed_input_line(item: &Stage1OutputRef) -> String { + let rollout_summary_file = format!( + "rollout_summaries/{}.md", + rollout_summary_file_stem_from_parts( + item.thread_id, + item.source_updated_at, + item.rollout_slug.as_deref(), + ) + ); + format!( + "- thread_id={}, rollout_summary_file={rollout_summary_file}", + item.thread_id + ) +} + /// Builds the stage-1 user message containing rollout metadata and content. /// /// Large rollout payloads are truncated to 70% of the active model's effective diff --git a/codex-rs/core/templates/memories/consolidation.md b/codex-rs/core/templates/memories/consolidation.md index 000fddedf..085895a69 100644 --- a/codex-rs/core/templates/memories/consolidation.md +++ b/codex-rs/core/templates/memories/consolidation.md @@ -121,6 +121,27 @@ Mode selection: - INCREMENTAL UPDATE: existing artifacts already exist and `raw_memories.md` mostly contains new additions. +Incremental thread diff snapshot (computed before the current artifact sync rewrites local files): + +**Diff since last consolidation:** +{{ phase2_input_selection }} + +Incremental update and forgetting mechanism: +- Use the diff provided +- Do not open raw sessions / original rollout transcripts. +- For each added thread id, search it in `raw_memories.md`, read that raw-memory section, and + read the corresponding `rollout_summaries/*.md` file only when needed for stronger evidence, + task placement, or conflict resolution. +- For each removed thread id, search it in `MEMORY.md` and delete only the memory supported by + that thread. Use `thread_id=` in `### rollout_summary_files` when available; if not, + fall back to rollout summary filenames plus the corresponding `rollout_summaries/*.md` files. +- If a `MEMORY.md` block contains both removed and undeleted threads, do not delete the whole + block. Remove only the removed thread's references and thread-local learnings, preserve shared + or still-supported content, and split or rewrite the block only if needed to keep the undeleted + threads intact. +- After `MEMORY.md` cleanup is done, revisit `memory_summary.md` and remove or rewrite stale + summary/index content that was only supported by removed thread ids. + Outputs: Under `{{ memory_root }}/`: A) `MEMORY.md` @@ -498,27 +519,42 @@ WORKFLOW conflicting task families until MEMORY blocks are richer and more useful than raw memories 3) INCREMENTAL UPDATE behavior: - - Treat `raw_memories.md` as the primary source of NEW signal. - - Read existing memory files first for continuity. + - Read existing `MEMORY.md` and `memory_summary.md` first for continuity and to locate + existing references that may need surgical cleanup. + - Use the injected thread-diff snapshot as the first routing pass: + - added thread ids = ingestion queue + - removed thread ids = forgetting / stale-cleanup queue - Build an index of rollout references already present in existing `MEMORY.md` before scanning raw memories so you can route net-new evidence into the right blocks. - - Compute net-new candidates from the raw-memory inventory (threads / rollout summaries / - updated evidence not already represented in `MEMORY.md`). + - Work in this order: + 1. For newly added thread ids, search them in `raw_memories.md`, read those sections, and + open the corresponding `rollout_summaries/*.md` files when necessary. + 2. Route the new signal into existing `MEMORY.md` blocks or create new ones when needed. + 3. For removed thread ids, search `MEMORY.md` and surgically delete or rewrite only the + unsupported thread-local memory. + 4. If a block mixes removed and undeleted threads, preserve the undeleted-thread content; + split or rewrite the block if that is the cleanest way to delete only the removed part. + 5. After `MEMORY.md` is correct, revisit `memory_summary.md` and remove or rewrite stale + summary/index content that no longer has undeleted support. - Integrate new signal into existing artifacts by: - - scanning new raw memories in recency order and identifying which existing blocks they should update + - scanning the newly added raw-memory entries in recency order and identifying which existing blocks they should update - updating existing knowledge with better/newer evidence - updating stale or contradicting guidance + - pruning or downgrading memory whose only provenance comes from removed thread ids - expanding terse old blocks when new summaries/raw memories make the task family clearer - doing light clustering and merging if needed - refreshing `MEMORY.md` top-of-file ordering so recent high-utility task families stay easy to find - rebuilding the `memory_summary.md` recent active window (last 3 memory days) from current `updated_at` coverage - updating existing skills or adding new skills only when there is clear new reusable procedure - - update `memory_summary.md` last to reflect the final state of the memory folder + - updating `memory_summary.md` last to reflect the final state of the memory folder - Minimize churn in incremental mode: if an existing `MEMORY.md` block or `## What's in Memory` topic still reflects the current evidence and points to the same task family / retrieval target, keep its wording, label, and relative order mostly stable. Rewrite/reorder/rename/ split/merge only when fixing a real problem (staleness, ambiguity, schema drift, wrong boundaries) or when meaningful new evidence materially improves retrieval clarity/searchability. + - Spend most of your deep-dive budget on newly added thread ids and on mixed blocks touched by + removed thread ids. Do not re-read unchanged older threads unless you need them for + conflict resolution, clustering, or provenance repair. 4) Evidence deep-dive rule (both modes): - `raw_memories.md` is the routing layer, not always the final authority for detail. @@ -529,6 +565,9 @@ WORKFLOW - When a task family is important, ambiguous, or duplicated across multiple rollouts, open the relevant `rollout_summaries/*.md` files and extract richer procedural detail, validation signals, and user feedback before finalizing `MEMORY.md`. + - When deleting stale memory from a mixed block, use the relevant rollout summaries to decide + which details are uniquely supported by removed threads versus still supported by undeleted + threads. - Use `updated_at` and validation strength together to resolve stale/conflicting notes. 5) For both modes, update `MEMORY.md` after skill updates: @@ -542,6 +581,8 @@ WORKFLOW 7) Final pass: - remove duplication in memory_summary, skills/, and MEMORY.md - remove stale or low-signal blocks that are less likely to be useful in the future + - remove or rewrite blocks/task sections whose supporting rollout references point only to + removed thread ids or missing rollout summary files - run a global rollout-reference audit on final `MEMORY.md` and fix accidental duplicate entries / redundant repetition, while preserving intentional multi-task or multi-block reuse when it adds distinct task-local value @@ -560,16 +601,3 @@ WORKFLOW You should dive deep and make sure you didn't miss any important information that might be useful for future agents; do not be superficial. - -============================================================ -SEARCH / REVIEW COMMANDS (RG-FIRST) -============================================================ - -Use `rg` for fast retrieval while consolidating: - -- Search durable notes: - `rg -n -i "" "{{ memory_root }}/MEMORY.md"` -- Search across memory tree: - `rg -n -i "" "{{ memory_root }}" | head -n 100` -- Locate rollout summary files: - `rg --files "{{ memory_root }}/rollout_summaries" | head -n 400` diff --git a/codex-rs/core/tests/suite/memories.rs b/codex-rs/core/tests/suite/memories.rs new file mode 100644 index 000000000..fc46c9df1 --- /dev/null +++ b/codex-rs/core/tests/suite/memories.rs @@ -0,0 +1,282 @@ +use anyhow::Result; +use chrono::Duration as ChronoDuration; +use chrono::Utc; +use codex_core::features::Feature; +use codex_protocol::ThreadId; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::Op; +use codex_protocol::protocol::SessionSource; +use core_test_support::responses::ResponseMock; +use core_test_support::responses::ResponsesRequest; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::mount_sse_once; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use pretty_assertions::assert_eq; +use std::path::Path; +use std::sync::Arc; +use tempfile::TempDir; +use tokio::time::Duration; +use tokio::time::Instant; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() -> Result<()> { + let server = start_mock_server().await; + let home = Arc::new(TempDir::new()?); + let db = init_state_db(&home).await?; + + let now = Utc::now(); + let thread_a = seed_stage1_output( + db.as_ref(), + home.path(), + now - ChronoDuration::hours(2), + "raw memory A", + "rollout summary A", + "rollout-a", + ) + .await?; + + let first_phase2 = mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-phase2-1"), + ev_assistant_message("msg-phase2-1", "phase2 complete"), + ev_completed("resp-phase2-1"), + ]), + ) + .await; + + let first = build_test_codex(&server, home.clone()).await?; + let first_request = wait_for_single_request(&first_phase2).await; + let first_prompt = phase2_prompt_text(&first_request); + assert!( + first_prompt.contains("- selected inputs this run: 1"), + "expected selected count in first prompt: {first_prompt}" + ); + assert!( + first_prompt.contains("- newly added since the last successful Phase 2 run: 1"), + "expected added count in first prompt: {first_prompt}" + ); + assert!( + first_prompt.contains("- removed from the last successful Phase 2 run: 0"), + "expected removed count in first prompt: {first_prompt}" + ); + assert!( + first_prompt.contains(&format!("- [added] thread_id={thread_a},")), + "expected thread A to be marked added: {first_prompt}" + ); + assert!( + first_prompt.contains("Removed from the last successful Phase 2 selection:\n- none"), + "expected no removed items in first prompt: {first_prompt}" + ); + + wait_for_phase2_success(db.as_ref(), thread_a).await?; + let memory_root = home.path().join("memories"); + let raw_memories = tokio::fs::read_to_string(memory_root.join("raw_memories.md")).await?; + assert!(raw_memories.contains("raw memory A")); + assert!(!raw_memories.contains("raw memory B")); + let rollout_summaries = read_rollout_summary_bodies(&memory_root).await?; + assert_eq!(rollout_summaries.len(), 1); + assert!(rollout_summaries[0].contains("rollout summary A")); + + shutdown_test_codex(&first).await?; + + let thread_b = seed_stage1_output( + db.as_ref(), + home.path(), + now - ChronoDuration::hours(1), + "raw memory B", + "rollout summary B", + "rollout-b", + ) + .await?; + + let second_phase2 = mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-phase2-2"), + ev_assistant_message("msg-phase2-2", "phase2 complete"), + ev_completed("resp-phase2-2"), + ]), + ) + .await; + + let second = build_test_codex(&server, home.clone()).await?; + let second_request = wait_for_single_request(&second_phase2).await; + let second_prompt = phase2_prompt_text(&second_request); + assert!( + second_prompt.contains("- selected inputs this run: 1"), + "expected selected count in second prompt: {second_prompt}" + ); + assert!( + second_prompt.contains("- newly added since the last successful Phase 2 run: 1"), + "expected added count in second prompt: {second_prompt}" + ); + assert!( + second_prompt.contains("- removed from the last successful Phase 2 run: 1"), + "expected removed count in second prompt: {second_prompt}" + ); + assert!( + second_prompt.contains(&format!("- [added] thread_id={thread_b},")), + "expected thread B to be marked added: {second_prompt}" + ); + assert!( + second_prompt.contains(&format!("- thread_id={thread_a},")), + "expected thread A to be marked removed: {second_prompt}" + ); + + wait_for_phase2_success(db.as_ref(), thread_b).await?; + let raw_memories = tokio::fs::read_to_string(memory_root.join("raw_memories.md")).await?; + assert!(raw_memories.contains("raw memory B")); + assert!(raw_memories.contains("raw memory A")); + let rollout_summaries = read_rollout_summary_bodies(&memory_root).await?; + assert_eq!(rollout_summaries.len(), 2); + assert!( + rollout_summaries + .iter() + .any(|summary| summary.contains("rollout summary B")) + ); + assert!( + rollout_summaries + .iter() + .any(|summary| summary.contains("rollout summary A")) + ); + + shutdown_test_codex(&second).await?; + Ok(()) +} + +async fn build_test_codex(server: &wiremock::MockServer, home: Arc) -> Result { + let mut builder = test_codex().with_home(home).with_config(|config| { + config.features.enable(Feature::Sqlite); + config.features.enable(Feature::MemoryTool); + config.memories.max_raw_memories_for_global = 1; + }); + builder.build(server).await +} + +async fn init_state_db(home: &Arc) -> Result> { + let db = + codex_state::StateRuntime::init(home.path().to_path_buf(), "test-provider".into(), None) + .await?; + db.mark_backfill_complete(None).await?; + Ok(db) +} + +async fn seed_stage1_output( + db: &codex_state::StateRuntime, + codex_home: &Path, + updated_at: chrono::DateTime, + raw_memory: &str, + rollout_summary: &str, + rollout_slug: &str, +) -> Result { + let thread_id = ThreadId::new(); + let mut metadata_builder = codex_state::ThreadMetadataBuilder::new( + thread_id, + codex_home.join(format!("rollout-{thread_id}.jsonl")), + updated_at, + SessionSource::Cli, + ); + metadata_builder.cwd = codex_home.join(format!("workspace-{rollout_slug}")); + metadata_builder.model_provider = Some("test-provider".to_string()); + let metadata = metadata_builder.build("test-provider"); + db.upsert_thread(&metadata).await?; + + let claim = db + .try_claim_stage1_job( + thread_id, + ThreadId::new(), + updated_at.timestamp(), + 3_600, + 64, + ) + .await?; + let ownership_token = match claim { + codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage-1 claim outcome: {other:?}"), + }; + + assert!( + db.mark_stage1_job_succeeded( + thread_id, + &ownership_token, + updated_at.timestamp(), + raw_memory, + rollout_summary, + Some(rollout_slug), + ) + .await?, + "stage-1 success should enqueue global consolidation" + ); + + Ok(thread_id) +} + +async fn wait_for_single_request(mock: &ResponseMock) -> ResponsesRequest { + let deadline = Instant::now() + Duration::from_secs(10); + loop { + let requests = mock.requests(); + if let Some(request) = requests.into_iter().next() { + return request; + } + assert!( + Instant::now() < deadline, + "timed out waiting for phase2 request" + ); + tokio::time::sleep(Duration::from_millis(50)).await; + } +} + +#[allow(clippy::expect_used)] +fn phase2_prompt_text(request: &ResponsesRequest) -> String { + request + .message_input_texts("user") + .into_iter() + .find(|text| text.contains("Current selected Phase 1 inputs:")) + .expect("phase2 prompt text") +} + +async fn wait_for_phase2_success( + db: &codex_state::StateRuntime, + expected_thread_id: ThreadId, +) -> Result<()> { + let deadline = Instant::now() + Duration::from_secs(10); + loop { + let selection = db.get_phase2_input_selection(1).await?; + if selection.selected.len() == 1 + && selection.selected[0].thread_id == expected_thread_id + && selection.retained_thread_ids == vec![expected_thread_id] + && selection.removed.is_empty() + { + return Ok(()); + } + + assert!( + Instant::now() < deadline, + "timed out waiting for phase2 success for {expected_thread_id}" + ); + tokio::time::sleep(Duration::from_millis(50)).await; + } +} + +async fn read_rollout_summary_bodies(memory_root: &Path) -> Result> { + let mut dir = tokio::fs::read_dir(memory_root.join("rollout_summaries")).await?; + let mut summaries = Vec::new(); + while let Some(entry) = dir.next_entry().await? { + summaries.push(tokio::fs::read_to_string(entry.path()).await?); + } + summaries.sort(); + Ok(summaries) +} + +async fn shutdown_test_codex(test: &TestCodex) -> Result<()> { + test.codex.submit(Op::Shutdown {}).await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await; + Ok(()) +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 1509428f3..fd20fc6f8 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -83,6 +83,7 @@ mod json_result; mod list_dir; mod live_cli; mod live_reload; +mod memories; mod model_info_overrides; mod model_overrides; mod model_switching; diff --git a/codex-rs/state/migrations/0018_phase2_selection_snapshot.sql b/codex-rs/state/migrations/0018_phase2_selection_snapshot.sql new file mode 100644 index 000000000..f175980fc --- /dev/null +++ b/codex-rs/state/migrations/0018_phase2_selection_snapshot.sql @@ -0,0 +1,3 @@ +ALTER TABLE stage1_outputs +ADD COLUMN selected_for_phase2_source_updated_at INTEGER; +ALTER TABLE threads ADD COLUMN memory_mode TEXT NOT NULL DEFAULT 'enabled'; diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index 59607f3d9..a6d1deeb8 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -14,6 +14,7 @@ mod runtime; pub use model::LogEntry; pub use model::LogQuery; pub use model::LogRow; +pub use model::Phase2InputSelection; pub use model::Phase2JobClaimOutcome; /// Preferred entrypoint: owns configuration and metrics. pub use runtime::StateRuntime; @@ -38,6 +39,7 @@ pub use model::SortKey; pub use model::Stage1JobClaim; pub use model::Stage1JobClaimOutcome; pub use model::Stage1Output; +pub use model::Stage1OutputRef; pub use model::Stage1StartupClaimParams; pub use model::ThreadMetadata; pub use model::ThreadMetadataBuilder; diff --git a/codex-rs/state/src/model/memories.rs b/codex-rs/state/src/model/memories.rs index fc2468d83..6c88d7360 100644 --- a/codex-rs/state/src/model/memories.rs +++ b/codex-rs/state/src/model/memories.rs @@ -21,6 +21,21 @@ pub struct Stage1Output { pub generated_at: DateTime, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Stage1OutputRef { + pub thread_id: ThreadId, + pub source_updated_at: DateTime, + pub rollout_slug: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Phase2InputSelection { + pub selected: Vec, + pub previous_selected: Vec, + pub retained_thread_ids: Vec, + pub removed: Vec, +} + #[derive(Debug)] pub(crate) struct Stage1OutputRow { thread_id: String, @@ -70,6 +85,18 @@ fn epoch_seconds_to_datetime(secs: i64) -> Result> { .ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}")) } +pub(crate) fn stage1_output_ref_from_parts( + thread_id: String, + source_updated_at: i64, + rollout_slug: Option, +) -> Result { + Ok(Stage1OutputRef { + thread_id: ThreadId::try_from(thread_id)?, + source_updated_at: epoch_seconds_to_datetime(source_updated_at)?, + rollout_slug, + }) +} + /// Result of trying to claim a stage-1 memory extraction job. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Stage1JobClaimOutcome { diff --git a/codex-rs/state/src/model/mod.rs b/codex-rs/state/src/model/mod.rs index 816c036f8..efaf3f787 100644 --- a/codex-rs/state/src/model/mod.rs +++ b/codex-rs/state/src/model/mod.rs @@ -16,10 +16,12 @@ pub use backfill_state::BackfillStatus; pub use log::LogEntry; pub use log::LogQuery; pub use log::LogRow; +pub use memories::Phase2InputSelection; pub use memories::Phase2JobClaimOutcome; pub use memories::Stage1JobClaim; pub use memories::Stage1JobClaimOutcome; pub use memories::Stage1Output; +pub use memories::Stage1OutputRef; pub use memories::Stage1StartupClaimParams; pub use thread_metadata::Anchor; pub use thread_metadata::BackfillStats; @@ -32,6 +34,7 @@ pub use thread_metadata::ThreadsPage; pub(crate) use agent_job::AgentJobItemRow; pub(crate) use agent_job::AgentJobRow; pub(crate) use memories::Stage1OutputRow; +pub(crate) use memories::stage1_output_ref_from_parts; pub(crate) use thread_metadata::ThreadRow; pub(crate) use thread_metadata::anchor_from_item; pub(crate) use thread_metadata::datetime_to_epoch_seconds; diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index f0c623050..b7ee2ac88 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -2773,7 +2773,11 @@ WHERE kind = 'memory_stage1' assert_eq!(phase2_input_watermark, 100); assert!( runtime - .mark_global_phase2_job_succeeded(phase2_token.as_str(), phase2_input_watermark) + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + phase2_input_watermark, + &[], + ) .await .expect("mark initial phase2 succeeded"), "initial phase2 success should clear global dirty state" @@ -2819,7 +2823,11 @@ WHERE kind = 'memory_stage1' assert_eq!(phase2_input_watermark, 101); assert!( runtime - .mark_global_phase2_job_succeeded(phase2_token.as_str(), phase2_input_watermark) + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + phase2_input_watermark, + &[], + ) .await .expect("mark phase2 succeeded after no-output delete") ); @@ -2936,7 +2944,7 @@ WHERE kind = 'memory_stage1' }; assert!( runtime - .mark_global_phase2_job_succeeded(ownership_token.as_str(), input_watermark) + .mark_global_phase2_job_succeeded(ownership_token.as_str(), input_watermark, &[],) .await .expect("mark phase2 succeeded"), "phase2 success should finalize for current token" @@ -3124,6 +3132,646 @@ VALUES (?, ?, ?, ?, ?) let _ = tokio::fs::remove_dir_all(codex_home).await; } + #[tokio::test] + async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + for (thread_id, workspace) in [ + (thread_id_a, "workspace-a"), + (thread_id_b, "workspace-b"), + (thread_id_c, "workspace-c"), + ] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(workspace), + )) + .await + .expect("upsert thread"); + } + + for (thread_id, updated_at, slug) in [ + (thread_id_a, 100, Some("rollout-a")), + (thread_id_b, 101, Some("rollout-b")), + (thread_id_c, 102, Some("rollout-c")), + ] { + let claim = runtime + .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) + .await + .expect("claim stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + updated_at, + &format!("raw-{updated_at}"), + &format!("summary-{updated_at}"), + slug, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + } + + let claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (ownership_token, input_watermark) = match claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + assert_eq!(input_watermark, 102); + let selected_outputs = runtime + .list_stage1_outputs_for_global(10) + .await + .expect("list stage1 outputs for global") + .into_iter() + .filter(|output| output.thread_id == thread_id_c || output.thread_id == thread_id_a) + .collect::>(); + assert!( + runtime + .mark_global_phase2_job_succeeded( + ownership_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success with selection"), + "phase2 success should persist selected rows" + ); + + let selection = runtime + .get_phase2_input_selection(2) + .await + .expect("load phase2 input selection"); + + assert_eq!(selection.selected.len(), 2); + assert_eq!(selection.previous_selected.len(), 2); + assert_eq!(selection.selected[0].thread_id, thread_id_c); + assert_eq!( + selection.selected[0].rollout_path, + codex_home.join(format!("rollout-{thread_id_c}.jsonl")) + ); + assert_eq!(selection.selected[1].thread_id, thread_id_b); + assert_eq!(selection.retained_thread_ids, vec![thread_id_c]); + + assert_eq!(selection.removed.len(), 1); + assert_eq!(selection.removed[0].thread_id, thread_id_a); + assert_eq!( + selection.removed[0].rollout_slug.as_deref(), + Some("rollout-a") + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let first_claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim initial stage1"); + let first_token = match first_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + first_token.as_str(), + 100, + "raw-100", + "summary-100", + Some("rollout-100"), + ) + .await + .expect("mark initial stage1 success"), + "initial stage1 success should persist output" + ); + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list selected outputs"); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should persist selected rows" + ); + + let refreshed_claim = runtime + .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) + .await + .expect("claim refreshed stage1"); + let refreshed_token = match refreshed_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + refreshed_token.as_str(), + 101, + "raw-101", + "summary-101", + Some("rollout-101"), + ) + .await + .expect("mark refreshed stage1 success"), + "refreshed stage1 success should persist output" + ); + + let selection = runtime + .get_phase2_input_selection(1) + .await + .expect("load phase2 input selection"); + assert_eq!(selection.selected.len(), 1); + assert_eq!(selection.previous_selected.len(), 1); + assert_eq!(selection.selected[0].thread_id, thread_id); + assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101); + assert!(selection.retained_thread_ids.is_empty()); + assert!(selection.removed.is_empty()); + + let (selected_for_phase2, selected_for_phase2_source_updated_at) = + sqlx::query_as::<_, (i64, Option)>( + "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", + ) + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load selected_for_phase2"); + assert_eq!(selected_for_phase2, 1); + assert_eq!(selected_for_phase2_source_updated_at, Some(100)); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_phase2_input_selection_reports_regenerated_previous_selection_as_removed() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread a"); + let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread b"); + let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread c"); + let thread_id_d = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread d"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + for (thread_id, workspace) in [ + (thread_id_a, "workspace-a"), + (thread_id_b, "workspace-b"), + (thread_id_c, "workspace-c"), + (thread_id_d, "workspace-d"), + ] { + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join(workspace), + )) + .await + .expect("upsert thread"); + } + + for (thread_id, updated_at, slug) in [ + (thread_id_a, 100, Some("rollout-a-100")), + (thread_id_b, 101, Some("rollout-b-101")), + (thread_id_c, 99, Some("rollout-c-99")), + (thread_id_d, 98, Some("rollout-d-98")), + ] { + let claim = runtime + .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) + .await + .expect("claim initial stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + updated_at, + &format!("raw-{updated_at}"), + &format!("summary-{updated_at}"), + slug, + ) + .await + .expect("mark stage1 succeeded"), + "stage1 success should persist output" + ); + } + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(2) + .await + .expect("list selected outputs"); + assert_eq!( + selected_outputs + .iter() + .map(|output| output.thread_id) + .collect::>(), + vec![thread_id_b, thread_id_a] + ); + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should persist selected rows" + ); + + for (thread_id, updated_at, slug) in [ + (thread_id_a, 102, Some("rollout-a-102")), + (thread_id_c, 103, Some("rollout-c-103")), + (thread_id_d, 104, Some("rollout-d-104")), + ] { + let claim = runtime + .try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64) + .await + .expect("claim refreshed stage1"); + let ownership_token = match claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + ownership_token.as_str(), + updated_at, + &format!("raw-{updated_at}"), + &format!("summary-{updated_at}"), + slug, + ) + .await + .expect("mark refreshed stage1 success"), + "refreshed stage1 success should persist output" + ); + } + + let selection = runtime + .get_phase2_input_selection(2) + .await + .expect("load phase2 input selection"); + assert_eq!( + selection + .selected + .iter() + .map(|output| output.thread_id) + .collect::>(), + vec![thread_id_d, thread_id_c] + ); + assert_eq!( + selection + .previous_selected + .iter() + .map(|output| output.thread_id) + .collect::>(), + vec![thread_id_a, thread_id_b] + ); + assert!(selection.retained_thread_ids.is_empty()); + assert_eq!( + selection + .removed + .iter() + .map(|output| (output.thread_id, output.source_updated_at.timestamp())) + .collect::>(), + vec![(thread_id_a, 102), (thread_id_b, 101)] + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_global_phase2_job_succeeded_updates_selected_snapshot_timestamp() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let initial_claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim initial stage1"); + let initial_token = match initial_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + initial_token.as_str(), + 100, + "raw-100", + "summary-100", + Some("rollout-100"), + ) + .await + .expect("mark initial stage1 success"), + "initial stage1 success should persist output" + ); + + let first_phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim first phase2"); + let (first_phase2_token, first_input_watermark) = match first_phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected first phase2 claim outcome: {other:?}"), + }; + let first_selected_outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list first selected outputs"); + assert!( + runtime + .mark_global_phase2_job_succeeded( + first_phase2_token.as_str(), + first_input_watermark, + &first_selected_outputs, + ) + .await + .expect("mark first phase2 success"), + "first phase2 success should persist selected rows" + ); + + let refreshed_claim = runtime + .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) + .await + .expect("claim refreshed stage1"); + let refreshed_token = match refreshed_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected refreshed stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + refreshed_token.as_str(), + 101, + "raw-101", + "summary-101", + Some("rollout-101"), + ) + .await + .expect("mark refreshed stage1 success"), + "refreshed stage1 success should persist output" + ); + + let second_phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim second phase2"); + let (second_phase2_token, second_input_watermark) = match second_phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected second phase2 claim outcome: {other:?}"), + }; + let second_selected_outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list second selected outputs"); + assert_eq!( + second_selected_outputs[0].source_updated_at.timestamp(), + 101 + ); + assert!( + runtime + .mark_global_phase2_job_succeeded( + second_phase2_token.as_str(), + second_input_watermark, + &second_selected_outputs, + ) + .await + .expect("mark second phase2 success"), + "second phase2 success should persist selected rows" + ); + + let selection = runtime + .get_phase2_input_selection(1) + .await + .expect("load phase2 input selection after refresh"); + assert_eq!(selection.retained_thread_ids, vec![thread_id]); + + let (selected_for_phase2, selected_for_phase2_source_updated_at) = + sqlx::query_as::<_, (i64, Option)>( + "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", + ) + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load selected snapshot after phase2"); + assert_eq!(selected_for_phase2, 1); + assert_eq!(selected_for_phase2_source_updated_at, Some(101)); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn mark_global_phase2_job_succeeded_only_marks_exact_selected_snapshots() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_id, + codex_home.join("workspace"), + )) + .await + .expect("upsert thread"); + + let initial_claim = runtime + .try_claim_stage1_job(thread_id, owner, 100, 3600, 64) + .await + .expect("claim initial stage1"); + let initial_token = match initial_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + initial_token.as_str(), + 100, + "raw-100", + "summary-100", + Some("rollout-100"), + ) + .await + .expect("mark initial stage1 success"), + "initial stage1 success should persist output" + ); + + let phase2_claim = runtime + .try_claim_global_phase2_job(owner, 3600) + .await + .expect("claim phase2"); + let (phase2_token, input_watermark) = match phase2_claim { + Phase2JobClaimOutcome::Claimed { + ownership_token, + input_watermark, + } => (ownership_token, input_watermark), + other => panic!("unexpected phase2 claim outcome: {other:?}"), + }; + let selected_outputs = runtime + .list_stage1_outputs_for_global(1) + .await + .expect("list selected outputs"); + assert_eq!(selected_outputs[0].source_updated_at.timestamp(), 100); + + let refreshed_claim = runtime + .try_claim_stage1_job(thread_id, owner, 101, 3600, 64) + .await + .expect("claim refreshed stage1"); + let refreshed_token = match refreshed_claim { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded( + thread_id, + refreshed_token.as_str(), + 101, + "raw-101", + "summary-101", + Some("rollout-101"), + ) + .await + .expect("mark refreshed stage1 success"), + "refreshed stage1 success should persist output" + ); + + assert!( + runtime + .mark_global_phase2_job_succeeded( + phase2_token.as_str(), + input_watermark, + &selected_outputs, + ) + .await + .expect("mark phase2 success"), + "phase2 success should still complete" + ); + + let (selected_for_phase2, selected_for_phase2_source_updated_at) = + sqlx::query_as::<_, (i64, Option)>( + "SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?", + ) + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load selected_for_phase2"); + assert_eq!(selected_for_phase2, 0); + assert_eq!(selected_for_phase2_source_updated_at, None); + + let selection = runtime + .get_phase2_input_selection(1) + .await + .expect("load phase2 input selection"); + assert_eq!(selection.selected.len(), 1); + assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101); + assert!(selection.retained_thread_ids.is_empty()); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + #[tokio::test] async fn record_stage1_output_usage_updates_usage_metadata() { let codex_home = unique_temp_dir(); @@ -3395,7 +4043,7 @@ VALUES (?, ?, ?, ?, ?) assert_eq!( runtime - .mark_global_phase2_job_succeeded(token_a.as_str(), 300) + .mark_global_phase2_job_succeeded(token_a.as_str(), 300, &[]) .await .expect("mark stale owner success result"), false, @@ -3403,7 +4051,7 @@ VALUES (?, ?, ?, ?, ?) ); assert!( runtime - .mark_global_phase2_job_succeeded(token_b.as_str(), 300) + .mark_global_phase2_job_succeeded(token_b.as_str(), 300, &[]) .await .expect("mark takeover owner success"), "takeover owner should finalize consolidation" @@ -3440,7 +4088,7 @@ VALUES (?, ?, ?, ?, ?) }; assert!( runtime - .mark_global_phase2_job_succeeded(token_a.as_str(), 500) + .mark_global_phase2_job_succeeded(token_a.as_str(), 500, &[]) .await .expect("mark initial phase2 success"), "initial phase2 success should finalize" diff --git a/codex-rs/state/src/runtime/memories.rs b/codex-rs/state/src/runtime/memories.rs index c6c95aa1c..859ad0d14 100644 --- a/codex-rs/state/src/runtime/memories.rs +++ b/codex-rs/state/src/runtime/memories.rs @@ -1,4 +1,5 @@ use super::*; +use crate::model::Phase2InputSelection; use crate::model::Phase2JobClaimOutcome; use crate::model::Stage1JobClaim; use crate::model::Stage1JobClaimOutcome; @@ -6,10 +7,12 @@ use crate::model::Stage1Output; use crate::model::Stage1OutputRow; use crate::model::Stage1StartupClaimParams; use crate::model::ThreadRow; +use crate::model::stage1_output_ref_from_parts; use chrono::Duration; use sqlx::Executor; use sqlx::QueryBuilder; use sqlx::Sqlite; +use std::collections::HashSet; const JOB_KIND_MEMORY_STAGE1: &str = "memory_stage1"; const JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL: &str = "memory_consolidate_global"; @@ -257,6 +260,117 @@ LIMIT ? .collect::, _>>() } + /// Returns the current phase-2 input set along with its diff against the + /// last successful phase-2 selection. + /// + /// Query behavior: + /// - current selection is the latest `n` non-empty stage-1 outputs ordered + /// by `source_updated_at DESC, thread_id DESC` + /// - previously selected rows are identified by `selected_for_phase2 = 1` + /// - `previous_selected` contains the current persisted rows that belonged + /// to the last successful phase-2 baseline + /// - `retained_thread_ids` records which current rows still match the exact + /// snapshot selected in the last successful phase-2 run + /// - removed rows are previously selected rows that are still present in + /// `stage1_outputs` but fall outside the current top-`n` selection + pub async fn get_phase2_input_selection( + &self, + n: usize, + ) -> anyhow::Result { + if n == 0 { + return Ok(Phase2InputSelection::default()); + } + + let current_rows = sqlx::query( + r#" +SELECT + so.thread_id, + COALESCE(t.rollout_path, '') AS rollout_path, + so.source_updated_at, + so.raw_memory, + so.rollout_summary, + so.rollout_slug, + so.generated_at, + so.selected_for_phase2, + so.selected_for_phase2_source_updated_at, + COALESCE(t.cwd, '') AS cwd +FROM stage1_outputs AS so +LEFT JOIN threads AS t + ON t.id = so.thread_id +WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0 +ORDER BY so.source_updated_at DESC, so.thread_id DESC +LIMIT ? + "#, + ) + .bind(n as i64) + .fetch_all(self.pool.as_ref()) + .await?; + + let mut current_thread_ids = HashSet::with_capacity(current_rows.len()); + let mut selected = Vec::with_capacity(current_rows.len()); + let mut retained_thread_ids = Vec::new(); + for row in current_rows { + let thread_id = row.try_get::("thread_id")?; + current_thread_ids.insert(thread_id.clone()); + let source_updated_at = row.try_get::("source_updated_at")?; + if row.try_get::("selected_for_phase2")? != 0 + && row.try_get::, _>("selected_for_phase2_source_updated_at")? + == Some(source_updated_at) + { + retained_thread_ids.push(ThreadId::try_from(thread_id.clone())?); + } + selected.push(Stage1Output::try_from(Stage1OutputRow::try_from_row( + &row, + )?)?); + } + + let previous_rows = sqlx::query( + r#" +SELECT + so.thread_id, + COALESCE(t.rollout_path, '') AS rollout_path, + so.source_updated_at, + so.raw_memory, + so.rollout_summary, + so.rollout_slug + , so.generated_at + , COALESCE(t.cwd, '') AS cwd +FROM stage1_outputs AS so +LEFT JOIN threads AS t + ON t.id = so.thread_id +WHERE so.selected_for_phase2 = 1 +ORDER BY so.source_updated_at DESC, so.thread_id DESC + "#, + ) + .fetch_all(self.pool.as_ref()) + .await?; + + let previous_selected = previous_rows + .iter() + .map(Stage1OutputRow::try_from_row) + .map(|row| row.and_then(Stage1Output::try_from)) + .collect::, _>>()?; + let mut removed = Vec::new(); + for row in previous_rows { + let thread_id = row.try_get::("thread_id")?; + if current_thread_ids.contains(thread_id.as_str()) { + continue; + } + removed.push(stage1_output_ref_from_parts( + thread_id, + row.try_get("source_updated_at")?, + row.try_get("rollout_slug")?, + )?); + } + + Ok(Phase2InputSelection { + selected, + previous_selected, + retained_thread_ids, + removed, + }) + } + /// Attempts to claim a stage-1 job for a thread at `source_updated_at`. /// /// Claim semantics: @@ -454,6 +568,9 @@ WHERE kind = ? AND job_key = ? /// - sets `status='done'` and `last_success_watermark = input_watermark` /// - upserts `stage1_outputs` for the thread, replacing existing output only /// when `source_updated_at` is newer or equal + /// - preserves any existing `selected_for_phase2` baseline until the next + /// successful phase-2 run rewrites the baseline selection, including the + /// snapshot timestamp chosen during that run /// - persists optional `rollout_slug` for rollout summary artifact naming /// - enqueues/advances the global phase-2 job watermark using /// `source_updated_at` @@ -806,12 +923,18 @@ WHERE kind = ? AND job_key = ? /// - sets `status='done'`, clears lease/errors /// - advances `last_success_watermark` to /// `max(existing_last_success_watermark, completed_watermark)` + /// - rewrites `selected_for_phase2` so only the exact selected stage-1 + /// snapshots remain marked as part of the latest successful phase-2 + /// selection, and persists each selected snapshot's + /// `source_updated_at` for future retained-vs-added diffing pub async fn mark_global_phase2_job_succeeded( &self, ownership_token: &str, completed_watermark: i64, + selected_outputs: &[Stage1Output], ) -> anyhow::Result { let now = Utc::now().timestamp(); + let mut tx = self.pool.begin().await?; let rows_affected = sqlx::query( r#" UPDATE jobs @@ -830,11 +953,46 @@ WHERE kind = ? AND job_key = ? .bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL) .bind(MEMORY_CONSOLIDATION_JOB_KEY) .bind(ownership_token) - .execute(self.pool.as_ref()) + .execute(&mut *tx) .await? .rows_affected(); - Ok(rows_affected > 0) + if rows_affected == 0 { + tx.commit().await?; + return Ok(false); + } + + sqlx::query( + r#" +UPDATE stage1_outputs +SET + selected_for_phase2 = 0, + selected_for_phase2_source_updated_at = NULL +WHERE selected_for_phase2 != 0 OR selected_for_phase2_source_updated_at IS NOT NULL + "#, + ) + .execute(&mut *tx) + .await?; + + for output in selected_outputs { + sqlx::query( + r#" +UPDATE stage1_outputs +SET + selected_for_phase2 = 1, + selected_for_phase2_source_updated_at = ? +WHERE thread_id = ? AND source_updated_at = ? + "#, + ) + .bind(output.source_updated_at.timestamp()) + .bind(output.thread_id.to_string()) + .bind(output.source_updated_at.timestamp()) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + Ok(true) } /// Marks the owned running global phase-2 job as failed and schedules retry.