diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 7c9909355..ba84b6b43 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1051,6 +1051,11 @@ impl Session { state.get_total_token_usage(state.server_reasoning_included()) } + async fn get_estimated_token_count(&self, turn_context: &TurnContext) -> Option { + let state = self.state.lock().await; + state.history.estimate_token_count(turn_context) + } + pub(crate) async fn get_base_instructions(&self) -> BaseInstructions { let state = self.state.lock().await; BaseInstructions { @@ -3310,6 +3315,7 @@ pub(crate) async fn run_turn( let model_info = turn_context.client.get_model_info(); let auto_compact_limit = model_info.auto_compact_token_limit().unwrap_or(i64::MAX); let total_usage_tokens = sess.get_total_token_usage().await; + let event = EventMsg::TurnStarted(TurnStartedEvent { model_context_window: turn_context.client.get_model_context_window(), collaboration_mode_kind: turn_context.collaboration_mode.mode, @@ -3465,6 +3471,19 @@ pub(crate) async fn run_turn( let total_usage_tokens = sess.get_total_token_usage().await; let token_limit_reached = total_usage_tokens >= auto_compact_limit; + let estimated_token_count = + sess.get_estimated_token_count(turn_context.as_ref()).await; + + info!( + turn_id = %turn_context.sub_id, + total_usage_tokens, + estimated_token_count = ?estimated_token_count, + auto_compact_limit, + token_limit_reached, + needs_follow_up, + "post sampling token usage" + ); + // as long as compaction works well in getting us way below the token limit, we shouldn't worry about being in an infinite loop. if token_limit_reached && needs_follow_up { run_auto_compact(&sess, &turn_context).await; diff --git a/codex-rs/core/src/compact_remote.rs b/codex-rs/core/src/compact_remote.rs index 9f7a8dfea..12bc769ce 100644 --- a/codex-rs/core/src/compact_remote.rs +++ b/codex-rs/core/src/compact_remote.rs @@ -3,6 +3,8 @@ use std::sync::Arc; use crate::Prompt; use crate::codex::Session; use crate::codex::TurnContext; +use crate::context_manager::ContextManager; +use crate::context_manager::is_codex_generated_item; use crate::error::Result as CodexResult; use crate::protocol::CompactedItem; use crate::protocol::EventMsg; @@ -11,6 +13,7 @@ use crate::protocol::TurnStartedEvent; use codex_protocol::items::ContextCompactionItem; use codex_protocol::items::TurnItem; use codex_protocol::models::ResponseItem; +use tracing::info; pub(crate) async fn run_inline_remote_auto_compact_task( sess: Arc, @@ -45,7 +48,16 @@ async fn run_remote_compact_task_inner_impl( let compaction_item = TurnItem::ContextCompaction(ContextCompactionItem::new()); sess.emit_turn_item_started(turn_context, &compaction_item) .await; - let history = sess.clone_history().await; + let mut history = sess.clone_history().await; + let deleted_items = + trim_function_call_history_to_fit_context_window(&mut history, turn_context.as_ref()); + if deleted_items > 0 { + info!( + turn_id = %turn_context.sub_id, + deleted_items, + "trimmed history items before remote compaction" + ); + } // Required to keep `/undo` available after compaction let ghost_snapshots: Vec = history @@ -86,3 +98,31 @@ async fn run_remote_compact_task_inner_impl( .await; Ok(()) } + +fn trim_function_call_history_to_fit_context_window( + history: &mut ContextManager, + turn_context: &TurnContext, +) -> usize { + let mut deleted_items = 0usize; + let Some(context_window) = turn_context.client.get_model_context_window() else { + return deleted_items; + }; + + while history + .estimate_token_count(turn_context) + .is_some_and(|estimated_tokens| estimated_tokens > context_window) + { + let Some(last_item) = history.raw_items().last() else { + break; + }; + if !is_codex_generated_item(last_item) { + break; + } + if !history.remove_last_item() { + break; + } + deleted_items += 1; + } + + deleted_items +} diff --git a/codex-rs/core/src/context_manager/history.rs b/codex-rs/core/src/context_manager/history.rs index 92dcef5ba..a29f7df7e 100644 --- a/codex-rs/core/src/context_manager/history.rs +++ b/codex-rs/core/src/context_manager/history.rs @@ -93,24 +93,7 @@ impl ContextManager { let base_tokens = i64::try_from(approx_token_count(&base_instructions)).unwrap_or(i64::MAX); let items_tokens = self.items.iter().fold(0i64, |acc, item| { - acc + match item { - ResponseItem::GhostSnapshot { .. } => 0, - ResponseItem::Reasoning { - encrypted_content: Some(content), - .. - } - | ResponseItem::Compaction { - encrypted_content: content, - } => { - let reasoning_bytes = estimate_reasoning_length(content.len()); - i64::try_from(approx_tokens_from_byte_count(reasoning_bytes)) - .unwrap_or(i64::MAX) - } - item => { - let serialized = serde_json::to_string(item).unwrap_or_default(); - i64::try_from(approx_token_count(&serialized)).unwrap_or(i64::MAX) - } - } + acc.saturating_add(estimate_item_token_count(item)) }); Some(base_tokens.saturating_add(items_tokens)) @@ -128,6 +111,15 @@ impl ContextManager { } } + pub(crate) fn remove_last_item(&mut self) -> bool { + if let Some(removed) = self.items.pop() { + normalize::remove_corresponding_for(&mut self.items, &removed); + true + } else { + false + } + } + pub(crate) fn replace(&mut self, items: Vec) { self.items = items; } @@ -207,36 +199,42 @@ impl ContextManager { ); } - fn get_non_last_reasoning_items_tokens(&self) -> usize { - // get reasoning items excluding all the ones after the last user message + fn get_non_last_reasoning_items_tokens(&self) -> i64 { + // Get reasoning items excluding all the ones after the last user message. let Some(last_user_index) = self .items .iter() .rposition(|item| matches!(item, ResponseItem::Message { role, .. } if role == "user")) else { - return 0usize; + return 0; }; - let total_reasoning_bytes = self - .items + self.items .iter() .take(last_user_index) - .filter_map(|item| { - if let ResponseItem::Reasoning { - encrypted_content: Some(content), - .. - } = item - { - Some(content.len()) - } else { - None - } + .filter(|item| { + matches!( + item, + ResponseItem::Reasoning { + encrypted_content: Some(_), + .. + } + ) }) - .map(estimate_reasoning_length) - .fold(0usize, usize::saturating_add); + .fold(0i64, |acc, item| { + acc.saturating_add(estimate_item_token_count(item)) + }) + } - let token_estimate = approx_tokens_from_byte_count(total_reasoning_bytes); - token_estimate as usize + fn get_trailing_codex_generated_items_tokens(&self) -> i64 { + let mut total = 0i64; + for item in self.items.iter().rev() { + if !is_codex_generated_item(item) { + break; + } + total = total.saturating_add(estimate_item_token_count(item)); + } + total } /// When true, the server already accounted for past reasoning tokens and @@ -247,10 +245,13 @@ impl ContextManager { .as_ref() .map(|info| info.last_token_usage.total_tokens) .unwrap_or(0); + let trailing_codex_generated_tokens = self.get_trailing_codex_generated_items_tokens(); if server_reasoning_included { - last_tokens + last_tokens.saturating_add(trailing_codex_generated_tokens) } else { - last_tokens.saturating_add(self.get_non_last_reasoning_items_tokens() as i64) + last_tokens + .saturating_add(self.get_non_last_reasoning_items_tokens()) + .saturating_add(trailing_codex_generated_tokens) } } @@ -332,6 +333,33 @@ fn estimate_reasoning_length(encoded_len: usize) -> usize { .saturating_sub(650) } +fn estimate_item_token_count(item: &ResponseItem) -> i64 { + match item { + ResponseItem::GhostSnapshot { .. } => 0, + ResponseItem::Reasoning { + encrypted_content: Some(content), + .. + } + | ResponseItem::Compaction { + encrypted_content: content, + } => { + let reasoning_bytes = estimate_reasoning_length(content.len()); + i64::try_from(approx_tokens_from_byte_count(reasoning_bytes)).unwrap_or(i64::MAX) + } + item => { + let serialized = serde_json::to_string(item).unwrap_or_default(); + i64::try_from(approx_token_count(&serialized)).unwrap_or(i64::MAX) + } + } +} + +pub(crate) fn is_codex_generated_item(item: &ResponseItem) -> bool { + matches!( + item, + ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } + ) || matches!(item, ResponseItem::Message { role, .. } if role == "developer") +} + pub(crate) fn is_user_turn_boundary(item: &ResponseItem) -> bool { let ResponseItem::Message { role, content, .. } = item else { return false; diff --git a/codex-rs/core/src/context_manager/history_tests.rs b/codex-rs/core/src/context_manager/history_tests.rs index d31f731a9..a6eba62f1 100644 --- a/codex-rs/core/src/context_manager/history_tests.rs +++ b/codex-rs/core/src/context_manager/history_tests.rs @@ -60,6 +60,23 @@ fn user_input_text_msg(text: &str) -> ResponseItem { } } +fn function_call_output(call_id: &str, content: &str) -> ResponseItem { + ResponseItem::FunctionCallOutput { + call_id: call_id.to_string(), + output: FunctionCallOutputPayload { + content: content.to_string(), + ..Default::default() + }, + } +} + +fn custom_tool_call_output(call_id: &str, output: &str) -> ResponseItem { + ResponseItem::CustomToolCallOutput { + call_id: call_id.to_string(), + output: output.to_string(), + } +} + fn reasoning_msg(text: &str) -> ResponseItem { ResponseItem::Reasoning { id: String::new(), @@ -168,6 +185,63 @@ fn non_last_reasoning_tokens_ignore_entries_after_last_user() { assert_eq!(history.get_non_last_reasoning_items_tokens(), 32); } +#[test] +fn trailing_codex_generated_tokens_stop_at_first_non_generated_item() { + let earlier_output = function_call_output("call-earlier", "earlier output"); + let trailing_function_output = function_call_output("call-tail-1", "tail function output"); + let trailing_custom_output = custom_tool_call_output("call-tail-2", "tail custom output"); + let history = create_history_with_items(vec![ + earlier_output, + user_msg("boundary item"), + trailing_function_output.clone(), + trailing_custom_output.clone(), + ]); + let expected_tokens = estimate_item_token_count(&trailing_function_output) + .saturating_add(estimate_item_token_count(&trailing_custom_output)); + + assert_eq!( + history.get_trailing_codex_generated_items_tokens(), + expected_tokens + ); +} + +#[test] +fn trailing_codex_generated_tokens_exclude_function_call_tail() { + let history = create_history_with_items(vec![ResponseItem::FunctionCall { + id: None, + name: "not-generated".to_string(), + arguments: "{}".to_string(), + call_id: "call-tail".to_string(), + }]); + + assert_eq!(history.get_trailing_codex_generated_items_tokens(), 0); +} + +#[test] +fn total_token_usage_includes_only_trailing_codex_generated_items() { + let non_trailing_output = function_call_output("call-before-message", "not trailing"); + let trailing_assistant = assistant_msg("assistant boundary"); + let trailing_output = custom_tool_call_output("tool-tail", "trailing output"); + let mut history = create_history_with_items(vec![ + non_trailing_output, + user_msg("boundary"), + trailing_assistant, + trailing_output.clone(), + ]); + history.update_token_info( + &TokenUsage { + total_tokens: 100, + ..Default::default() + }, + None, + ); + + assert_eq!( + history.get_total_token_usage(true), + 100 + estimate_item_token_count(&trailing_output) + ); +} + #[test] fn get_history_for_prompt_drops_ghost_commits() { let items = vec![ResponseItem::GhostSnapshot { @@ -222,6 +296,30 @@ fn remove_first_item_removes_matching_call_for_output() { assert_eq!(h.raw_items(), vec![]); } +#[test] +fn remove_last_item_removes_matching_call_for_output() { + let items = vec![ + user_msg("before tool call"), + ResponseItem::FunctionCall { + id: None, + name: "do_it".to_string(), + arguments: "{}".to_string(), + call_id: "call-delete-last".to_string(), + }, + ResponseItem::FunctionCallOutput { + call_id: "call-delete-last".to_string(), + output: FunctionCallOutputPayload { + content: "ok".to_string(), + ..Default::default() + }, + }, + ]; + let mut h = create_history_with_items(items); + + assert!(h.remove_last_item()); + assert_eq!(h.raw_items(), vec![user_msg("before tool call")]); +} + #[test] fn replace_last_turn_images_replaces_tool_output_images() { let items = vec![ diff --git a/codex-rs/core/src/context_manager/mod.rs b/codex-rs/core/src/context_manager/mod.rs index baae93c77..22e9682fe 100644 --- a/codex-rs/core/src/context_manager/mod.rs +++ b/codex-rs/core/src/context_manager/mod.rs @@ -2,4 +2,5 @@ mod history; mod normalize; pub(crate) use history::ContextManager; +pub(crate) use history::is_codex_generated_item; pub(crate) use history::is_user_turn_boundary; diff --git a/codex-rs/core/tests/suite/compact_remote.rs b/codex-rs/core/tests/suite/compact_remote.rs index e5446f4a3..55985c9b7 100644 --- a/codex-rs/core/tests/suite/compact_remote.rs +++ b/codex-rs/core/tests/suite/compact_remote.rs @@ -222,6 +222,134 @@ async fn remote_compact_runs_automatically() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn remote_compact_trims_function_call_history_to_fit_context_window() -> Result<()> { + skip_if_no_network!(Ok(())); + + let first_user_message = "turn with retained shell call"; + let second_user_message = "turn with trimmed shell call"; + let retained_call_id = "retained-call"; + let trimmed_call_id = "trimmed-call"; + let retained_command = "echo retained-shell-output"; + let trimmed_command = "yes x | head -n 3000"; + + let harness = TestCodexHarness::with_builder( + test_codex() + .with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) + .with_config(|config| { + config.features.enable(Feature::RemoteCompaction); + config.model_context_window = Some(2_000); + }), + ) + .await?; + let codex = harness.test().codex.clone(); + + let response_log = responses::mount_sse_sequence( + harness.server(), + vec![ + sse(vec![ + responses::ev_shell_command_call(retained_call_id, retained_command), + responses::ev_completed("retained-call-response"), + ]), + sse(vec![ + responses::ev_assistant_message("retained-assistant", "retained complete"), + responses::ev_completed("retained-final-response"), + ]), + sse(vec![ + responses::ev_shell_command_call(trimmed_call_id, trimmed_command), + responses::ev_completed("trimmed-call-response"), + ]), + sse(vec![responses::ev_completed("trimmed-final-response")]), + ], + ) + .await; + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: first_user_message.into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await; + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: second_user_message.into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await; + + let compact_mock = + responses::mount_compact_json_once(harness.server(), serde_json::json!({ "output": [] })) + .await; + + codex.submit(Op::Compact).await?; + wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await; + + assert!( + response_log + .function_call_output_text(retained_call_id) + .is_some(), + "expected retained shell call to produce function_call_output before compaction" + ); + assert!( + response_log + .function_call_output_text(trimmed_call_id) + .is_some(), + "expected trimmed shell call to produce function_call_output before compaction" + ); + + let compact_request = compact_mock.single_request(); + let user_messages = compact_request.message_input_texts("user"); + assert!( + user_messages + .iter() + .any(|message| message == first_user_message), + "expected compact request to retain earlier user history" + ); + assert!( + user_messages + .iter() + .any(|message| message == second_user_message), + "expected compact request to retain the user boundary message" + ); + + assert!( + compact_request.has_function_call(retained_call_id) + && compact_request + .function_call_output_text(retained_call_id) + .is_some(), + "expected compact request to keep the older function call/result pair" + ); + assert!( + !compact_request.has_function_call(trimmed_call_id) + && compact_request + .function_call_output_text(trimmed_call_id) + .is_none(), + "expected compact request to drop the trailing function call/result pair past the boundary" + ); + + assert_eq!( + compact_request.inputs_of_type("function_call").len(), + 1, + "expected exactly one function call after trimming" + ); + assert_eq!( + compact_request.inputs_of_type("function_call_output").len(), + 1, + "expected exactly one function call output after trimming" + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn remote_manual_compact_emits_context_compaction_items() -> Result<()> { skip_if_no_network!(Ok(()));