feat: record memory usage (#12761)

This commit is contained in:
jif-oai 2026-02-25 13:48:40 +00:00 committed by GitHub
parent 5441130e0a
commit e4bfa763f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 254 additions and 24 deletions

View file

@ -43,6 +43,7 @@ use crate::stream_events_utils::handle_non_tool_response_item;
use crate::stream_events_utils::handle_output_item_done;
use crate::stream_events_utils::last_assistant_message_from_item;
use crate::stream_events_utils::raw_assistant_output_text_from_item;
use crate::stream_events_utils::record_completed_response_item;
use crate::terminal;
use crate::truncate::TruncationPolicy;
use crate::turn_metadata::TurnMetadataState;
@ -6068,7 +6069,7 @@ async fn handle_assistant_item_done_in_plan_mode(
{
maybe_complete_plan_item_from_message(sess, turn_context, state, item).await;
if let Some(turn_item) = handle_non_tool_response_item(item, true).await {
if let Some(turn_item) = handle_non_tool_response_item(item, true) {
emit_turn_item_in_plan_mode(
sess,
turn_context,
@ -6079,8 +6080,7 @@ async fn handle_assistant_item_done_in_plan_mode(
.await;
}
sess.record_conversation_items(turn_context, std::slice::from_ref(item))
.await;
record_completed_response_item(sess, turn_context, item).await;
if let Some(agent_message) = last_assistant_message_from_item(item, true) {
*last_agent_message = Some(agent_message);
}
@ -6254,7 +6254,7 @@ async fn try_run_sampling_request(
needs_follow_up |= output_result.needs_follow_up;
}
ResponseEvent::OutputItemAdded(item) => {
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await {
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode) {
let mut turn_item = turn_item;
let mut seeded_parsed: Option<ParsedAssistantTextDelta> = None;
let mut seeded_item_id: Option<String> = None;
@ -6453,10 +6453,12 @@ async fn try_run_sampling_request(
}
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
responses
.iter()
.rev()
.find_map(|item| last_assistant_message_from_item(item, false))
for item in responses.iter().rev() {
if let Some(message) = last_assistant_message_from_item(item, false) {
return Some(message);
}
}
None
}
use crate::memories::prompts::build_memory_tool_developer_instructions;

View file

@ -0,0 +1,62 @@
use codex_protocol::ThreadId;
pub fn get_thread_id_from_citations(citations: Vec<String>) -> Vec<ThreadId> {
let mut result = Vec::new();
for citation in citations {
let mut ids_block = None;
for (open, close) in [
("<thread_ids>", "</thread_ids>"),
("<rollout_ids>", "</rollout_ids>"),
] {
if let Some((_, rest)) = citation.split_once(open)
&& let Some((ids, _)) = rest.split_once(close)
{
ids_block = Some(ids);
break;
}
}
if let Some(ids_block) = ids_block {
for id in ids_block
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
{
if let Ok(thread_id) = ThreadId::try_from(id) {
result.push(thread_id);
}
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::get_thread_id_from_citations;
use codex_protocol::ThreadId;
use pretty_assertions::assert_eq;
#[test]
fn get_thread_id_from_citations_extracts_thread_ids() {
let first = ThreadId::new();
let second = ThreadId::new();
let citations = vec![format!(
"<memory_citation>\n<citation_entries>\nMEMORY.md:1-2|note=[x]\n</citation_entries>\n<thread_ids>\n{first}\nnot-a-uuid\n{second}\n</thread_ids>\n</memory_citation>"
)];
assert_eq!(get_thread_id_from_citations(citations), vec![first, second]);
}
#[test]
fn get_thread_id_from_citations_supports_legacy_rollout_ids() {
let thread_id = ThreadId::new();
let citations = vec![format!(
"<memory_citation>\n<rollout_ids>\n{thread_id}\n</rollout_ids>\n</memory_citation>"
)];
assert_eq!(get_thread_id_from_citations(citations), vec![thread_id]);
}
}

View file

@ -4,6 +4,7 @@
//! - Phase 1: select rollouts, extract stage-1 raw memories, persist stage-1 outputs, and enqueue consolidation.
//! - Phase 2: claim a global consolidation lock, materialize consolidation inputs, and dispatch one consolidation agent.
pub(crate) mod citations;
mod phase1;
mod phase2;
pub(crate) mod prompts;

View file

@ -11,7 +11,9 @@ use crate::codex::TurnContext;
use crate::error::CodexErr;
use crate::error::Result;
use crate::function_tool::FunctionCallError;
use crate::memories::citations::get_thread_id_from_citations;
use crate::parse_turn_item;
use crate::state_db;
use crate::tools::parallel::ToolCallRuntime;
use crate::tools::router::ToolRouter;
use codex_protocol::models::FunctionCallOutputBody;
@ -24,7 +26,7 @@ use tracing::debug;
use tracing::instrument;
fn strip_hidden_assistant_markup(text: &str, plan_mode: bool) -> String {
let (without_citations, _citations) = strip_citations(text);
let (without_citations, _) = strip_citations(text);
if plan_mode {
strip_proposed_plan_blocks(&without_citations)
} else {
@ -48,6 +50,36 @@ pub(crate) fn raw_assistant_output_text_from_item(item: &ResponseItem) -> Option
None
}
/// Persist a completed model response item and record any cited memory usage.
pub(crate) async fn record_completed_response_item(
sess: &Session,
turn_context: &TurnContext,
item: &ResponseItem,
) {
sess.record_conversation_items(turn_context, std::slice::from_ref(item))
.await;
record_stage1_output_usage_for_completed_item(turn_context, item).await;
}
async fn record_stage1_output_usage_for_completed_item(
turn_context: &TurnContext,
item: &ResponseItem,
) {
let Some(raw_text) = raw_assistant_output_text_from_item(item) else {
return;
};
let (_, citations) = strip_citations(&raw_text);
let thread_ids = get_thread_id_from_citations(citations);
if thread_ids.is_empty() {
return;
}
if let Some(db) = state_db::get_state_db(turn_context.config.as_ref(), None).await {
let _ = db.record_stage1_output_usage(&thread_ids).await;
}
}
/// Handle a completed output item from the model stream, recording it and
/// queuing any tool execution futures. This records items immediately so
/// history and rollout stay in sync even if the turn is later cancelled.
@ -88,8 +120,7 @@ pub(crate) async fn handle_output_item_done(
payload_preview
);
ctx.sess
.record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item))
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
let cancellation_token = ctx.cancellation_token.child_token();
@ -104,7 +135,7 @@ pub(crate) async fn handle_output_item_done(
}
// No tool call: convert messages/reasoning into turn items and mark them as complete.
Ok(None) => {
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await {
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode) {
if previously_active_item.is_none() {
ctx.sess
.emit_turn_item_started(&ctx.turn_context, &turn_item)
@ -116,8 +147,7 @@ pub(crate) async fn handle_output_item_done(
.await;
}
ctx.sess
.record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item))
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
let last_agent_message = last_assistant_message_from_item(&item, plan_mode);
@ -138,8 +168,7 @@ pub(crate) async fn handle_output_item_done(
..Default::default()
},
};
ctx.sess
.record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item))
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
if let Some(response_item) = response_input_to_response_item(&response) {
ctx.sess
@ -161,8 +190,7 @@ pub(crate) async fn handle_output_item_done(
..Default::default()
},
};
ctx.sess
.record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item))
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
if let Some(response_item) = response_input_to_response_item(&response) {
ctx.sess
@ -184,7 +212,7 @@ pub(crate) async fn handle_output_item_done(
Ok(output)
}
pub(crate) async fn handle_non_tool_response_item(
pub(crate) fn handle_non_tool_response_item(
item: &ResponseItem,
plan_mode: bool,
) -> Option<TurnItem> {
@ -286,13 +314,12 @@ mod tests {
}
}
#[tokio::test]
async fn handle_non_tool_response_item_strips_citations_from_assistant_message() {
#[test]
fn handle_non_tool_response_item_strips_citations_from_assistant_message() {
let item = assistant_output_text("hello<oai-mem-citation>doc1</oai-mem-citation> world");
let turn_item = handle_non_tool_response_item(&item, false)
.await
.expect("assistant message should parse");
let turn_item =
handle_non_tool_response_item(&item, false).expect("assistant message should parse");
let TurnItem::AgentMessage(agent_message) = turn_item else {
panic!("expected agent message");

View file

@ -0,0 +1,2 @@
ALTER TABLE stage1_outputs ADD COLUMN usage_count INTEGER;
ALTER TABLE stage1_outputs ADD COLUMN last_usage INTEGER;

View file

@ -3124,6 +3124,105 @@ VALUES (?, ?, ?, ?, ?)
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn record_stage1_output_usage_updates_usage_metadata() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a");
let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b");
let missing = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("missing id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_a,
codex_home.join("workspace-a"),
))
.await
.expect("upsert thread a");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_b,
codex_home.join("workspace-b"),
))
.await
.expect("upsert thread b");
let claim_a = runtime
.try_claim_stage1_job(thread_a, owner, 100, 3600, 64)
.await
.expect("claim stage1 a");
let token_a = match claim_a {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome for a: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(thread_a, token_a.as_str(), 100, "raw a", "sum a", None)
.await
.expect("mark stage1 succeeded a")
);
let claim_b = runtime
.try_claim_stage1_job(thread_b, owner, 101, 3600, 64)
.await
.expect("claim stage1 b");
let token_b = match claim_b {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome for b: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(thread_b, token_b.as_str(), 101, "raw b", "sum b", None)
.await
.expect("mark stage1 succeeded b")
);
let updated_rows = runtime
.record_stage1_output_usage(&[thread_a, thread_a, thread_b, missing])
.await
.expect("record stage1 output usage");
assert_eq!(updated_rows, 3);
let row_a =
sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_a.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("load stage1 usage row a");
let row_b =
sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?")
.bind(thread_b.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("load stage1 usage row b");
assert_eq!(
row_a
.try_get::<i64, _>("usage_count")
.expect("usage_count a"),
2
);
assert_eq!(
row_b
.try_get::<i64, _>("usage_count")
.expect("usage_count b"),
1
);
let last_usage_a = row_a.try_get::<i64, _>("last_usage").expect("last_usage a");
let last_usage_b = row_b.try_get::<i64, _>("last_usage").expect("last_usage b");
assert_eq!(last_usage_a, last_usage_b);
assert!(last_usage_a > 0);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn mark_stage1_job_succeeded_enqueues_global_consolidation() {
let codex_home = unique_temp_dir();

View file

@ -49,6 +49,43 @@ WHERE kind = ? OR kind = ?
Ok(())
}
/// Record usage for cited stage-1 outputs.
///
/// Each thread id increments `usage_count` by one and sets `last_usage` to
/// the current Unix timestamp. Missing rows are ignored.
pub async fn record_stage1_output_usage(
&self,
thread_ids: &[ThreadId],
) -> anyhow::Result<usize> {
if thread_ids.is_empty() {
return Ok(0);
}
let now = Utc::now().timestamp();
let mut tx = self.pool.begin().await?;
let mut updated_rows = 0;
for thread_id in thread_ids {
updated_rows += sqlx::query(
r#"
UPDATE stage1_outputs
SET
usage_count = COALESCE(usage_count, 0) + 1,
last_usage = ?
WHERE thread_id = ?
"#,
)
.bind(now)
.bind(thread_id.to_string())
.execute(&mut *tx)
.await?
.rows_affected() as usize;
}
tx.commit().await?;
Ok(updated_rows)
}
/// Selects and claims stage-1 startup jobs for stale threads.
///
/// Query behavior: