From e4bfa763f66a31bf5507c5d743da91b3da14f6aa Mon Sep 17 00:00:00 2001 From: jif-oai Date: Wed, 25 Feb 2026 13:48:40 +0000 Subject: [PATCH] feat: record memory usage (#12761) --- codex-rs/core/src/codex.rs | 18 ++-- codex-rs/core/src/memories/citations.rs | 62 ++++++++++++ codex-rs/core/src/memories/mod.rs | 1 + codex-rs/core/src/stream_events_utils.rs | 59 ++++++++--- .../state/migrations/0016_memory_usage.sql | 2 + codex-rs/state/src/runtime.rs | 99 +++++++++++++++++++ codex-rs/state/src/runtime/memories.rs | 37 +++++++ 7 files changed, 254 insertions(+), 24 deletions(-) create mode 100644 codex-rs/core/src/memories/citations.rs create mode 100644 codex-rs/state/migrations/0016_memory_usage.sql diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index a58f2e1e0..9487151cc 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -43,6 +43,7 @@ use crate::stream_events_utils::handle_non_tool_response_item; use crate::stream_events_utils::handle_output_item_done; use crate::stream_events_utils::last_assistant_message_from_item; use crate::stream_events_utils::raw_assistant_output_text_from_item; +use crate::stream_events_utils::record_completed_response_item; use crate::terminal; use crate::truncate::TruncationPolicy; use crate::turn_metadata::TurnMetadataState; @@ -6068,7 +6069,7 @@ async fn handle_assistant_item_done_in_plan_mode( { maybe_complete_plan_item_from_message(sess, turn_context, state, item).await; - if let Some(turn_item) = handle_non_tool_response_item(item, true).await { + if let Some(turn_item) = handle_non_tool_response_item(item, true) { emit_turn_item_in_plan_mode( sess, turn_context, @@ -6079,8 +6080,7 @@ async fn handle_assistant_item_done_in_plan_mode( .await; } - sess.record_conversation_items(turn_context, std::slice::from_ref(item)) - .await; + record_completed_response_item(sess, turn_context, item).await; if let Some(agent_message) = last_assistant_message_from_item(item, true) { *last_agent_message = Some(agent_message); } @@ -6254,7 +6254,7 @@ async fn try_run_sampling_request( needs_follow_up |= output_result.needs_follow_up; } ResponseEvent::OutputItemAdded(item) => { - if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await { + if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode) { let mut turn_item = turn_item; let mut seeded_parsed: Option = None; let mut seeded_item_id: Option = None; @@ -6453,10 +6453,12 @@ async fn try_run_sampling_request( } pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option { - responses - .iter() - .rev() - .find_map(|item| last_assistant_message_from_item(item, false)) + for item in responses.iter().rev() { + if let Some(message) = last_assistant_message_from_item(item, false) { + return Some(message); + } + } + None } use crate::memories::prompts::build_memory_tool_developer_instructions; diff --git a/codex-rs/core/src/memories/citations.rs b/codex-rs/core/src/memories/citations.rs new file mode 100644 index 000000000..91c777826 --- /dev/null +++ b/codex-rs/core/src/memories/citations.rs @@ -0,0 +1,62 @@ +use codex_protocol::ThreadId; + +pub fn get_thread_id_from_citations(citations: Vec) -> Vec { + let mut result = Vec::new(); + for citation in citations { + let mut ids_block = None; + for (open, close) in [ + ("", ""), + ("", ""), + ] { + if let Some((_, rest)) = citation.split_once(open) + && let Some((ids, _)) = rest.split_once(close) + { + ids_block = Some(ids); + break; + } + } + + if let Some(ids_block) = ids_block { + for id in ids_block + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + { + if let Ok(thread_id) = ThreadId::try_from(id) { + result.push(thread_id); + } + } + } + } + result +} + +#[cfg(test)] +mod tests { + use super::get_thread_id_from_citations; + use codex_protocol::ThreadId; + use pretty_assertions::assert_eq; + + #[test] + fn get_thread_id_from_citations_extracts_thread_ids() { + let first = ThreadId::new(); + let second = ThreadId::new(); + + let citations = vec![format!( + "\n\nMEMORY.md:1-2|note=[x]\n\n\n{first}\nnot-a-uuid\n{second}\n\n" + )]; + + assert_eq!(get_thread_id_from_citations(citations), vec![first, second]); + } + + #[test] + fn get_thread_id_from_citations_supports_legacy_rollout_ids() { + let thread_id = ThreadId::new(); + + let citations = vec![format!( + "\n\n{thread_id}\n\n" + )]; + + assert_eq!(get_thread_id_from_citations(citations), vec![thread_id]); + } +} diff --git a/codex-rs/core/src/memories/mod.rs b/codex-rs/core/src/memories/mod.rs index 1f892646c..a90eba6be 100644 --- a/codex-rs/core/src/memories/mod.rs +++ b/codex-rs/core/src/memories/mod.rs @@ -4,6 +4,7 @@ //! - Phase 1: select rollouts, extract stage-1 raw memories, persist stage-1 outputs, and enqueue consolidation. //! - Phase 2: claim a global consolidation lock, materialize consolidation inputs, and dispatch one consolidation agent. +pub(crate) mod citations; mod phase1; mod phase2; pub(crate) mod prompts; diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index 753336554..fe72db09f 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -11,7 +11,9 @@ use crate::codex::TurnContext; use crate::error::CodexErr; use crate::error::Result; use crate::function_tool::FunctionCallError; +use crate::memories::citations::get_thread_id_from_citations; use crate::parse_turn_item; +use crate::state_db; use crate::tools::parallel::ToolCallRuntime; use crate::tools::router::ToolRouter; use codex_protocol::models::FunctionCallOutputBody; @@ -24,7 +26,7 @@ use tracing::debug; use tracing::instrument; fn strip_hidden_assistant_markup(text: &str, plan_mode: bool) -> String { - let (without_citations, _citations) = strip_citations(text); + let (without_citations, _) = strip_citations(text); if plan_mode { strip_proposed_plan_blocks(&without_citations) } else { @@ -48,6 +50,36 @@ pub(crate) fn raw_assistant_output_text_from_item(item: &ResponseItem) -> Option None } +/// Persist a completed model response item and record any cited memory usage. +pub(crate) async fn record_completed_response_item( + sess: &Session, + turn_context: &TurnContext, + item: &ResponseItem, +) { + sess.record_conversation_items(turn_context, std::slice::from_ref(item)) + .await; + record_stage1_output_usage_for_completed_item(turn_context, item).await; +} + +async fn record_stage1_output_usage_for_completed_item( + turn_context: &TurnContext, + item: &ResponseItem, +) { + let Some(raw_text) = raw_assistant_output_text_from_item(item) else { + return; + }; + + let (_, citations) = strip_citations(&raw_text); + let thread_ids = get_thread_id_from_citations(citations); + if thread_ids.is_empty() { + return; + } + + if let Some(db) = state_db::get_state_db(turn_context.config.as_ref(), None).await { + let _ = db.record_stage1_output_usage(&thread_ids).await; + } +} + /// Handle a completed output item from the model stream, recording it and /// queuing any tool execution futures. This records items immediately so /// history and rollout stay in sync even if the turn is later cancelled. @@ -88,8 +120,7 @@ pub(crate) async fn handle_output_item_done( payload_preview ); - ctx.sess - .record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item)) + record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item) .await; let cancellation_token = ctx.cancellation_token.child_token(); @@ -104,7 +135,7 @@ pub(crate) async fn handle_output_item_done( } // No tool call: convert messages/reasoning into turn items and mark them as complete. Ok(None) => { - if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await { + if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode) { if previously_active_item.is_none() { ctx.sess .emit_turn_item_started(&ctx.turn_context, &turn_item) @@ -116,8 +147,7 @@ pub(crate) async fn handle_output_item_done( .await; } - ctx.sess - .record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item)) + record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item) .await; let last_agent_message = last_assistant_message_from_item(&item, plan_mode); @@ -138,8 +168,7 @@ pub(crate) async fn handle_output_item_done( ..Default::default() }, }; - ctx.sess - .record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item)) + record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item) .await; if let Some(response_item) = response_input_to_response_item(&response) { ctx.sess @@ -161,8 +190,7 @@ pub(crate) async fn handle_output_item_done( ..Default::default() }, }; - ctx.sess - .record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item)) + record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item) .await; if let Some(response_item) = response_input_to_response_item(&response) { ctx.sess @@ -184,7 +212,7 @@ pub(crate) async fn handle_output_item_done( Ok(output) } -pub(crate) async fn handle_non_tool_response_item( +pub(crate) fn handle_non_tool_response_item( item: &ResponseItem, plan_mode: bool, ) -> Option { @@ -286,13 +314,12 @@ mod tests { } } - #[tokio::test] - async fn handle_non_tool_response_item_strips_citations_from_assistant_message() { + #[test] + fn handle_non_tool_response_item_strips_citations_from_assistant_message() { let item = assistant_output_text("hellodoc1 world"); - let turn_item = handle_non_tool_response_item(&item, false) - .await - .expect("assistant message should parse"); + let turn_item = + handle_non_tool_response_item(&item, false).expect("assistant message should parse"); let TurnItem::AgentMessage(agent_message) = turn_item else { panic!("expected agent message"); diff --git a/codex-rs/state/migrations/0016_memory_usage.sql b/codex-rs/state/migrations/0016_memory_usage.sql new file mode 100644 index 000000000..1067ab2e7 --- /dev/null +++ b/codex-rs/state/migrations/0016_memory_usage.sql @@ -0,0 +1,2 @@ +ALTER TABLE stage1_outputs ADD COLUMN usage_count INTEGER; +ALTER TABLE stage1_outputs ADD COLUMN last_usage INTEGER; diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index 513345fbb..f0c623050 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -3124,6 +3124,105 @@ VALUES (?, ?, ?, ?, ?) let _ = tokio::fs::remove_dir_all(codex_home).await; } + #[tokio::test] + async fn record_stage1_output_usage_updates_usage_metadata() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a"); + let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b"); + let missing = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("missing id"); + let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id"); + + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_a, + codex_home.join("workspace-a"), + )) + .await + .expect("upsert thread a"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + thread_b, + codex_home.join("workspace-b"), + )) + .await + .expect("upsert thread b"); + + let claim_a = runtime + .try_claim_stage1_job(thread_a, owner, 100, 3600, 64) + .await + .expect("claim stage1 a"); + let token_a = match claim_a { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome for a: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded(thread_a, token_a.as_str(), 100, "raw a", "sum a", None) + .await + .expect("mark stage1 succeeded a") + ); + + let claim_b = runtime + .try_claim_stage1_job(thread_b, owner, 101, 3600, 64) + .await + .expect("claim stage1 b"); + let token_b = match claim_b { + Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token, + other => panic!("unexpected stage1 claim outcome for b: {other:?}"), + }; + assert!( + runtime + .mark_stage1_job_succeeded(thread_b, token_b.as_str(), 101, "raw b", "sum b", None) + .await + .expect("mark stage1 succeeded b") + ); + + let updated_rows = runtime + .record_stage1_output_usage(&[thread_a, thread_a, thread_b, missing]) + .await + .expect("record stage1 output usage"); + assert_eq!(updated_rows, 3); + + let row_a = + sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?") + .bind(thread_a.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load stage1 usage row a"); + let row_b = + sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?") + .bind(thread_b.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("load stage1 usage row b"); + + assert_eq!( + row_a + .try_get::("usage_count") + .expect("usage_count a"), + 2 + ); + assert_eq!( + row_b + .try_get::("usage_count") + .expect("usage_count b"), + 1 + ); + + let last_usage_a = row_a.try_get::("last_usage").expect("last_usage a"); + let last_usage_b = row_b.try_get::("last_usage").expect("last_usage b"); + assert_eq!(last_usage_a, last_usage_b); + assert!(last_usage_a > 0); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + #[tokio::test] async fn mark_stage1_job_succeeded_enqueues_global_consolidation() { let codex_home = unique_temp_dir(); diff --git a/codex-rs/state/src/runtime/memories.rs b/codex-rs/state/src/runtime/memories.rs index 1899aec8b..6d5ffc888 100644 --- a/codex-rs/state/src/runtime/memories.rs +++ b/codex-rs/state/src/runtime/memories.rs @@ -49,6 +49,43 @@ WHERE kind = ? OR kind = ? Ok(()) } + /// Record usage for cited stage-1 outputs. + /// + /// Each thread id increments `usage_count` by one and sets `last_usage` to + /// the current Unix timestamp. Missing rows are ignored. + pub async fn record_stage1_output_usage( + &self, + thread_ids: &[ThreadId], + ) -> anyhow::Result { + if thread_ids.is_empty() { + return Ok(0); + } + + let now = Utc::now().timestamp(); + let mut tx = self.pool.begin().await?; + let mut updated_rows = 0; + + for thread_id in thread_ids { + updated_rows += sqlx::query( + r#" +UPDATE stage1_outputs +SET + usage_count = COALESCE(usage_count, 0) + 1, + last_usage = ? +WHERE thread_id = ? + "#, + ) + .bind(now) + .bind(thread_id.to_string()) + .execute(&mut *tx) + .await? + .rows_affected() as usize; + } + + tx.commit().await?; + Ok(updated_rows) + } + /// Selects and claims stage-1 startup jobs for stale threads. /// /// Query behavior: