Act on reasoning-included per turn (#9402)

- Reset reasoning-included flag each turn and update compaction test
This commit is contained in:
Ahmed Ibrahim 2026-01-19 11:23:25 -08:00 committed by GitHub
parent 57ec3a8277
commit b11e96fb04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 192 additions and 11 deletions

View file

@ -42,6 +42,10 @@ pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),
OutputItemAdded(ResponseItem),
/// Emitted when `X-Reasoning-Included: true` is present on the response,
/// meaning the server already accounted for past reasoning tokens and the
/// client should not re-estimate them.
ServerReasoningIncluded(bool),
Completed {
response_id: String,
token_usage: Option<TokenUsage>,

View file

@ -157,6 +157,9 @@ impl Stream for AggregatedStream {
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
}
Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included))));
}
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
}

View file

@ -29,18 +29,21 @@ use url::Url;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
pub struct ResponsesWebsocketConnection {
stream: Arc<Mutex<Option<WsStream>>>,
// TODO (pakrym): is this the right place for timeout?
idle_timeout: Duration,
server_reasoning_included: bool,
}
impl ResponsesWebsocketConnection {
fn new(stream: WsStream, idle_timeout: Duration) -> Self {
fn new(stream: WsStream, idle_timeout: Duration, server_reasoning_included: bool) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
idle_timeout,
server_reasoning_included,
}
}
@ -56,11 +59,17 @@ impl ResponsesWebsocketConnection {
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
let stream = Arc::clone(&self.stream);
let idle_timeout = self.idle_timeout;
let server_reasoning_included = self.server_reasoning_included;
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;
tokio::spawn(async move {
if server_reasoning_included {
let _ = tx_event
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
.await;
}
let mut guard = stream.lock().await;
let Some(ws_stream) = guard.as_mut() else {
let _ = tx_event
@ -111,10 +120,12 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
headers.extend(extra_headers);
apply_auth_headers(&mut headers, &self.auth);
let stream = connect_websocket(ws_url, headers, turn_state).await?;
let (stream, server_reasoning_included) =
connect_websocket(ws_url, headers, turn_state).await?;
Ok(ResponsesWebsocketConnection::new(
stream,
self.provider.stream_idle_timeout,
server_reasoning_included,
))
}
}
@ -137,7 +148,7 @@ async fn connect_websocket(
url: Url,
headers: HeaderMap,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<WsStream, ApiError> {
) -> Result<(WsStream, bool), ApiError> {
let mut request = url
.clone()
.into_client_request()
@ -147,6 +158,7 @@ async fn connect_websocket(
let (stream, response) = tokio_tungstenite::connect_async(request)
.await
.map_err(|err| map_ws_error(err, &url))?;
let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER);
if let Some(turn_state) = turn_state
&& let Some(header_value) = response
.headers()
@ -155,7 +167,7 @@ async fn connect_websocket(
{
let _ = turn_state.set(header_value.to_string());
}
Ok(stream)
Ok((stream, reasoning_included))
}
fn map_ws_error(err: WsError, url: &Url) -> ApiError {

View file

@ -25,6 +25,8 @@ use tokio_util::io::ReaderStream;
use tracing::debug;
use tracing::trace;
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
/// Streams SSE events from an on-disk fixture for tests.
pub fn stream_from_fixture(
path: impl AsRef<Path>,
@ -58,6 +60,10 @@ pub fn spawn_response_stream(
.get("X-Models-Etag")
.and_then(|v| v.to_str().ok())
.map(ToString::to_string);
let reasoning_included = stream_response
.headers
.get(X_REASONING_INCLUDED_HEADER)
.is_some();
if let Some(turn_state) = turn_state.as_ref()
&& let Some(header_value) = stream_response
.headers
@ -74,6 +80,11 @@ pub fn spawn_response_stream(
if let Some(etag) = models_etag {
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
}
if reasoning_included {
let _ = tx_event
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
.await;
}
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
});

View file

@ -809,7 +809,7 @@ impl Session {
async fn get_total_token_usage(&self) -> i64 {
let state = self.state.lock().await;
state.get_total_token_usage()
state.get_total_token_usage(state.server_reasoning_included())
}
async fn record_initial_history(&self, conversation_history: InitialHistory) {
@ -1618,6 +1618,11 @@ impl Session {
self.send_token_count_event(turn_context).await;
}
pub(crate) async fn set_server_reasoning_included(&self, included: bool) {
let mut state = self.state.lock().await;
state.set_server_reasoning_included(included);
}
async fn send_token_count_event(&self, turn_context: &TurnContext) {
let (info, rate_limits) = {
let state = self.state.lock().await;
@ -3149,6 +3154,9 @@ async fn try_run_sampling_request(
active_item = Some(tracked_item);
}
}
ResponseEvent::ServerReasoningIncluded(included) => {
sess.set_server_reasoning_included(included).await;
}
ResponseEvent::RateLimits(snapshot) => {
// Update internal state with latest rate limits, but defer sending until
// token usage is available to avoid duplicate TokenCount events.

View file

@ -316,6 +316,9 @@ async fn drain_to_completed(
sess.record_into_history(std::slice::from_ref(&item), turn_context)
.await;
}
Ok(ResponseEvent::ServerReasoningIncluded(included)) => {
sess.set_server_reasoning_included(included).await;
}
Ok(ResponseEvent::RateLimits(snapshot)) => {
sess.update_rate_limits(turn_context, snapshot).await;
}

View file

@ -235,12 +235,19 @@ impl ContextManager {
token_estimate as usize
}
pub(crate) fn get_total_token_usage(&self) -> i64 {
self.token_info
/// When true, the server already accounted for past reasoning tokens and
/// the client should not re-estimate them.
pub(crate) fn get_total_token_usage(&self, server_reasoning_included: bool) -> i64 {
let last_tokens = self
.token_info
.as_ref()
.map(|info| info.last_token_usage.total_tokens)
.unwrap_or(0)
.saturating_add(self.get_non_last_reasoning_items_tokens() as i64)
.unwrap_or(0);
if server_reasoning_included {
last_tokens
} else {
last_tokens.saturating_add(self.get_non_last_reasoning_items_tokens() as i64)
}
}
/// This function enforces a couple of invariants on the in-memory history:

View file

@ -14,6 +14,7 @@ pub(crate) struct SessionState {
pub(crate) session_configuration: SessionConfiguration,
pub(crate) history: ContextManager,
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
pub(crate) server_reasoning_included: bool,
}
impl SessionState {
@ -24,6 +25,7 @@ impl SessionState {
session_configuration,
history,
latest_rate_limits: None,
server_reasoning_included: false,
}
}
@ -78,8 +80,17 @@ impl SessionState {
self.history.set_token_usage_full(context_window);
}
pub(crate) fn get_total_token_usage(&self) -> i64 {
self.history.get_total_token_usage()
pub(crate) fn get_total_token_usage(&self, server_reasoning_included: bool) -> i64 {
self.history
.get_total_token_usage(server_reasoning_included)
}
pub(crate) fn set_server_reasoning_included(&mut self, included: bool) {
self.server_reasoning_included = included;
}
pub(crate) fn server_reasoning_included(&self) -> bool {
self.server_reasoning_included
}
}

View file

@ -30,6 +30,7 @@ impl SessionTask for RegularTask {
) -> Option<String> {
let sess = session.clone_session();
let run_turn_span = trace_span!("run_turn");
sess.set_server_reasoning_included(false).await;
sess.services
.otel_manager
.apply_traceparent_parent(&run_turn_span);

View file

@ -15,10 +15,12 @@ use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ReasoningSummary;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::WebSocketConnectionConfig;
use core_test_support::responses::WebSocketTestServer;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::start_websocket_server;
use core_test_support::responses::start_websocket_server_with_headers;
use core_test_support::skip_if_no_network;
use futures::StreamExt;
use pretty_assertions::assert_eq;
@ -60,6 +62,40 @@ async fn responses_websocket_streams_request() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_emits_reasoning_included_event() {
skip_if_no_network!();
let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
requests: vec![vec![ev_response_created("resp-1"), ev_completed("resp-1")]],
response_headers: vec![("X-Reasoning-Included".to_string(), "true".to_string())],
}])
.await;
let harness = websocket_harness(&server).await;
let mut session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
let mut stream = session
.stream(&prompt)
.await
.expect("websocket stream failed");
let mut saw_reasoning_included = false;
while let Some(event) = stream.next().await {
match event.expect("event") {
ResponseEvent::ServerReasoningIncluded(true) => {
saw_reasoning_included = true;
}
ResponseEvent::Completed { .. } => break,
_ => {}
}
}
assert!(saw_reasoning_included);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_appends_on_prefix() {
skip_if_no_network!();

View file

@ -32,11 +32,13 @@ use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_completed_with_tokens;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::mount_compact_json_once;
use core_test_support::responses::mount_response_sequence;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_once_match;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::sse_failed;
use core_test_support::responses::sse_response;
use core_test_support::responses::start_mock_server;
use pretty_assertions::assert_eq;
use serde_json::json;
@ -2147,3 +2149,85 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() {
"third turn should include compaction summary item"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn auto_compact_runs_when_reasoning_header_clears_between_turns() {
skip_if_no_network!();
let server = start_mock_server().await;
let first_user = "SERVER_INCLUDED_FIRST";
let second_user = "SERVER_INCLUDED_SECOND";
let third_user = "SERVER_INCLUDED_THIRD";
let pre_last_reasoning_content = "a".repeat(2_400);
let post_last_reasoning_content = "b".repeat(4_000);
let first_turn = sse(vec![
ev_reasoning_item("pre-reasoning", &["pre"], &[&pre_last_reasoning_content]),
ev_completed_with_tokens("r1", 10),
]);
let second_turn = sse(vec![
ev_reasoning_item("post-reasoning", &["post"], &[&post_last_reasoning_content]),
ev_completed_with_tokens("r2", 80),
]);
let third_turn = sse(vec![
ev_assistant_message("m4", FINAL_REPLY),
ev_completed_with_tokens("r4", 1),
]);
let responses = vec![
sse_response(first_turn).insert_header("X-Reasoning-Included", "true"),
sse_response(second_turn),
sse_response(third_turn),
];
mount_response_sequence(&server, responses).await;
let compacted_history = vec![
codex_protocol::models::ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![codex_protocol::models::ContentItem::OutputText {
text: "REMOTE_COMPACT_SUMMARY".to_string(),
}],
},
codex_protocol::models::ResponseItem::Compaction {
encrypted_content: "ENCRYPTED_COMPACTION_SUMMARY".to_string(),
},
];
let compact_mock =
mount_compact_json_once(&server, serde_json::json!({ "output": compacted_history })).await;
let codex = test_codex()
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
.with_config(|config| {
set_test_compact_prompt(config);
config.model_auto_compact_token_limit = Some(300);
config.features.enable(Feature::RemoteCompaction);
})
.build(&server)
.await
.expect("build codex")
.codex;
for user in [first_user, second_user, third_user] {
codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: user.into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
}
let compact_requests = compact_mock.requests();
assert_eq!(
compact_requests.len(),
1,
"remote compaction should run once after the reasoning header clears"
);
}

View file

@ -484,6 +484,7 @@ impl OtelManager {
ResponseEvent::ReasoningSummaryPartAdded { .. } => {
"reasoning_summary_part_added".into()
}
ResponseEvent::ServerReasoningIncluded(_) => "server_reasoning_included".into(),
ResponseEvent::RateLimits(_) => "rate_limits".into(),
ResponseEvent::ModelsEtag(_) => "models_etag".into(),
}