diff --git a/codex-rs/codex-api/src/common.rs b/codex-rs/codex-api/src/common.rs index 2118cf66e..9a7aab997 100644 --- a/codex-rs/codex-api/src/common.rs +++ b/codex-rs/codex-api/src/common.rs @@ -42,6 +42,10 @@ pub enum ResponseEvent { Created, OutputItemDone(ResponseItem), OutputItemAdded(ResponseItem), + /// Emitted when `X-Reasoning-Included: true` is present on the response, + /// meaning the server already accounted for past reasoning tokens and the + /// client should not re-estimate them. + ServerReasoningIncluded(bool), Completed { response_id: String, token_usage: Option, diff --git a/codex-rs/codex-api/src/endpoint/chat.rs b/codex-rs/codex-api/src/endpoint/chat.rs index cd830a09f..8fe1d2a52 100644 --- a/codex-rs/codex-api/src/endpoint/chat.rs +++ b/codex-rs/codex-api/src/endpoint/chat.rs @@ -157,6 +157,9 @@ impl Stream for AggregatedStream { return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); } + Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included)))) => { + return Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included)))); + } Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); } diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index 3c6cab74e..39e2f2fd0 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -29,18 +29,21 @@ use url::Url; type WsStream = WebSocketStream>; const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state"; +const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included"; pub struct ResponsesWebsocketConnection { stream: Arc>>, // TODO (pakrym): is this the right place for timeout? idle_timeout: Duration, + server_reasoning_included: bool, } impl ResponsesWebsocketConnection { - fn new(stream: WsStream, idle_timeout: Duration) -> Self { + fn new(stream: WsStream, idle_timeout: Duration, server_reasoning_included: bool) -> Self { Self { stream: Arc::new(Mutex::new(Some(stream))), idle_timeout, + server_reasoning_included, } } @@ -56,11 +59,17 @@ impl ResponsesWebsocketConnection { mpsc::channel::>(1600); let stream = Arc::clone(&self.stream); let idle_timeout = self.idle_timeout; + let server_reasoning_included = self.server_reasoning_included; let request_body = serde_json::to_value(&request).map_err(|err| { ApiError::Stream(format!("failed to encode websocket request: {err}")) })?; tokio::spawn(async move { + if server_reasoning_included { + let _ = tx_event + .send(Ok(ResponseEvent::ServerReasoningIncluded(true))) + .await; + } let mut guard = stream.lock().await; let Some(ws_stream) = guard.as_mut() else { let _ = tx_event @@ -111,10 +120,12 @@ impl ResponsesWebsocketClient { headers.extend(extra_headers); apply_auth_headers(&mut headers, &self.auth); - let stream = connect_websocket(ws_url, headers, turn_state).await?; + let (stream, server_reasoning_included) = + connect_websocket(ws_url, headers, turn_state).await?; Ok(ResponsesWebsocketConnection::new( stream, self.provider.stream_idle_timeout, + server_reasoning_included, )) } } @@ -137,7 +148,7 @@ async fn connect_websocket( url: Url, headers: HeaderMap, turn_state: Option>>, -) -> Result { +) -> Result<(WsStream, bool), ApiError> { let mut request = url .clone() .into_client_request() @@ -147,6 +158,7 @@ async fn connect_websocket( let (stream, response) = tokio_tungstenite::connect_async(request) .await .map_err(|err| map_ws_error(err, &url))?; + let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER); if let Some(turn_state) = turn_state && let Some(header_value) = response .headers() @@ -155,7 +167,7 @@ async fn connect_websocket( { let _ = turn_state.set(header_value.to_string()); } - Ok(stream) + Ok((stream, reasoning_included)) } fn map_ws_error(err: WsError, url: &Url) -> ApiError { diff --git a/codex-rs/codex-api/src/sse/responses.rs b/codex-rs/codex-api/src/sse/responses.rs index a70111d98..f23975f8d 100644 --- a/codex-rs/codex-api/src/sse/responses.rs +++ b/codex-rs/codex-api/src/sse/responses.rs @@ -25,6 +25,8 @@ use tokio_util::io::ReaderStream; use tracing::debug; use tracing::trace; +const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included"; + /// Streams SSE events from an on-disk fixture for tests. pub fn stream_from_fixture( path: impl AsRef, @@ -58,6 +60,10 @@ pub fn spawn_response_stream( .get("X-Models-Etag") .and_then(|v| v.to_str().ok()) .map(ToString::to_string); + let reasoning_included = stream_response + .headers + .get(X_REASONING_INCLUDED_HEADER) + .is_some(); if let Some(turn_state) = turn_state.as_ref() && let Some(header_value) = stream_response .headers @@ -74,6 +80,11 @@ pub fn spawn_response_stream( if let Some(etag) = models_etag { let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await; } + if reasoning_included { + let _ = tx_event + .send(Ok(ResponseEvent::ServerReasoningIncluded(true))) + .await; + } process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await; }); diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 3d48810ff..5ca971c12 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -809,7 +809,7 @@ impl Session { async fn get_total_token_usage(&self) -> i64 { let state = self.state.lock().await; - state.get_total_token_usage() + state.get_total_token_usage(state.server_reasoning_included()) } async fn record_initial_history(&self, conversation_history: InitialHistory) { @@ -1618,6 +1618,11 @@ impl Session { self.send_token_count_event(turn_context).await; } + pub(crate) async fn set_server_reasoning_included(&self, included: bool) { + let mut state = self.state.lock().await; + state.set_server_reasoning_included(included); + } + async fn send_token_count_event(&self, turn_context: &TurnContext) { let (info, rate_limits) = { let state = self.state.lock().await; @@ -3149,6 +3154,9 @@ async fn try_run_sampling_request( active_item = Some(tracked_item); } } + ResponseEvent::ServerReasoningIncluded(included) => { + sess.set_server_reasoning_included(included).await; + } ResponseEvent::RateLimits(snapshot) => { // Update internal state with latest rate limits, but defer sending until // token usage is available to avoid duplicate TokenCount events. diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index 250b91415..40e6284de 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -316,6 +316,9 @@ async fn drain_to_completed( sess.record_into_history(std::slice::from_ref(&item), turn_context) .await; } + Ok(ResponseEvent::ServerReasoningIncluded(included)) => { + sess.set_server_reasoning_included(included).await; + } Ok(ResponseEvent::RateLimits(snapshot)) => { sess.update_rate_limits(turn_context, snapshot).await; } diff --git a/codex-rs/core/src/context_manager/history.rs b/codex-rs/core/src/context_manager/history.rs index 0c133bdc2..4feeddc29 100644 --- a/codex-rs/core/src/context_manager/history.rs +++ b/codex-rs/core/src/context_manager/history.rs @@ -235,12 +235,19 @@ impl ContextManager { token_estimate as usize } - pub(crate) fn get_total_token_usage(&self) -> i64 { - self.token_info + /// When true, the server already accounted for past reasoning tokens and + /// the client should not re-estimate them. + pub(crate) fn get_total_token_usage(&self, server_reasoning_included: bool) -> i64 { + let last_tokens = self + .token_info .as_ref() .map(|info| info.last_token_usage.total_tokens) - .unwrap_or(0) - .saturating_add(self.get_non_last_reasoning_items_tokens() as i64) + .unwrap_or(0); + if server_reasoning_included { + last_tokens + } else { + last_tokens.saturating_add(self.get_non_last_reasoning_items_tokens() as i64) + } } /// This function enforces a couple of invariants on the in-memory history: diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index c61d18837..746396949 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -14,6 +14,7 @@ pub(crate) struct SessionState { pub(crate) session_configuration: SessionConfiguration, pub(crate) history: ContextManager, pub(crate) latest_rate_limits: Option, + pub(crate) server_reasoning_included: bool, } impl SessionState { @@ -24,6 +25,7 @@ impl SessionState { session_configuration, history, latest_rate_limits: None, + server_reasoning_included: false, } } @@ -78,8 +80,17 @@ impl SessionState { self.history.set_token_usage_full(context_window); } - pub(crate) fn get_total_token_usage(&self) -> i64 { - self.history.get_total_token_usage() + pub(crate) fn get_total_token_usage(&self, server_reasoning_included: bool) -> i64 { + self.history + .get_total_token_usage(server_reasoning_included) + } + + pub(crate) fn set_server_reasoning_included(&mut self, included: bool) { + self.server_reasoning_included = included; + } + + pub(crate) fn server_reasoning_included(&self) -> bool { + self.server_reasoning_included } } diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index f897d3ce8..cac0cd5da 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -30,6 +30,7 @@ impl SessionTask for RegularTask { ) -> Option { let sess = session.clone_session(); let run_turn_span = trace_span!("run_turn"); + sess.set_server_reasoning_included(false).await; sess.services .otel_manager .apply_traceparent_parent(&run_turn_span); diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index 9d39fb240..1532ac74d 100644 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -15,10 +15,12 @@ use codex_otel::OtelManager; use codex_protocol::ThreadId; use codex_protocol::config_types::ReasoningSummary; use core_test_support::load_default_config_for_test; +use core_test_support::responses::WebSocketConnectionConfig; use core_test_support::responses::WebSocketTestServer; use core_test_support::responses::ev_completed; use core_test_support::responses::ev_response_created; use core_test_support::responses::start_websocket_server; +use core_test_support::responses::start_websocket_server_with_headers; use core_test_support::skip_if_no_network; use futures::StreamExt; use pretty_assertions::assert_eq; @@ -60,6 +62,40 @@ async fn responses_websocket_streams_request() { server.shutdown().await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_websocket_emits_reasoning_included_event() { + skip_if_no_network!(); + + let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig { + requests: vec![vec![ev_response_created("resp-1"), ev_completed("resp-1")]], + response_headers: vec![("X-Reasoning-Included".to_string(), "true".to_string())], + }]) + .await; + + let harness = websocket_harness(&server).await; + let mut session = harness.client.new_session(); + let prompt = prompt_with_input(vec![message_item("hello")]); + + let mut stream = session + .stream(&prompt) + .await + .expect("websocket stream failed"); + + let mut saw_reasoning_included = false; + while let Some(event) = stream.next().await { + match event.expect("event") { + ResponseEvent::ServerReasoningIncluded(true) => { + saw_reasoning_included = true; + } + ResponseEvent::Completed { .. } => break, + _ => {} + } + } + + assert!(saw_reasoning_included); + server.shutdown().await; +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn responses_websocket_appends_on_prefix() { skip_if_no_network!(); diff --git a/codex-rs/core/tests/suite/compact.rs b/codex-rs/core/tests/suite/compact.rs index d033c6664..b06f5ef14 100644 --- a/codex-rs/core/tests/suite/compact.rs +++ b/codex-rs/core/tests/suite/compact.rs @@ -32,11 +32,13 @@ use core_test_support::responses::ev_completed; use core_test_support::responses::ev_completed_with_tokens; use core_test_support::responses::ev_function_call; use core_test_support::responses::mount_compact_json_once; +use core_test_support::responses::mount_response_sequence; use core_test_support::responses::mount_sse_once; use core_test_support::responses::mount_sse_once_match; use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::responses::sse_failed; +use core_test_support::responses::sse_response; use core_test_support::responses::start_mock_server; use pretty_assertions::assert_eq; use serde_json::json; @@ -2147,3 +2149,85 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() { "third turn should include compaction summary item" ); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn auto_compact_runs_when_reasoning_header_clears_between_turns() { + skip_if_no_network!(); + + let server = start_mock_server().await; + + let first_user = "SERVER_INCLUDED_FIRST"; + let second_user = "SERVER_INCLUDED_SECOND"; + let third_user = "SERVER_INCLUDED_THIRD"; + + let pre_last_reasoning_content = "a".repeat(2_400); + let post_last_reasoning_content = "b".repeat(4_000); + + let first_turn = sse(vec![ + ev_reasoning_item("pre-reasoning", &["pre"], &[&pre_last_reasoning_content]), + ev_completed_with_tokens("r1", 10), + ]); + let second_turn = sse(vec![ + ev_reasoning_item("post-reasoning", &["post"], &[&post_last_reasoning_content]), + ev_completed_with_tokens("r2", 80), + ]); + let third_turn = sse(vec![ + ev_assistant_message("m4", FINAL_REPLY), + ev_completed_with_tokens("r4", 1), + ]); + + let responses = vec![ + sse_response(first_turn).insert_header("X-Reasoning-Included", "true"), + sse_response(second_turn), + sse_response(third_turn), + ]; + mount_response_sequence(&server, responses).await; + + let compacted_history = vec![ + codex_protocol::models::ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![codex_protocol::models::ContentItem::OutputText { + text: "REMOTE_COMPACT_SUMMARY".to_string(), + }], + }, + codex_protocol::models::ResponseItem::Compaction { + encrypted_content: "ENCRYPTED_COMPACTION_SUMMARY".to_string(), + }, + ]; + let compact_mock = + mount_compact_json_once(&server, serde_json::json!({ "output": compacted_history })).await; + + let codex = test_codex() + .with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) + .with_config(|config| { + set_test_compact_prompt(config); + config.model_auto_compact_token_limit = Some(300); + config.features.enable(Feature::RemoteCompaction); + }) + .build(&server) + .await + .expect("build codex") + .codex; + + for user in [first_user, second_user, third_user] { + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: user.into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + } + + let compact_requests = compact_mock.requests(); + assert_eq!( + compact_requests.len(), + 1, + "remote compaction should run once after the reasoning header clears" + ); +} diff --git a/codex-rs/otel/src/traces/otel_manager.rs b/codex-rs/otel/src/traces/otel_manager.rs index a1aaf0b51..0847fd882 100644 --- a/codex-rs/otel/src/traces/otel_manager.rs +++ b/codex-rs/otel/src/traces/otel_manager.rs @@ -484,6 +484,7 @@ impl OtelManager { ResponseEvent::ReasoningSummaryPartAdded { .. } => { "reasoning_summary_part_added".into() } + ResponseEvent::ServerReasoningIncluded(_) => "server_reasoning_included".into(), ResponseEvent::RateLimits(_) => "rate_limits".into(), ResponseEvent::ModelsEtag(_) => "models_etag".into(), }