Trim compaction input (#10374)

Two fixes:

1. Include trailing tool output in the total context size calculation.
Otherwise when checking whether compaction should run we ignore newly
added outputs.
2. Trim trailing tool output/tool calls until we can fit the request
into the model context size. Otherwise the compaction endpoint will fail
to compact. We only trim items that can be reproduced again by the model
(tool calls, tool call outputs).
This commit is contained in:
pakrym-oai 2026-02-02 19:03:11 -08:00 committed by GitHub
parent 7e07ec8f73
commit cbfd2a37cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 354 additions and 40 deletions

View file

@ -1051,6 +1051,11 @@ impl Session {
state.get_total_token_usage(state.server_reasoning_included())
}
async fn get_estimated_token_count(&self, turn_context: &TurnContext) -> Option<i64> {
let state = self.state.lock().await;
state.history.estimate_token_count(turn_context)
}
pub(crate) async fn get_base_instructions(&self) -> BaseInstructions {
let state = self.state.lock().await;
BaseInstructions {
@ -3310,6 +3315,7 @@ pub(crate) async fn run_turn(
let model_info = turn_context.client.get_model_info();
let auto_compact_limit = model_info.auto_compact_token_limit().unwrap_or(i64::MAX);
let total_usage_tokens = sess.get_total_token_usage().await;
let event = EventMsg::TurnStarted(TurnStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
collaboration_mode_kind: turn_context.collaboration_mode.mode,
@ -3465,6 +3471,19 @@ pub(crate) async fn run_turn(
let total_usage_tokens = sess.get_total_token_usage().await;
let token_limit_reached = total_usage_tokens >= auto_compact_limit;
let estimated_token_count =
sess.get_estimated_token_count(turn_context.as_ref()).await;
info!(
turn_id = %turn_context.sub_id,
total_usage_tokens,
estimated_token_count = ?estimated_token_count,
auto_compact_limit,
token_limit_reached,
needs_follow_up,
"post sampling token usage"
);
// as long as compaction works well in getting us way below the token limit, we shouldn't worry about being in an infinite loop.
if token_limit_reached && needs_follow_up {
run_auto_compact(&sess, &turn_context).await;

View file

@ -3,6 +3,8 @@ use std::sync::Arc;
use crate::Prompt;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::context_manager::ContextManager;
use crate::context_manager::is_codex_generated_item;
use crate::error::Result as CodexResult;
use crate::protocol::CompactedItem;
use crate::protocol::EventMsg;
@ -11,6 +13,7 @@ use crate::protocol::TurnStartedEvent;
use codex_protocol::items::ContextCompactionItem;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ResponseItem;
use tracing::info;
pub(crate) async fn run_inline_remote_auto_compact_task(
sess: Arc<Session>,
@ -45,7 +48,16 @@ async fn run_remote_compact_task_inner_impl(
let compaction_item = TurnItem::ContextCompaction(ContextCompactionItem::new());
sess.emit_turn_item_started(turn_context, &compaction_item)
.await;
let history = sess.clone_history().await;
let mut history = sess.clone_history().await;
let deleted_items =
trim_function_call_history_to_fit_context_window(&mut history, turn_context.as_ref());
if deleted_items > 0 {
info!(
turn_id = %turn_context.sub_id,
deleted_items,
"trimmed history items before remote compaction"
);
}
// Required to keep `/undo` available after compaction
let ghost_snapshots: Vec<ResponseItem> = history
@ -86,3 +98,31 @@ async fn run_remote_compact_task_inner_impl(
.await;
Ok(())
}
fn trim_function_call_history_to_fit_context_window(
history: &mut ContextManager,
turn_context: &TurnContext,
) -> usize {
let mut deleted_items = 0usize;
let Some(context_window) = turn_context.client.get_model_context_window() else {
return deleted_items;
};
while history
.estimate_token_count(turn_context)
.is_some_and(|estimated_tokens| estimated_tokens > context_window)
{
let Some(last_item) = history.raw_items().last() else {
break;
};
if !is_codex_generated_item(last_item) {
break;
}
if !history.remove_last_item() {
break;
}
deleted_items += 1;
}
deleted_items
}

View file

@ -93,24 +93,7 @@ impl ContextManager {
let base_tokens = i64::try_from(approx_token_count(&base_instructions)).unwrap_or(i64::MAX);
let items_tokens = self.items.iter().fold(0i64, |acc, item| {
acc + match item {
ResponseItem::GhostSnapshot { .. } => 0,
ResponseItem::Reasoning {
encrypted_content: Some(content),
..
}
| ResponseItem::Compaction {
encrypted_content: content,
} => {
let reasoning_bytes = estimate_reasoning_length(content.len());
i64::try_from(approx_tokens_from_byte_count(reasoning_bytes))
.unwrap_or(i64::MAX)
}
item => {
let serialized = serde_json::to_string(item).unwrap_or_default();
i64::try_from(approx_token_count(&serialized)).unwrap_or(i64::MAX)
}
}
acc.saturating_add(estimate_item_token_count(item))
});
Some(base_tokens.saturating_add(items_tokens))
@ -128,6 +111,15 @@ impl ContextManager {
}
}
pub(crate) fn remove_last_item(&mut self) -> bool {
if let Some(removed) = self.items.pop() {
normalize::remove_corresponding_for(&mut self.items, &removed);
true
} else {
false
}
}
pub(crate) fn replace(&mut self, items: Vec<ResponseItem>) {
self.items = items;
}
@ -207,36 +199,42 @@ impl ContextManager {
);
}
fn get_non_last_reasoning_items_tokens(&self) -> usize {
// get reasoning items excluding all the ones after the last user message
fn get_non_last_reasoning_items_tokens(&self) -> i64 {
// Get reasoning items excluding all the ones after the last user message.
let Some(last_user_index) = self
.items
.iter()
.rposition(|item| matches!(item, ResponseItem::Message { role, .. } if role == "user"))
else {
return 0usize;
return 0;
};
let total_reasoning_bytes = self
.items
self.items
.iter()
.take(last_user_index)
.filter_map(|item| {
if let ResponseItem::Reasoning {
encrypted_content: Some(content),
..
} = item
{
Some(content.len())
} else {
None
}
.filter(|item| {
matches!(
item,
ResponseItem::Reasoning {
encrypted_content: Some(_),
..
}
)
})
.map(estimate_reasoning_length)
.fold(0usize, usize::saturating_add);
.fold(0i64, |acc, item| {
acc.saturating_add(estimate_item_token_count(item))
})
}
let token_estimate = approx_tokens_from_byte_count(total_reasoning_bytes);
token_estimate as usize
fn get_trailing_codex_generated_items_tokens(&self) -> i64 {
let mut total = 0i64;
for item in self.items.iter().rev() {
if !is_codex_generated_item(item) {
break;
}
total = total.saturating_add(estimate_item_token_count(item));
}
total
}
/// When true, the server already accounted for past reasoning tokens and
@ -247,10 +245,13 @@ impl ContextManager {
.as_ref()
.map(|info| info.last_token_usage.total_tokens)
.unwrap_or(0);
let trailing_codex_generated_tokens = self.get_trailing_codex_generated_items_tokens();
if server_reasoning_included {
last_tokens
last_tokens.saturating_add(trailing_codex_generated_tokens)
} else {
last_tokens.saturating_add(self.get_non_last_reasoning_items_tokens() as i64)
last_tokens
.saturating_add(self.get_non_last_reasoning_items_tokens())
.saturating_add(trailing_codex_generated_tokens)
}
}
@ -332,6 +333,33 @@ fn estimate_reasoning_length(encoded_len: usize) -> usize {
.saturating_sub(650)
}
fn estimate_item_token_count(item: &ResponseItem) -> i64 {
match item {
ResponseItem::GhostSnapshot { .. } => 0,
ResponseItem::Reasoning {
encrypted_content: Some(content),
..
}
| ResponseItem::Compaction {
encrypted_content: content,
} => {
let reasoning_bytes = estimate_reasoning_length(content.len());
i64::try_from(approx_tokens_from_byte_count(reasoning_bytes)).unwrap_or(i64::MAX)
}
item => {
let serialized = serde_json::to_string(item).unwrap_or_default();
i64::try_from(approx_token_count(&serialized)).unwrap_or(i64::MAX)
}
}
}
pub(crate) fn is_codex_generated_item(item: &ResponseItem) -> bool {
matches!(
item,
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. }
) || matches!(item, ResponseItem::Message { role, .. } if role == "developer")
}
pub(crate) fn is_user_turn_boundary(item: &ResponseItem) -> bool {
let ResponseItem::Message { role, content, .. } = item else {
return false;

View file

@ -60,6 +60,23 @@ fn user_input_text_msg(text: &str) -> ResponseItem {
}
}
fn function_call_output(call_id: &str, content: &str) -> ResponseItem {
ResponseItem::FunctionCallOutput {
call_id: call_id.to_string(),
output: FunctionCallOutputPayload {
content: content.to_string(),
..Default::default()
},
}
}
fn custom_tool_call_output(call_id: &str, output: &str) -> ResponseItem {
ResponseItem::CustomToolCallOutput {
call_id: call_id.to_string(),
output: output.to_string(),
}
}
fn reasoning_msg(text: &str) -> ResponseItem {
ResponseItem::Reasoning {
id: String::new(),
@ -168,6 +185,63 @@ fn non_last_reasoning_tokens_ignore_entries_after_last_user() {
assert_eq!(history.get_non_last_reasoning_items_tokens(), 32);
}
#[test]
fn trailing_codex_generated_tokens_stop_at_first_non_generated_item() {
let earlier_output = function_call_output("call-earlier", "earlier output");
let trailing_function_output = function_call_output("call-tail-1", "tail function output");
let trailing_custom_output = custom_tool_call_output("call-tail-2", "tail custom output");
let history = create_history_with_items(vec![
earlier_output,
user_msg("boundary item"),
trailing_function_output.clone(),
trailing_custom_output.clone(),
]);
let expected_tokens = estimate_item_token_count(&trailing_function_output)
.saturating_add(estimate_item_token_count(&trailing_custom_output));
assert_eq!(
history.get_trailing_codex_generated_items_tokens(),
expected_tokens
);
}
#[test]
fn trailing_codex_generated_tokens_exclude_function_call_tail() {
let history = create_history_with_items(vec![ResponseItem::FunctionCall {
id: None,
name: "not-generated".to_string(),
arguments: "{}".to_string(),
call_id: "call-tail".to_string(),
}]);
assert_eq!(history.get_trailing_codex_generated_items_tokens(), 0);
}
#[test]
fn total_token_usage_includes_only_trailing_codex_generated_items() {
let non_trailing_output = function_call_output("call-before-message", "not trailing");
let trailing_assistant = assistant_msg("assistant boundary");
let trailing_output = custom_tool_call_output("tool-tail", "trailing output");
let mut history = create_history_with_items(vec![
non_trailing_output,
user_msg("boundary"),
trailing_assistant,
trailing_output.clone(),
]);
history.update_token_info(
&TokenUsage {
total_tokens: 100,
..Default::default()
},
None,
);
assert_eq!(
history.get_total_token_usage(true),
100 + estimate_item_token_count(&trailing_output)
);
}
#[test]
fn get_history_for_prompt_drops_ghost_commits() {
let items = vec![ResponseItem::GhostSnapshot {
@ -222,6 +296,30 @@ fn remove_first_item_removes_matching_call_for_output() {
assert_eq!(h.raw_items(), vec![]);
}
#[test]
fn remove_last_item_removes_matching_call_for_output() {
let items = vec![
user_msg("before tool call"),
ResponseItem::FunctionCall {
id: None,
name: "do_it".to_string(),
arguments: "{}".to_string(),
call_id: "call-delete-last".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-delete-last".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
},
];
let mut h = create_history_with_items(items);
assert!(h.remove_last_item());
assert_eq!(h.raw_items(), vec![user_msg("before tool call")]);
}
#[test]
fn replace_last_turn_images_replaces_tool_output_images() {
let items = vec![

View file

@ -2,4 +2,5 @@ mod history;
mod normalize;
pub(crate) use history::ContextManager;
pub(crate) use history::is_codex_generated_item;
pub(crate) use history::is_user_turn_boundary;

View file

@ -222,6 +222,134 @@ async fn remote_compact_runs_automatically() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_compact_trims_function_call_history_to_fit_context_window() -> Result<()> {
skip_if_no_network!(Ok(()));
let first_user_message = "turn with retained shell call";
let second_user_message = "turn with trimmed shell call";
let retained_call_id = "retained-call";
let trimmed_call_id = "trimmed-call";
let retained_command = "echo retained-shell-output";
let trimmed_command = "yes x | head -n 3000";
let harness = TestCodexHarness::with_builder(
test_codex()
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
.with_config(|config| {
config.features.enable(Feature::RemoteCompaction);
config.model_context_window = Some(2_000);
}),
)
.await?;
let codex = harness.test().codex.clone();
let response_log = responses::mount_sse_sequence(
harness.server(),
vec![
sse(vec![
responses::ev_shell_command_call(retained_call_id, retained_command),
responses::ev_completed("retained-call-response"),
]),
sse(vec![
responses::ev_assistant_message("retained-assistant", "retained complete"),
responses::ev_completed("retained-final-response"),
]),
sse(vec![
responses::ev_shell_command_call(trimmed_call_id, trimmed_command),
responses::ev_completed("trimmed-call-response"),
]),
sse(vec![responses::ev_completed("trimmed-final-response")]),
],
)
.await;
codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: first_user_message.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await?;
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await;
codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: second_user_message.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await?;
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await;
let compact_mock =
responses::mount_compact_json_once(harness.server(), serde_json::json!({ "output": [] }))
.await;
codex.submit(Op::Compact).await?;
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await;
assert!(
response_log
.function_call_output_text(retained_call_id)
.is_some(),
"expected retained shell call to produce function_call_output before compaction"
);
assert!(
response_log
.function_call_output_text(trimmed_call_id)
.is_some(),
"expected trimmed shell call to produce function_call_output before compaction"
);
let compact_request = compact_mock.single_request();
let user_messages = compact_request.message_input_texts("user");
assert!(
user_messages
.iter()
.any(|message| message == first_user_message),
"expected compact request to retain earlier user history"
);
assert!(
user_messages
.iter()
.any(|message| message == second_user_message),
"expected compact request to retain the user boundary message"
);
assert!(
compact_request.has_function_call(retained_call_id)
&& compact_request
.function_call_output_text(retained_call_id)
.is_some(),
"expected compact request to keep the older function call/result pair"
);
assert!(
!compact_request.has_function_call(trimmed_call_id)
&& compact_request
.function_call_output_text(trimmed_call_id)
.is_none(),
"expected compact request to drop the trailing function call/result pair past the boundary"
);
assert_eq!(
compact_request.inputs_of_type("function_call").len(),
1,
"expected exactly one function call after trimming"
);
assert_eq!(
compact_request.inputs_of_type("function_call_output").len(),
1,
"expected exactly one function call output after trimming"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_manual_compact_emits_context_compaction_items() -> Result<()> {
skip_if_no_network!(Ok(()));