feat: memories forgetting (#12900)

Add diff based memory forgetting
This commit is contained in:
jif-oai 2026-02-26 13:19:57 +00:00 committed by GitHub
parent 81ce645733
commit 382fa338b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1335 additions and 39 deletions

View file

@ -70,11 +70,31 @@ What it does:
If there is input, it then:
- spawns an internal consolidation sub-agent
- builds the Phase 2 prompt with a diff of the current Phase 1 input
selection versus the last successful Phase 2 selection (`added`,
`retained`, `removed`)
- runs it with no approvals, no network, and local write access only
- disables collab for that agent (to prevent recursive delegation)
- watches the agent status and heartbeats the global job lease while it runs
- marks the phase-2 job success/failure in the state DB when the agent finishes
Selection diff behavior:
- successful Phase 2 runs mark the exact stage-1 snapshots they consumed with
`selected_for_phase2 = 1` and persist the matching
`selected_for_phase2_source_updated_at`
- Phase 1 upserts preserve the previous `selected_for_phase2` baseline until
the next successful Phase 2 run rewrites it
- the next Phase 2 run compares the current top-N stage-1 inputs against that
prior snapshot selection to label inputs as `added` or `retained`; a
refreshed thread stays `added` until Phase 2 successfully selects its newer
snapshot
- rows that were previously selected but still exist outside the current top-N
selection are surfaced as `removed`
- before the agent starts, local `rollout_summaries/` and `raw_memories.md`
keep the union of the current selection and the previous successful
selection, so removed-thread evidence stays available during forgetting
Watermark behavior:
- The global phase-2 job claim includes an input watermark representing the latest input timestamp known when the job was claimed.

View file

@ -8,6 +8,7 @@ use crate::memories::metrics;
use crate::memories::phase_two;
use crate::memories::prompts::build_consolidation_prompt;
use crate::memories::storage::rebuild_raw_memories_file_from_memories;
use crate::memories::storage::rollout_summary_file_stem;
use crate::memories::storage::sync_rollout_summaries_from_memories;
use codex_config::Constrained;
use codex_protocol::ThreadId;
@ -17,8 +18,10 @@ use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::user_input::UserInput;
use codex_state::Stage1Output;
use codex_state::StateRuntime;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
@ -73,21 +76,24 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
};
// 3. Query the memories
let raw_memories = match db.list_stage1_outputs_for_global(max_raw_memories).await {
Ok(memories) => memories,
let selection = match db.get_phase2_input_selection(max_raw_memories).await {
Ok(selection) => selection,
Err(err) => {
tracing::error!("failed to list stage1 outputs from global: {}", err);
job::failed(session, db, &claim, "failed_load_stage1_outputs").await;
return;
}
};
let raw_memories = selection.selected.to_vec();
let artifact_memories = artifact_memories_for_phase2(&selection);
let new_watermark = get_watermark(claim.watermark, &raw_memories);
// 4. Update the file system by syncing the raw memories with the one extracted from DB at
// step 3
// [`rollout_summaries/`]
if let Err(err) =
sync_rollout_summaries_from_memories(&root, &raw_memories, max_raw_memories).await
sync_rollout_summaries_from_memories(&root, &artifact_memories, artifact_memories.len())
.await
{
tracing::error!("failed syncing local memory artifacts for global consolidation: {err}");
job::failed(session, db, &claim, "failed_sync_artifacts").await;
@ -95,7 +101,8 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
}
// [`raw_memories.md`]
if let Err(err) =
rebuild_raw_memories_file_from_memories(&root, &raw_memories, max_raw_memories).await
rebuild_raw_memories_file_from_memories(&root, &artifact_memories, artifact_memories.len())
.await
{
tracing::error!("failed syncing local memory artifacts for global consolidation: {err}");
job::failed(session, db, &claim, "failed_rebuild_raw_memories").await;
@ -103,12 +110,20 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
}
if raw_memories.is_empty() {
// We check only after sync of the file system.
job::succeed(session, db, &claim, new_watermark, "succeeded_no_input").await;
job::succeed(
session,
db,
&claim,
new_watermark,
&[],
"succeeded_no_input",
)
.await;
return;
}
// 5. Spawn the agent
let prompt = agent::get_prompt(config);
let prompt = agent::get_prompt(config, &selection);
let source = SessionSource::SubAgent(SubAgentSource::MemoryConsolidation);
let thread_id = match session
.services
@ -129,6 +144,7 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
session,
claim,
new_watermark,
raw_memories.clone(),
thread_id,
phase_two_e2e_timer,
);
@ -140,6 +156,22 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
emit_metrics(session, counters);
}
fn artifact_memories_for_phase2(
selection: &codex_state::Phase2InputSelection,
) -> Vec<Stage1Output> {
let mut seen = HashSet::new();
let mut memories = selection.selected.clone();
for memory in &selection.selected {
seen.insert(rollout_summary_file_stem(memory));
}
for memory in &selection.previous_selected {
if seen.insert(rollout_summary_file_stem(memory)) {
memories.push(memory.clone());
}
}
memories
}
mod job {
use super::*;
@ -205,6 +237,7 @@ mod job {
db: &StateRuntime,
claim: &Claim,
completion_watermark: i64,
selected_outputs: &[codex_state::Stage1Output],
reason: &'static str,
) {
session.services.otel_manager.counter(
@ -213,7 +246,7 @@ mod job {
&[("status", reason)],
);
let _ = db
.mark_global_phase2_job_succeeded(&claim.token, completion_watermark)
.mark_global_phase2_job_succeeded(&claim.token, completion_watermark, selected_outputs)
.await;
}
}
@ -266,9 +299,12 @@ mod agent {
Some(agent_config)
}
pub(super) fn get_prompt(config: Arc<Config>) -> Vec<UserInput> {
pub(super) fn get_prompt(
config: Arc<Config>,
selection: &codex_state::Phase2InputSelection,
) -> Vec<UserInput> {
let root = memory_root(&config.codex_home);
let prompt = build_consolidation_prompt(&root);
let prompt = build_consolidation_prompt(&root, selection);
vec![UserInput::Text {
text: prompt,
text_elements: vec![],
@ -280,6 +316,7 @@ mod agent {
session: &Arc<Session>,
claim: Claim,
new_watermark: i64,
selected_outputs: Vec<codex_state::Stage1Output>,
thread_id: ThreadId,
phase_two_e2e_timer: Option<codex_otel::Timer>,
) {
@ -316,7 +353,15 @@ mod agent {
if let Some(token_usage) = agent_control.get_total_token_usage(thread_id).await {
emit_token_usage_metrics(&session, &token_usage);
}
job::succeed(&session, &db, &claim, new_watermark, "succeeded").await;
job::succeed(
&session,
&db,
&claim,
new_watermark,
&selected_outputs,
"succeeded",
)
.await;
} else {
job::failed(&session, &db, &claim, "failed_agent").await;
}

View file

@ -1,9 +1,13 @@
use crate::memories::memory_root;
use crate::memories::phase_one;
use crate::memories::storage::rollout_summary_file_stem_from_parts;
use crate::truncate::TruncationPolicy;
use crate::truncate::truncate_text;
use askama::Template;
use codex_protocol::openai_models::ModelInfo;
use codex_state::Phase2InputSelection;
use codex_state::Stage1Output;
use codex_state::Stage1OutputRef;
use std::path::Path;
use tokio::fs;
use tracing::warn;
@ -12,6 +16,7 @@ use tracing::warn;
#[template(path = "memories/consolidation.md", escape = "none")]
struct ConsolidationPromptTemplate<'a> {
memory_root: &'a str,
phase2_input_selection: &'a str,
}
#[derive(Template)]
@ -30,17 +35,91 @@ struct MemoryToolDeveloperInstructionsTemplate<'a> {
}
/// Builds the consolidation subagent prompt for a specific memory root.
pub(super) fn build_consolidation_prompt(memory_root: &Path) -> String {
pub(super) fn build_consolidation_prompt(
memory_root: &Path,
selection: &Phase2InputSelection,
) -> String {
let memory_root = memory_root.display().to_string();
let phase2_input_selection = render_phase2_input_selection(selection);
let template = ConsolidationPromptTemplate {
memory_root: &memory_root,
phase2_input_selection: &phase2_input_selection,
};
template.render().unwrap_or_else(|err| {
warn!("failed to render memories consolidation prompt template: {err}");
format!("## Memory Phase 2 (Consolidation)\nConsolidate Codex memories in: {memory_root}")
format!(
"## Memory Phase 2 (Consolidation)\nConsolidate Codex memories in: {memory_root}\n\n{phase2_input_selection}"
)
})
}
fn render_phase2_input_selection(selection: &Phase2InputSelection) -> String {
let retained = selection.retained_thread_ids.len();
let added = selection.selected.len().saturating_sub(retained);
let selected = if selection.selected.is_empty() {
"- none".to_string()
} else {
selection
.selected
.iter()
.map(|item| {
render_selected_input_line(
item,
selection.retained_thread_ids.contains(&item.thread_id),
)
})
.collect::<Vec<_>>()
.join("\n")
};
let removed = if selection.removed.is_empty() {
"- none".to_string()
} else {
selection
.removed
.iter()
.map(render_removed_input_line)
.collect::<Vec<_>>()
.join("\n")
};
format!(
"- selected inputs this run: {}\n- newly added since the last successful Phase 2 run: {added}\n- retained from the last successful Phase 2 run: {retained}\n- removed from the last successful Phase 2 run: {}\n\nCurrent selected Phase 1 inputs:\n{selected}\n\nRemoved from the last successful Phase 2 selection:\n{removed}\n",
selection.selected.len(),
selection.removed.len(),
)
}
fn render_selected_input_line(item: &Stage1Output, retained: bool) -> String {
let status = if retained { "retained" } else { "added" };
let rollout_summary_file = format!(
"rollout_summaries/{}.md",
rollout_summary_file_stem_from_parts(
item.thread_id,
item.source_updated_at,
item.rollout_slug.as_deref(),
)
);
format!(
"- [{status}] thread_id={}, rollout_summary_file={rollout_summary_file}",
item.thread_id
)
}
fn render_removed_input_line(item: &Stage1OutputRef) -> String {
let rollout_summary_file = format!(
"rollout_summaries/{}.md",
rollout_summary_file_stem_from_parts(
item.thread_id,
item.source_updated_at,
item.rollout_slug.as_deref(),
)
);
format!(
"- thread_id={}, rollout_summary_file={rollout_summary_file}",
item.thread_id
)
}
/// Builds the stage-1 user message containing rollout metadata and content.
///
/// Large rollout payloads are truncated to 70% of the active model's effective

View file

@ -121,6 +121,27 @@ Mode selection:
- INCREMENTAL UPDATE: existing artifacts already exist and `raw_memories.md`
mostly contains new additions.
Incremental thread diff snapshot (computed before the current artifact sync rewrites local files):
**Diff since last consolidation:**
{{ phase2_input_selection }}
Incremental update and forgetting mechanism:
- Use the diff provided
- Do not open raw sessions / original rollout transcripts.
- For each added thread id, search it in `raw_memories.md`, read that raw-memory section, and
read the corresponding `rollout_summaries/*.md` file only when needed for stronger evidence,
task placement, or conflict resolution.
- For each removed thread id, search it in `MEMORY.md` and delete only the memory supported by
that thread. Use `thread_id=<thread_id>` in `### rollout_summary_files` when available; if not,
fall back to rollout summary filenames plus the corresponding `rollout_summaries/*.md` files.
- If a `MEMORY.md` block contains both removed and undeleted threads, do not delete the whole
block. Remove only the removed thread's references and thread-local learnings, preserve shared
or still-supported content, and split or rewrite the block only if needed to keep the undeleted
threads intact.
- After `MEMORY.md` cleanup is done, revisit `memory_summary.md` and remove or rewrite stale
summary/index content that was only supported by removed thread ids.
Outputs:
Under `{{ memory_root }}/`:
A) `MEMORY.md`
@ -498,27 +519,42 @@ WORKFLOW
conflicting task families until MEMORY blocks are richer and more useful than raw memories
3) INCREMENTAL UPDATE behavior:
- Treat `raw_memories.md` as the primary source of NEW signal.
- Read existing memory files first for continuity.
- Read existing `MEMORY.md` and `memory_summary.md` first for continuity and to locate
existing references that may need surgical cleanup.
- Use the injected thread-diff snapshot as the first routing pass:
- added thread ids = ingestion queue
- removed thread ids = forgetting / stale-cleanup queue
- Build an index of rollout references already present in existing `MEMORY.md` before
scanning raw memories so you can route net-new evidence into the right blocks.
- Compute net-new candidates from the raw-memory inventory (threads / rollout summaries /
updated evidence not already represented in `MEMORY.md`).
- Work in this order:
1. For newly added thread ids, search them in `raw_memories.md`, read those sections, and
open the corresponding `rollout_summaries/*.md` files when necessary.
2. Route the new signal into existing `MEMORY.md` blocks or create new ones when needed.
3. For removed thread ids, search `MEMORY.md` and surgically delete or rewrite only the
unsupported thread-local memory.
4. If a block mixes removed and undeleted threads, preserve the undeleted-thread content;
split or rewrite the block if that is the cleanest way to delete only the removed part.
5. After `MEMORY.md` is correct, revisit `memory_summary.md` and remove or rewrite stale
summary/index content that no longer has undeleted support.
- Integrate new signal into existing artifacts by:
- scanning new raw memories in recency order and identifying which existing blocks they should update
- scanning the newly added raw-memory entries in recency order and identifying which existing blocks they should update
- updating existing knowledge with better/newer evidence
- updating stale or contradicting guidance
- pruning or downgrading memory whose only provenance comes from removed thread ids
- expanding terse old blocks when new summaries/raw memories make the task family clearer
- doing light clustering and merging if needed
- refreshing `MEMORY.md` top-of-file ordering so recent high-utility task families stay easy to find
- rebuilding the `memory_summary.md` recent active window (last 3 memory days) from current `updated_at` coverage
- updating existing skills or adding new skills only when there is clear new reusable procedure
- update `memory_summary.md` last to reflect the final state of the memory folder
- updating `memory_summary.md` last to reflect the final state of the memory folder
- Minimize churn in incremental mode: if an existing `MEMORY.md` block or `## What's in Memory`
topic still reflects the current evidence and points to the same task family / retrieval
target, keep its wording, label, and relative order mostly stable. Rewrite/reorder/rename/
split/merge only when fixing a real problem (staleness, ambiguity, schema drift, wrong
boundaries) or when meaningful new evidence materially improves retrieval clarity/searchability.
- Spend most of your deep-dive budget on newly added thread ids and on mixed blocks touched by
removed thread ids. Do not re-read unchanged older threads unless you need them for
conflict resolution, clustering, or provenance repair.
4) Evidence deep-dive rule (both modes):
- `raw_memories.md` is the routing layer, not always the final authority for detail.
@ -529,6 +565,9 @@ WORKFLOW
- When a task family is important, ambiguous, or duplicated across multiple rollouts,
open the relevant `rollout_summaries/*.md` files and extract richer procedural detail,
validation signals, and user feedback before finalizing `MEMORY.md`.
- When deleting stale memory from a mixed block, use the relevant rollout summaries to decide
which details are uniquely supported by removed threads versus still supported by undeleted
threads.
- Use `updated_at` and validation strength together to resolve stale/conflicting notes.
5) For both modes, update `MEMORY.md` after skill updates:
@ -542,6 +581,8 @@ WORKFLOW
7) Final pass:
- remove duplication in memory_summary, skills/, and MEMORY.md
- remove stale or low-signal blocks that are less likely to be useful in the future
- remove or rewrite blocks/task sections whose supporting rollout references point only to
removed thread ids or missing rollout summary files
- run a global rollout-reference audit on final `MEMORY.md` and fix accidental duplicate
entries / redundant repetition, while preserving intentional multi-task or multi-block
reuse when it adds distinct task-local value
@ -560,16 +601,3 @@ WORKFLOW
You should dive deep and make sure you didn't miss any important information that might
be useful for future agents; do not be superficial.
============================================================
SEARCH / REVIEW COMMANDS (RG-FIRST)
============================================================
Use `rg` for fast retrieval while consolidating:
- Search durable notes:
`rg -n -i "<pattern>" "{{ memory_root }}/MEMORY.md"`
- Search across memory tree:
`rg -n -i "<pattern>" "{{ memory_root }}" | head -n 100`
- Locate rollout summary files:
`rg --files "{{ memory_root }}/rollout_summaries" | head -n 400`

View file

@ -0,0 +1,282 @@
use anyhow::Result;
use chrono::Duration as ChronoDuration;
use chrono::Utc;
use codex_core::features::Feature;
use codex_protocol::ThreadId;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::SessionSource;
use core_test_support::responses::ResponseMock;
use core_test_support::responses::ResponsesRequest;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::test_codex::TestCodex;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use pretty_assertions::assert_eq;
use std::path::Path;
use std::sync::Arc;
use tempfile::TempDir;
use tokio::time::Duration;
use tokio::time::Instant;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn memories_startup_phase2_tracks_added_and_removed_inputs_across_runs() -> Result<()> {
let server = start_mock_server().await;
let home = Arc::new(TempDir::new()?);
let db = init_state_db(&home).await?;
let now = Utc::now();
let thread_a = seed_stage1_output(
db.as_ref(),
home.path(),
now - ChronoDuration::hours(2),
"raw memory A",
"rollout summary A",
"rollout-a",
)
.await?;
let first_phase2 = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-phase2-1"),
ev_assistant_message("msg-phase2-1", "phase2 complete"),
ev_completed("resp-phase2-1"),
]),
)
.await;
let first = build_test_codex(&server, home.clone()).await?;
let first_request = wait_for_single_request(&first_phase2).await;
let first_prompt = phase2_prompt_text(&first_request);
assert!(
first_prompt.contains("- selected inputs this run: 1"),
"expected selected count in first prompt: {first_prompt}"
);
assert!(
first_prompt.contains("- newly added since the last successful Phase 2 run: 1"),
"expected added count in first prompt: {first_prompt}"
);
assert!(
first_prompt.contains("- removed from the last successful Phase 2 run: 0"),
"expected removed count in first prompt: {first_prompt}"
);
assert!(
first_prompt.contains(&format!("- [added] thread_id={thread_a},")),
"expected thread A to be marked added: {first_prompt}"
);
assert!(
first_prompt.contains("Removed from the last successful Phase 2 selection:\n- none"),
"expected no removed items in first prompt: {first_prompt}"
);
wait_for_phase2_success(db.as_ref(), thread_a).await?;
let memory_root = home.path().join("memories");
let raw_memories = tokio::fs::read_to_string(memory_root.join("raw_memories.md")).await?;
assert!(raw_memories.contains("raw memory A"));
assert!(!raw_memories.contains("raw memory B"));
let rollout_summaries = read_rollout_summary_bodies(&memory_root).await?;
assert_eq!(rollout_summaries.len(), 1);
assert!(rollout_summaries[0].contains("rollout summary A"));
shutdown_test_codex(&first).await?;
let thread_b = seed_stage1_output(
db.as_ref(),
home.path(),
now - ChronoDuration::hours(1),
"raw memory B",
"rollout summary B",
"rollout-b",
)
.await?;
let second_phase2 = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-phase2-2"),
ev_assistant_message("msg-phase2-2", "phase2 complete"),
ev_completed("resp-phase2-2"),
]),
)
.await;
let second = build_test_codex(&server, home.clone()).await?;
let second_request = wait_for_single_request(&second_phase2).await;
let second_prompt = phase2_prompt_text(&second_request);
assert!(
second_prompt.contains("- selected inputs this run: 1"),
"expected selected count in second prompt: {second_prompt}"
);
assert!(
second_prompt.contains("- newly added since the last successful Phase 2 run: 1"),
"expected added count in second prompt: {second_prompt}"
);
assert!(
second_prompt.contains("- removed from the last successful Phase 2 run: 1"),
"expected removed count in second prompt: {second_prompt}"
);
assert!(
second_prompt.contains(&format!("- [added] thread_id={thread_b},")),
"expected thread B to be marked added: {second_prompt}"
);
assert!(
second_prompt.contains(&format!("- thread_id={thread_a},")),
"expected thread A to be marked removed: {second_prompt}"
);
wait_for_phase2_success(db.as_ref(), thread_b).await?;
let raw_memories = tokio::fs::read_to_string(memory_root.join("raw_memories.md")).await?;
assert!(raw_memories.contains("raw memory B"));
assert!(raw_memories.contains("raw memory A"));
let rollout_summaries = read_rollout_summary_bodies(&memory_root).await?;
assert_eq!(rollout_summaries.len(), 2);
assert!(
rollout_summaries
.iter()
.any(|summary| summary.contains("rollout summary B"))
);
assert!(
rollout_summaries
.iter()
.any(|summary| summary.contains("rollout summary A"))
);
shutdown_test_codex(&second).await?;
Ok(())
}
async fn build_test_codex(server: &wiremock::MockServer, home: Arc<TempDir>) -> Result<TestCodex> {
let mut builder = test_codex().with_home(home).with_config(|config| {
config.features.enable(Feature::Sqlite);
config.features.enable(Feature::MemoryTool);
config.memories.max_raw_memories_for_global = 1;
});
builder.build(server).await
}
async fn init_state_db(home: &Arc<TempDir>) -> Result<Arc<codex_state::StateRuntime>> {
let db =
codex_state::StateRuntime::init(home.path().to_path_buf(), "test-provider".into(), None)
.await?;
db.mark_backfill_complete(None).await?;
Ok(db)
}
async fn seed_stage1_output(
db: &codex_state::StateRuntime,
codex_home: &Path,
updated_at: chrono::DateTime<Utc>,
raw_memory: &str,
rollout_summary: &str,
rollout_slug: &str,
) -> Result<ThreadId> {
let thread_id = ThreadId::new();
let mut metadata_builder = codex_state::ThreadMetadataBuilder::new(
thread_id,
codex_home.join(format!("rollout-{thread_id}.jsonl")),
updated_at,
SessionSource::Cli,
);
metadata_builder.cwd = codex_home.join(format!("workspace-{rollout_slug}"));
metadata_builder.model_provider = Some("test-provider".to_string());
let metadata = metadata_builder.build("test-provider");
db.upsert_thread(&metadata).await?;
let claim = db
.try_claim_stage1_job(
thread_id,
ThreadId::new(),
updated_at.timestamp(),
3_600,
64,
)
.await?;
let ownership_token = match claim {
codex_state::Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage-1 claim outcome: {other:?}"),
};
assert!(
db.mark_stage1_job_succeeded(
thread_id,
&ownership_token,
updated_at.timestamp(),
raw_memory,
rollout_summary,
Some(rollout_slug),
)
.await?,
"stage-1 success should enqueue global consolidation"
);
Ok(thread_id)
}
async fn wait_for_single_request(mock: &ResponseMock) -> ResponsesRequest {
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let requests = mock.requests();
if let Some(request) = requests.into_iter().next() {
return request;
}
assert!(
Instant::now() < deadline,
"timed out waiting for phase2 request"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
#[allow(clippy::expect_used)]
fn phase2_prompt_text(request: &ResponsesRequest) -> String {
request
.message_input_texts("user")
.into_iter()
.find(|text| text.contains("Current selected Phase 1 inputs:"))
.expect("phase2 prompt text")
}
async fn wait_for_phase2_success(
db: &codex_state::StateRuntime,
expected_thread_id: ThreadId,
) -> Result<()> {
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let selection = db.get_phase2_input_selection(1).await?;
if selection.selected.len() == 1
&& selection.selected[0].thread_id == expected_thread_id
&& selection.retained_thread_ids == vec![expected_thread_id]
&& selection.removed.is_empty()
{
return Ok(());
}
assert!(
Instant::now() < deadline,
"timed out waiting for phase2 success for {expected_thread_id}"
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
async fn read_rollout_summary_bodies(memory_root: &Path) -> Result<Vec<String>> {
let mut dir = tokio::fs::read_dir(memory_root.join("rollout_summaries")).await?;
let mut summaries = Vec::new();
while let Some(entry) = dir.next_entry().await? {
summaries.push(tokio::fs::read_to_string(entry.path()).await?);
}
summaries.sort();
Ok(summaries)
}
async fn shutdown_test_codex(test: &TestCodex) -> Result<()> {
test.codex.submit(Op::Shutdown {}).await?;
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
Ok(())
}

View file

@ -83,6 +83,7 @@ mod json_result;
mod list_dir;
mod live_cli;
mod live_reload;
mod memories;
mod model_info_overrides;
mod model_overrides;
mod model_switching;

View file

@ -0,0 +1,3 @@
ALTER TABLE stage1_outputs
ADD COLUMN selected_for_phase2_source_updated_at INTEGER;
ALTER TABLE threads ADD COLUMN memory_mode TEXT NOT NULL DEFAULT 'enabled';

View file

@ -14,6 +14,7 @@ mod runtime;
pub use model::LogEntry;
pub use model::LogQuery;
pub use model::LogRow;
pub use model::Phase2InputSelection;
pub use model::Phase2JobClaimOutcome;
/// Preferred entrypoint: owns configuration and metrics.
pub use runtime::StateRuntime;
@ -38,6 +39,7 @@ pub use model::SortKey;
pub use model::Stage1JobClaim;
pub use model::Stage1JobClaimOutcome;
pub use model::Stage1Output;
pub use model::Stage1OutputRef;
pub use model::Stage1StartupClaimParams;
pub use model::ThreadMetadata;
pub use model::ThreadMetadataBuilder;

View file

@ -21,6 +21,21 @@ pub struct Stage1Output {
pub generated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Stage1OutputRef {
pub thread_id: ThreadId,
pub source_updated_at: DateTime<Utc>,
pub rollout_slug: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Phase2InputSelection {
pub selected: Vec<Stage1Output>,
pub previous_selected: Vec<Stage1Output>,
pub retained_thread_ids: Vec<ThreadId>,
pub removed: Vec<Stage1OutputRef>,
}
#[derive(Debug)]
pub(crate) struct Stage1OutputRow {
thread_id: String,
@ -70,6 +85,18 @@ fn epoch_seconds_to_datetime(secs: i64) -> Result<DateTime<Utc>> {
.ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}"))
}
pub(crate) fn stage1_output_ref_from_parts(
thread_id: String,
source_updated_at: i64,
rollout_slug: Option<String>,
) -> Result<Stage1OutputRef> {
Ok(Stage1OutputRef {
thread_id: ThreadId::try_from(thread_id)?,
source_updated_at: epoch_seconds_to_datetime(source_updated_at)?,
rollout_slug,
})
}
/// Result of trying to claim a stage-1 memory extraction job.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Stage1JobClaimOutcome {

View file

@ -16,10 +16,12 @@ pub use backfill_state::BackfillStatus;
pub use log::LogEntry;
pub use log::LogQuery;
pub use log::LogRow;
pub use memories::Phase2InputSelection;
pub use memories::Phase2JobClaimOutcome;
pub use memories::Stage1JobClaim;
pub use memories::Stage1JobClaimOutcome;
pub use memories::Stage1Output;
pub use memories::Stage1OutputRef;
pub use memories::Stage1StartupClaimParams;
pub use thread_metadata::Anchor;
pub use thread_metadata::BackfillStats;
@ -32,6 +34,7 @@ pub use thread_metadata::ThreadsPage;
pub(crate) use agent_job::AgentJobItemRow;
pub(crate) use agent_job::AgentJobRow;
pub(crate) use memories::Stage1OutputRow;
pub(crate) use memories::stage1_output_ref_from_parts;
pub(crate) use thread_metadata::ThreadRow;
pub(crate) use thread_metadata::anchor_from_item;
pub(crate) use thread_metadata::datetime_to_epoch_seconds;

View file

@ -2773,7 +2773,11 @@ WHERE kind = 'memory_stage1'
assert_eq!(phase2_input_watermark, 100);
assert!(
runtime
.mark_global_phase2_job_succeeded(phase2_token.as_str(), phase2_input_watermark)
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
phase2_input_watermark,
&[],
)
.await
.expect("mark initial phase2 succeeded"),
"initial phase2 success should clear global dirty state"
@ -2819,7 +2823,11 @@ WHERE kind = 'memory_stage1'
assert_eq!(phase2_input_watermark, 101);
assert!(
runtime
.mark_global_phase2_job_succeeded(phase2_token.as_str(), phase2_input_watermark)
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
phase2_input_watermark,
&[],
)
.await
.expect("mark phase2 succeeded after no-output delete")
);
@ -2936,7 +2944,7 @@ WHERE kind = 'memory_stage1'
};
assert!(
runtime
.mark_global_phase2_job_succeeded(ownership_token.as_str(), input_watermark)
.mark_global_phase2_job_succeeded(ownership_token.as_str(), input_watermark, &[],)
.await
.expect("mark phase2 succeeded"),
"phase2 success should finalize for current token"
@ -3124,6 +3132,646 @@ VALUES (?, ?, ?, ?, ?)
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
for (thread_id, workspace) in [
(thread_id_a, "workspace-a"),
(thread_id_b, "workspace-b"),
(thread_id_c, "workspace-c"),
] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(workspace),
))
.await
.expect("upsert thread");
}
for (thread_id, updated_at, slug) in [
(thread_id_a, 100, Some("rollout-a")),
(thread_id_b, 101, Some("rollout-b")),
(thread_id_c, 102, Some("rollout-c")),
] {
let claim = runtime
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
updated_at,
&format!("raw-{updated_at}"),
&format!("summary-{updated_at}"),
slug,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
}
let claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (ownership_token, input_watermark) = match claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
assert_eq!(input_watermark, 102);
let selected_outputs = runtime
.list_stage1_outputs_for_global(10)
.await
.expect("list stage1 outputs for global")
.into_iter()
.filter(|output| output.thread_id == thread_id_c || output.thread_id == thread_id_a)
.collect::<Vec<_>>();
assert!(
runtime
.mark_global_phase2_job_succeeded(
ownership_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success with selection"),
"phase2 success should persist selected rows"
);
let selection = runtime
.get_phase2_input_selection(2)
.await
.expect("load phase2 input selection");
assert_eq!(selection.selected.len(), 2);
assert_eq!(selection.previous_selected.len(), 2);
assert_eq!(selection.selected[0].thread_id, thread_id_c);
assert_eq!(
selection.selected[0].rollout_path,
codex_home.join(format!("rollout-{thread_id_c}.jsonl"))
);
assert_eq!(selection.selected[1].thread_id, thread_id_b);
assert_eq!(selection.retained_thread_ids, vec![thread_id_c]);
assert_eq!(selection.removed.len(), 1);
assert_eq!(selection.removed[0].thread_id, thread_id_a);
assert_eq!(
selection.removed[0].rollout_slug.as_deref(),
Some("rollout-a")
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join("workspace"),
))
.await
.expect("upsert thread");
let first_claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim initial stage1");
let first_token = match first_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
first_token.as_str(),
100,
"raw-100",
"summary-100",
Some("rollout-100"),
)
.await
.expect("mark initial stage1 success"),
"initial stage1 success should persist output"
);
let phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (phase2_token, input_watermark) = match phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(1)
.await
.expect("list selected outputs");
assert!(
runtime
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should persist selected rows"
);
let refreshed_claim = runtime
.try_claim_stage1_job(thread_id, owner, 101, 3600, 64)
.await
.expect("claim refreshed stage1");
let refreshed_token = match refreshed_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
refreshed_token.as_str(),
101,
"raw-101",
"summary-101",
Some("rollout-101"),
)
.await
.expect("mark refreshed stage1 success"),
"refreshed stage1 success should persist output"
);
let selection = runtime
.get_phase2_input_selection(1)
.await
.expect("load phase2 input selection");
assert_eq!(selection.selected.len(), 1);
assert_eq!(selection.previous_selected.len(), 1);
assert_eq!(selection.selected[0].thread_id, thread_id);
assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101);
assert!(selection.retained_thread_ids.is_empty());
assert!(selection.removed.is_empty());
let (selected_for_phase2, selected_for_phase2_source_updated_at) =
sqlx::query_as::<_, (i64, Option<i64>)>(
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("load selected_for_phase2");
assert_eq!(selected_for_phase2, 1);
assert_eq!(selected_for_phase2_source_updated_at, Some(100));
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_reports_regenerated_previous_selection_as_removed() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread a");
let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread b");
let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread c");
let thread_id_d = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread d");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
for (thread_id, workspace) in [
(thread_id_a, "workspace-a"),
(thread_id_b, "workspace-b"),
(thread_id_c, "workspace-c"),
(thread_id_d, "workspace-d"),
] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(workspace),
))
.await
.expect("upsert thread");
}
for (thread_id, updated_at, slug) in [
(thread_id_a, 100, Some("rollout-a-100")),
(thread_id_b, 101, Some("rollout-b-101")),
(thread_id_c, 99, Some("rollout-c-99")),
(thread_id_d, 98, Some("rollout-d-98")),
] {
let claim = runtime
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
.await
.expect("claim initial stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
updated_at,
&format!("raw-{updated_at}"),
&format!("summary-{updated_at}"),
slug,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
}
let phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (phase2_token, input_watermark) = match phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(2)
.await
.expect("list selected outputs");
assert_eq!(
selected_outputs
.iter()
.map(|output| output.thread_id)
.collect::<Vec<_>>(),
vec![thread_id_b, thread_id_a]
);
assert!(
runtime
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should persist selected rows"
);
for (thread_id, updated_at, slug) in [
(thread_id_a, 102, Some("rollout-a-102")),
(thread_id_c, 103, Some("rollout-c-103")),
(thread_id_d, 104, Some("rollout-d-104")),
] {
let claim = runtime
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
.await
.expect("claim refreshed stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
updated_at,
&format!("raw-{updated_at}"),
&format!("summary-{updated_at}"),
slug,
)
.await
.expect("mark refreshed stage1 success"),
"refreshed stage1 success should persist output"
);
}
let selection = runtime
.get_phase2_input_selection(2)
.await
.expect("load phase2 input selection");
assert_eq!(
selection
.selected
.iter()
.map(|output| output.thread_id)
.collect::<Vec<_>>(),
vec![thread_id_d, thread_id_c]
);
assert_eq!(
selection
.previous_selected
.iter()
.map(|output| output.thread_id)
.collect::<Vec<_>>(),
vec![thread_id_a, thread_id_b]
);
assert!(selection.retained_thread_ids.is_empty());
assert_eq!(
selection
.removed
.iter()
.map(|output| (output.thread_id, output.source_updated_at.timestamp()))
.collect::<Vec<_>>(),
vec![(thread_id_a, 102), (thread_id_b, 101)]
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn mark_global_phase2_job_succeeded_updates_selected_snapshot_timestamp() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join("workspace"),
))
.await
.expect("upsert thread");
let initial_claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim initial stage1");
let initial_token = match initial_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
initial_token.as_str(),
100,
"raw-100",
"summary-100",
Some("rollout-100"),
)
.await
.expect("mark initial stage1 success"),
"initial stage1 success should persist output"
);
let first_phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim first phase2");
let (first_phase2_token, first_input_watermark) = match first_phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected first phase2 claim outcome: {other:?}"),
};
let first_selected_outputs = runtime
.list_stage1_outputs_for_global(1)
.await
.expect("list first selected outputs");
assert!(
runtime
.mark_global_phase2_job_succeeded(
first_phase2_token.as_str(),
first_input_watermark,
&first_selected_outputs,
)
.await
.expect("mark first phase2 success"),
"first phase2 success should persist selected rows"
);
let refreshed_claim = runtime
.try_claim_stage1_job(thread_id, owner, 101, 3600, 64)
.await
.expect("claim refreshed stage1");
let refreshed_token = match refreshed_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected refreshed stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
refreshed_token.as_str(),
101,
"raw-101",
"summary-101",
Some("rollout-101"),
)
.await
.expect("mark refreshed stage1 success"),
"refreshed stage1 success should persist output"
);
let second_phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim second phase2");
let (second_phase2_token, second_input_watermark) = match second_phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected second phase2 claim outcome: {other:?}"),
};
let second_selected_outputs = runtime
.list_stage1_outputs_for_global(1)
.await
.expect("list second selected outputs");
assert_eq!(
second_selected_outputs[0].source_updated_at.timestamp(),
101
);
assert!(
runtime
.mark_global_phase2_job_succeeded(
second_phase2_token.as_str(),
second_input_watermark,
&second_selected_outputs,
)
.await
.expect("mark second phase2 success"),
"second phase2 success should persist selected rows"
);
let selection = runtime
.get_phase2_input_selection(1)
.await
.expect("load phase2 input selection after refresh");
assert_eq!(selection.retained_thread_ids, vec![thread_id]);
let (selected_for_phase2, selected_for_phase2_source_updated_at) =
sqlx::query_as::<_, (i64, Option<i64>)>(
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("load selected snapshot after phase2");
assert_eq!(selected_for_phase2, 1);
assert_eq!(selected_for_phase2_source_updated_at, Some(101));
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn mark_global_phase2_job_succeeded_only_marks_exact_selected_snapshots() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join("workspace"),
))
.await
.expect("upsert thread");
let initial_claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim initial stage1");
let initial_token = match initial_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
initial_token.as_str(),
100,
"raw-100",
"summary-100",
Some("rollout-100"),
)
.await
.expect("mark initial stage1 success"),
"initial stage1 success should persist output"
);
let phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (phase2_token, input_watermark) = match phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(1)
.await
.expect("list selected outputs");
assert_eq!(selected_outputs[0].source_updated_at.timestamp(), 100);
let refreshed_claim = runtime
.try_claim_stage1_job(thread_id, owner, 101, 3600, 64)
.await
.expect("claim refreshed stage1");
let refreshed_token = match refreshed_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
refreshed_token.as_str(),
101,
"raw-101",
"summary-101",
Some("rollout-101"),
)
.await
.expect("mark refreshed stage1 success"),
"refreshed stage1 success should persist output"
);
assert!(
runtime
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should still complete"
);
let (selected_for_phase2, selected_for_phase2_source_updated_at) =
sqlx::query_as::<_, (i64, Option<i64>)>(
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("load selected_for_phase2");
assert_eq!(selected_for_phase2, 0);
assert_eq!(selected_for_phase2_source_updated_at, None);
let selection = runtime
.get_phase2_input_selection(1)
.await
.expect("load phase2 input selection");
assert_eq!(selection.selected.len(), 1);
assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101);
assert!(selection.retained_thread_ids.is_empty());
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn record_stage1_output_usage_updates_usage_metadata() {
let codex_home = unique_temp_dir();
@ -3395,7 +4043,7 @@ VALUES (?, ?, ?, ?, ?)
assert_eq!(
runtime
.mark_global_phase2_job_succeeded(token_a.as_str(), 300)
.mark_global_phase2_job_succeeded(token_a.as_str(), 300, &[])
.await
.expect("mark stale owner success result"),
false,
@ -3403,7 +4051,7 @@ VALUES (?, ?, ?, ?, ?)
);
assert!(
runtime
.mark_global_phase2_job_succeeded(token_b.as_str(), 300)
.mark_global_phase2_job_succeeded(token_b.as_str(), 300, &[])
.await
.expect("mark takeover owner success"),
"takeover owner should finalize consolidation"
@ -3440,7 +4088,7 @@ VALUES (?, ?, ?, ?, ?)
};
assert!(
runtime
.mark_global_phase2_job_succeeded(token_a.as_str(), 500)
.mark_global_phase2_job_succeeded(token_a.as_str(), 500, &[])
.await
.expect("mark initial phase2 success"),
"initial phase2 success should finalize"

View file

@ -1,4 +1,5 @@
use super::*;
use crate::model::Phase2InputSelection;
use crate::model::Phase2JobClaimOutcome;
use crate::model::Stage1JobClaim;
use crate::model::Stage1JobClaimOutcome;
@ -6,10 +7,12 @@ use crate::model::Stage1Output;
use crate::model::Stage1OutputRow;
use crate::model::Stage1StartupClaimParams;
use crate::model::ThreadRow;
use crate::model::stage1_output_ref_from_parts;
use chrono::Duration;
use sqlx::Executor;
use sqlx::QueryBuilder;
use sqlx::Sqlite;
use std::collections::HashSet;
const JOB_KIND_MEMORY_STAGE1: &str = "memory_stage1";
const JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL: &str = "memory_consolidate_global";
@ -257,6 +260,117 @@ LIMIT ?
.collect::<Result<Vec<_>, _>>()
}
/// Returns the current phase-2 input set along with its diff against the
/// last successful phase-2 selection.
///
/// Query behavior:
/// - current selection is the latest `n` non-empty stage-1 outputs ordered
/// by `source_updated_at DESC, thread_id DESC`
/// - previously selected rows are identified by `selected_for_phase2 = 1`
/// - `previous_selected` contains the current persisted rows that belonged
/// to the last successful phase-2 baseline
/// - `retained_thread_ids` records which current rows still match the exact
/// snapshot selected in the last successful phase-2 run
/// - removed rows are previously selected rows that are still present in
/// `stage1_outputs` but fall outside the current top-`n` selection
pub async fn get_phase2_input_selection(
&self,
n: usize,
) -> anyhow::Result<Phase2InputSelection> {
if n == 0 {
return Ok(Phase2InputSelection::default());
}
let current_rows = sqlx::query(
r#"
SELECT
so.thread_id,
COALESCE(t.rollout_path, '') AS rollout_path,
so.source_updated_at,
so.raw_memory,
so.rollout_summary,
so.rollout_slug,
so.generated_at,
so.selected_for_phase2,
so.selected_for_phase2_source_updated_at,
COALESCE(t.cwd, '') AS cwd
FROM stage1_outputs AS so
LEFT JOIN threads AS t
ON t.id = so.thread_id
WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0
ORDER BY so.source_updated_at DESC, so.thread_id DESC
LIMIT ?
"#,
)
.bind(n as i64)
.fetch_all(self.pool.as_ref())
.await?;
let mut current_thread_ids = HashSet::with_capacity(current_rows.len());
let mut selected = Vec::with_capacity(current_rows.len());
let mut retained_thread_ids = Vec::new();
for row in current_rows {
let thread_id = row.try_get::<String, _>("thread_id")?;
current_thread_ids.insert(thread_id.clone());
let source_updated_at = row.try_get::<i64, _>("source_updated_at")?;
if row.try_get::<i64, _>("selected_for_phase2")? != 0
&& row.try_get::<Option<i64>, _>("selected_for_phase2_source_updated_at")?
== Some(source_updated_at)
{
retained_thread_ids.push(ThreadId::try_from(thread_id.clone())?);
}
selected.push(Stage1Output::try_from(Stage1OutputRow::try_from_row(
&row,
)?)?);
}
let previous_rows = sqlx::query(
r#"
SELECT
so.thread_id,
COALESCE(t.rollout_path, '') AS rollout_path,
so.source_updated_at,
so.raw_memory,
so.rollout_summary,
so.rollout_slug
, so.generated_at
, COALESCE(t.cwd, '') AS cwd
FROM stage1_outputs AS so
LEFT JOIN threads AS t
ON t.id = so.thread_id
WHERE so.selected_for_phase2 = 1
ORDER BY so.source_updated_at DESC, so.thread_id DESC
"#,
)
.fetch_all(self.pool.as_ref())
.await?;
let previous_selected = previous_rows
.iter()
.map(Stage1OutputRow::try_from_row)
.map(|row| row.and_then(Stage1Output::try_from))
.collect::<Result<Vec<_>, _>>()?;
let mut removed = Vec::new();
for row in previous_rows {
let thread_id = row.try_get::<String, _>("thread_id")?;
if current_thread_ids.contains(thread_id.as_str()) {
continue;
}
removed.push(stage1_output_ref_from_parts(
thread_id,
row.try_get("source_updated_at")?,
row.try_get("rollout_slug")?,
)?);
}
Ok(Phase2InputSelection {
selected,
previous_selected,
retained_thread_ids,
removed,
})
}
/// Attempts to claim a stage-1 job for a thread at `source_updated_at`.
///
/// Claim semantics:
@ -454,6 +568,9 @@ WHERE kind = ? AND job_key = ?
/// - sets `status='done'` and `last_success_watermark = input_watermark`
/// - upserts `stage1_outputs` for the thread, replacing existing output only
/// when `source_updated_at` is newer or equal
/// - preserves any existing `selected_for_phase2` baseline until the next
/// successful phase-2 run rewrites the baseline selection, including the
/// snapshot timestamp chosen during that run
/// - persists optional `rollout_slug` for rollout summary artifact naming
/// - enqueues/advances the global phase-2 job watermark using
/// `source_updated_at`
@ -806,12 +923,18 @@ WHERE kind = ? AND job_key = ?
/// - sets `status='done'`, clears lease/errors
/// - advances `last_success_watermark` to
/// `max(existing_last_success_watermark, completed_watermark)`
/// - rewrites `selected_for_phase2` so only the exact selected stage-1
/// snapshots remain marked as part of the latest successful phase-2
/// selection, and persists each selected snapshot's
/// `source_updated_at` for future retained-vs-added diffing
pub async fn mark_global_phase2_job_succeeded(
&self,
ownership_token: &str,
completed_watermark: i64,
selected_outputs: &[Stage1Output],
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let mut tx = self.pool.begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE jobs
@ -830,11 +953,46 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(&mut *tx)
.await?
.rows_affected();
Ok(rows_affected > 0)
if rows_affected == 0 {
tx.commit().await?;
return Ok(false);
}
sqlx::query(
r#"
UPDATE stage1_outputs
SET
selected_for_phase2 = 0,
selected_for_phase2_source_updated_at = NULL
WHERE selected_for_phase2 != 0 OR selected_for_phase2_source_updated_at IS NOT NULL
"#,
)
.execute(&mut *tx)
.await?;
for output in selected_outputs {
sqlx::query(
r#"
UPDATE stage1_outputs
SET
selected_for_phase2 = 1,
selected_for_phase2_source_updated_at = ?
WHERE thread_id = ? AND source_updated_at = ?
"#,
)
.bind(output.source_updated_at.timestamp())
.bind(output.thread_id.to_string())
.bind(output.source_updated_at.timestamp())
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(true)
}
/// Marks the owned running global phase-2 job as failed and schedules retry.