feat: record memory usage (#12761)
This commit is contained in:
parent
5441130e0a
commit
e4bfa763f6
7 changed files with 254 additions and 24 deletions
|
|
@ -43,6 +43,7 @@ use crate::stream_events_utils::handle_non_tool_response_item;
|
|||
use crate::stream_events_utils::handle_output_item_done;
|
||||
use crate::stream_events_utils::last_assistant_message_from_item;
|
||||
use crate::stream_events_utils::raw_assistant_output_text_from_item;
|
||||
use crate::stream_events_utils::record_completed_response_item;
|
||||
use crate::terminal;
|
||||
use crate::truncate::TruncationPolicy;
|
||||
use crate::turn_metadata::TurnMetadataState;
|
||||
|
|
@ -6068,7 +6069,7 @@ async fn handle_assistant_item_done_in_plan_mode(
|
|||
{
|
||||
maybe_complete_plan_item_from_message(sess, turn_context, state, item).await;
|
||||
|
||||
if let Some(turn_item) = handle_non_tool_response_item(item, true).await {
|
||||
if let Some(turn_item) = handle_non_tool_response_item(item, true) {
|
||||
emit_turn_item_in_plan_mode(
|
||||
sess,
|
||||
turn_context,
|
||||
|
|
@ -6079,8 +6080,7 @@ async fn handle_assistant_item_done_in_plan_mode(
|
|||
.await;
|
||||
}
|
||||
|
||||
sess.record_conversation_items(turn_context, std::slice::from_ref(item))
|
||||
.await;
|
||||
record_completed_response_item(sess, turn_context, item).await;
|
||||
if let Some(agent_message) = last_assistant_message_from_item(item, true) {
|
||||
*last_agent_message = Some(agent_message);
|
||||
}
|
||||
|
|
@ -6254,7 +6254,7 @@ async fn try_run_sampling_request(
|
|||
needs_follow_up |= output_result.needs_follow_up;
|
||||
}
|
||||
ResponseEvent::OutputItemAdded(item) => {
|
||||
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await {
|
||||
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode) {
|
||||
let mut turn_item = turn_item;
|
||||
let mut seeded_parsed: Option<ParsedAssistantTextDelta> = None;
|
||||
let mut seeded_item_id: Option<String> = None;
|
||||
|
|
@ -6453,10 +6453,12 @@ async fn try_run_sampling_request(
|
|||
}
|
||||
|
||||
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
||||
responses
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|item| last_assistant_message_from_item(item, false))
|
||||
for item in responses.iter().rev() {
|
||||
if let Some(message) = last_assistant_message_from_item(item, false) {
|
||||
return Some(message);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
use crate::memories::prompts::build_memory_tool_developer_instructions;
|
||||
|
|
|
|||
62
codex-rs/core/src/memories/citations.rs
Normal file
62
codex-rs/core/src/memories/citations.rs
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
use codex_protocol::ThreadId;
|
||||
|
||||
pub fn get_thread_id_from_citations(citations: Vec<String>) -> Vec<ThreadId> {
|
||||
let mut result = Vec::new();
|
||||
for citation in citations {
|
||||
let mut ids_block = None;
|
||||
for (open, close) in [
|
||||
("<thread_ids>", "</thread_ids>"),
|
||||
("<rollout_ids>", "</rollout_ids>"),
|
||||
] {
|
||||
if let Some((_, rest)) = citation.split_once(open)
|
||||
&& let Some((ids, _)) = rest.split_once(close)
|
||||
{
|
||||
ids_block = Some(ids);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ids_block) = ids_block {
|
||||
for id in ids_block
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty())
|
||||
{
|
||||
if let Ok(thread_id) = ThreadId::try_from(id) {
|
||||
result.push(thread_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::get_thread_id_from_citations;
|
||||
use codex_protocol::ThreadId;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn get_thread_id_from_citations_extracts_thread_ids() {
|
||||
let first = ThreadId::new();
|
||||
let second = ThreadId::new();
|
||||
|
||||
let citations = vec![format!(
|
||||
"<memory_citation>\n<citation_entries>\nMEMORY.md:1-2|note=[x]\n</citation_entries>\n<thread_ids>\n{first}\nnot-a-uuid\n{second}\n</thread_ids>\n</memory_citation>"
|
||||
)];
|
||||
|
||||
assert_eq!(get_thread_id_from_citations(citations), vec![first, second]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_thread_id_from_citations_supports_legacy_rollout_ids() {
|
||||
let thread_id = ThreadId::new();
|
||||
|
||||
let citations = vec![format!(
|
||||
"<memory_citation>\n<rollout_ids>\n{thread_id}\n</rollout_ids>\n</memory_citation>"
|
||||
)];
|
||||
|
||||
assert_eq!(get_thread_id_from_citations(citations), vec![thread_id]);
|
||||
}
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
//! - Phase 1: select rollouts, extract stage-1 raw memories, persist stage-1 outputs, and enqueue consolidation.
|
||||
//! - Phase 2: claim a global consolidation lock, materialize consolidation inputs, and dispatch one consolidation agent.
|
||||
|
||||
pub(crate) mod citations;
|
||||
mod phase1;
|
||||
mod phase2;
|
||||
pub(crate) mod prompts;
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ use crate::codex::TurnContext;
|
|||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::memories::citations::get_thread_id_from_citations;
|
||||
use crate::parse_turn_item;
|
||||
use crate::state_db;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
|
|
@ -24,7 +26,7 @@ use tracing::debug;
|
|||
use tracing::instrument;
|
||||
|
||||
fn strip_hidden_assistant_markup(text: &str, plan_mode: bool) -> String {
|
||||
let (without_citations, _citations) = strip_citations(text);
|
||||
let (without_citations, _) = strip_citations(text);
|
||||
if plan_mode {
|
||||
strip_proposed_plan_blocks(&without_citations)
|
||||
} else {
|
||||
|
|
@ -48,6 +50,36 @@ pub(crate) fn raw_assistant_output_text_from_item(item: &ResponseItem) -> Option
|
|||
None
|
||||
}
|
||||
|
||||
/// Persist a completed model response item and record any cited memory usage.
|
||||
pub(crate) async fn record_completed_response_item(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
item: &ResponseItem,
|
||||
) {
|
||||
sess.record_conversation_items(turn_context, std::slice::from_ref(item))
|
||||
.await;
|
||||
record_stage1_output_usage_for_completed_item(turn_context, item).await;
|
||||
}
|
||||
|
||||
async fn record_stage1_output_usage_for_completed_item(
|
||||
turn_context: &TurnContext,
|
||||
item: &ResponseItem,
|
||||
) {
|
||||
let Some(raw_text) = raw_assistant_output_text_from_item(item) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (_, citations) = strip_citations(&raw_text);
|
||||
let thread_ids = get_thread_id_from_citations(citations);
|
||||
if thread_ids.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(db) = state_db::get_state_db(turn_context.config.as_ref(), None).await {
|
||||
let _ = db.record_stage1_output_usage(&thread_ids).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a completed output item from the model stream, recording it and
|
||||
/// queuing any tool execution futures. This records items immediately so
|
||||
/// history and rollout stay in sync even if the turn is later cancelled.
|
||||
|
|
@ -88,8 +120,7 @@ pub(crate) async fn handle_output_item_done(
|
|||
payload_preview
|
||||
);
|
||||
|
||||
ctx.sess
|
||||
.record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item))
|
||||
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
|
||||
.await;
|
||||
|
||||
let cancellation_token = ctx.cancellation_token.child_token();
|
||||
|
|
@ -104,7 +135,7 @@ pub(crate) async fn handle_output_item_done(
|
|||
}
|
||||
// No tool call: convert messages/reasoning into turn items and mark them as complete.
|
||||
Ok(None) => {
|
||||
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await {
|
||||
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode) {
|
||||
if previously_active_item.is_none() {
|
||||
ctx.sess
|
||||
.emit_turn_item_started(&ctx.turn_context, &turn_item)
|
||||
|
|
@ -116,8 +147,7 @@ pub(crate) async fn handle_output_item_done(
|
|||
.await;
|
||||
}
|
||||
|
||||
ctx.sess
|
||||
.record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item))
|
||||
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
|
||||
.await;
|
||||
let last_agent_message = last_assistant_message_from_item(&item, plan_mode);
|
||||
|
||||
|
|
@ -138,8 +168,7 @@ pub(crate) async fn handle_output_item_done(
|
|||
..Default::default()
|
||||
},
|
||||
};
|
||||
ctx.sess
|
||||
.record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item))
|
||||
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
|
||||
.await;
|
||||
if let Some(response_item) = response_input_to_response_item(&response) {
|
||||
ctx.sess
|
||||
|
|
@ -161,8 +190,7 @@ pub(crate) async fn handle_output_item_done(
|
|||
..Default::default()
|
||||
},
|
||||
};
|
||||
ctx.sess
|
||||
.record_conversation_items(&ctx.turn_context, std::slice::from_ref(&item))
|
||||
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
|
||||
.await;
|
||||
if let Some(response_item) = response_input_to_response_item(&response) {
|
||||
ctx.sess
|
||||
|
|
@ -184,7 +212,7 @@ pub(crate) async fn handle_output_item_done(
|
|||
Ok(output)
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_non_tool_response_item(
|
||||
pub(crate) fn handle_non_tool_response_item(
|
||||
item: &ResponseItem,
|
||||
plan_mode: bool,
|
||||
) -> Option<TurnItem> {
|
||||
|
|
@ -286,13 +314,12 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_non_tool_response_item_strips_citations_from_assistant_message() {
|
||||
#[test]
|
||||
fn handle_non_tool_response_item_strips_citations_from_assistant_message() {
|
||||
let item = assistant_output_text("hello<oai-mem-citation>doc1</oai-mem-citation> world");
|
||||
|
||||
let turn_item = handle_non_tool_response_item(&item, false)
|
||||
.await
|
||||
.expect("assistant message should parse");
|
||||
let turn_item =
|
||||
handle_non_tool_response_item(&item, false).expect("assistant message should parse");
|
||||
|
||||
let TurnItem::AgentMessage(agent_message) = turn_item else {
|
||||
panic!("expected agent message");
|
||||
|
|
|
|||
2
codex-rs/state/migrations/0016_memory_usage.sql
Normal file
2
codex-rs/state/migrations/0016_memory_usage.sql
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
ALTER TABLE stage1_outputs ADD COLUMN usage_count INTEGER;
|
||||
ALTER TABLE stage1_outputs ADD COLUMN last_usage INTEGER;
|
||||
|
|
@ -3124,6 +3124,105 @@ VALUES (?, ?, ?, ?, ?)
|
|||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_stage1_output_usage_updates_usage_metadata() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id a");
|
||||
let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id b");
|
||||
let missing = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("missing id");
|
||||
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
|
||||
|
||||
runtime
|
||||
.upsert_thread(&test_thread_metadata(
|
||||
&codex_home,
|
||||
thread_a,
|
||||
codex_home.join("workspace-a"),
|
||||
))
|
||||
.await
|
||||
.expect("upsert thread a");
|
||||
runtime
|
||||
.upsert_thread(&test_thread_metadata(
|
||||
&codex_home,
|
||||
thread_b,
|
||||
codex_home.join("workspace-b"),
|
||||
))
|
||||
.await
|
||||
.expect("upsert thread b");
|
||||
|
||||
let claim_a = runtime
|
||||
.try_claim_stage1_job(thread_a, owner, 100, 3600, 64)
|
||||
.await
|
||||
.expect("claim stage1 a");
|
||||
let token_a = match claim_a {
|
||||
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
|
||||
other => panic!("unexpected stage1 claim outcome for a: {other:?}"),
|
||||
};
|
||||
assert!(
|
||||
runtime
|
||||
.mark_stage1_job_succeeded(thread_a, token_a.as_str(), 100, "raw a", "sum a", None)
|
||||
.await
|
||||
.expect("mark stage1 succeeded a")
|
||||
);
|
||||
|
||||
let claim_b = runtime
|
||||
.try_claim_stage1_job(thread_b, owner, 101, 3600, 64)
|
||||
.await
|
||||
.expect("claim stage1 b");
|
||||
let token_b = match claim_b {
|
||||
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
|
||||
other => panic!("unexpected stage1 claim outcome for b: {other:?}"),
|
||||
};
|
||||
assert!(
|
||||
runtime
|
||||
.mark_stage1_job_succeeded(thread_b, token_b.as_str(), 101, "raw b", "sum b", None)
|
||||
.await
|
||||
.expect("mark stage1 succeeded b")
|
||||
);
|
||||
|
||||
let updated_rows = runtime
|
||||
.record_stage1_output_usage(&[thread_a, thread_a, thread_b, missing])
|
||||
.await
|
||||
.expect("record stage1 output usage");
|
||||
assert_eq!(updated_rows, 3);
|
||||
|
||||
let row_a =
|
||||
sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?")
|
||||
.bind(thread_a.to_string())
|
||||
.fetch_one(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("load stage1 usage row a");
|
||||
let row_b =
|
||||
sqlx::query("SELECT usage_count, last_usage FROM stage1_outputs WHERE thread_id = ?")
|
||||
.bind(thread_b.to_string())
|
||||
.fetch_one(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("load stage1 usage row b");
|
||||
|
||||
assert_eq!(
|
||||
row_a
|
||||
.try_get::<i64, _>("usage_count")
|
||||
.expect("usage_count a"),
|
||||
2
|
||||
);
|
||||
assert_eq!(
|
||||
row_b
|
||||
.try_get::<i64, _>("usage_count")
|
||||
.expect("usage_count b"),
|
||||
1
|
||||
);
|
||||
|
||||
let last_usage_a = row_a.try_get::<i64, _>("last_usage").expect("last_usage a");
|
||||
let last_usage_b = row_b.try_get::<i64, _>("last_usage").expect("last_usage b");
|
||||
assert_eq!(last_usage_a, last_usage_b);
|
||||
assert!(last_usage_a > 0);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mark_stage1_job_succeeded_enqueues_global_consolidation() {
|
||||
let codex_home = unique_temp_dir();
|
||||
|
|
|
|||
|
|
@ -49,6 +49,43 @@ WHERE kind = ? OR kind = ?
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Record usage for cited stage-1 outputs.
|
||||
///
|
||||
/// Each thread id increments `usage_count` by one and sets `last_usage` to
|
||||
/// the current Unix timestamp. Missing rows are ignored.
|
||||
pub async fn record_stage1_output_usage(
|
||||
&self,
|
||||
thread_ids: &[ThreadId],
|
||||
) -> anyhow::Result<usize> {
|
||||
if thread_ids.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let now = Utc::now().timestamp();
|
||||
let mut tx = self.pool.begin().await?;
|
||||
let mut updated_rows = 0;
|
||||
|
||||
for thread_id in thread_ids {
|
||||
updated_rows += sqlx::query(
|
||||
r#"
|
||||
UPDATE stage1_outputs
|
||||
SET
|
||||
usage_count = COALESCE(usage_count, 0) + 1,
|
||||
last_usage = ?
|
||||
WHERE thread_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(now)
|
||||
.bind(thread_id.to_string())
|
||||
.execute(&mut *tx)
|
||||
.await?
|
||||
.rows_affected() as usize;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(updated_rows)
|
||||
}
|
||||
|
||||
/// Selects and claims stage-1 startup jobs for stale threads.
|
||||
///
|
||||
/// Query behavior:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue