Act on reasoning-included per turn (#9402)
- Reset reasoning-included flag each turn and update compaction test
This commit is contained in:
parent
57ec3a8277
commit
b11e96fb04
12 changed files with 192 additions and 11 deletions
|
|
@ -42,6 +42,10 @@ pub enum ResponseEvent {
|
|||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
/// Emitted when `X-Reasoning-Included: true` is present on the response,
|
||||
/// meaning the server already accounted for past reasoning tokens and the
|
||||
/// client should not re-estimate them.
|
||||
ServerReasoningIncluded(bool),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
|
|
|
|||
|
|
@ -157,6 +157,9 @@ impl Stream for AggregatedStream {
|
|||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ServerReasoningIncluded(included))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,18 +29,21 @@ use url::Url;
|
|||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
||||
|
||||
pub struct ResponsesWebsocketConnection {
|
||||
stream: Arc<Mutex<Option<WsStream>>>,
|
||||
// TODO (pakrym): is this the right place for timeout?
|
||||
idle_timeout: Duration,
|
||||
server_reasoning_included: bool,
|
||||
}
|
||||
|
||||
impl ResponsesWebsocketConnection {
|
||||
fn new(stream: WsStream, idle_timeout: Duration) -> Self {
|
||||
fn new(stream: WsStream, idle_timeout: Duration, server_reasoning_included: bool) -> Self {
|
||||
Self {
|
||||
stream: Arc::new(Mutex::new(Some(stream))),
|
||||
idle_timeout,
|
||||
server_reasoning_included,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -56,11 +59,17 @@ impl ResponsesWebsocketConnection {
|
|||
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
|
||||
let stream = Arc::clone(&self.stream);
|
||||
let idle_timeout = self.idle_timeout;
|
||||
let server_reasoning_included = self.server_reasoning_included;
|
||||
let request_body = serde_json::to_value(&request).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to encode websocket request: {err}"))
|
||||
})?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if server_reasoning_included {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
|
||||
.await;
|
||||
}
|
||||
let mut guard = stream.lock().await;
|
||||
let Some(ws_stream) = guard.as_mut() else {
|
||||
let _ = tx_event
|
||||
|
|
@ -111,10 +120,12 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
|||
headers.extend(extra_headers);
|
||||
apply_auth_headers(&mut headers, &self.auth);
|
||||
|
||||
let stream = connect_websocket(ws_url, headers, turn_state).await?;
|
||||
let (stream, server_reasoning_included) =
|
||||
connect_websocket(ws_url, headers, turn_state).await?;
|
||||
Ok(ResponsesWebsocketConnection::new(
|
||||
stream,
|
||||
self.provider.stream_idle_timeout,
|
||||
server_reasoning_included,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
@ -137,7 +148,7 @@ async fn connect_websocket(
|
|||
url: Url,
|
||||
headers: HeaderMap,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<WsStream, ApiError> {
|
||||
) -> Result<(WsStream, bool), ApiError> {
|
||||
let mut request = url
|
||||
.clone()
|
||||
.into_client_request()
|
||||
|
|
@ -147,6 +158,7 @@ async fn connect_websocket(
|
|||
let (stream, response) = tokio_tungstenite::connect_async(request)
|
||||
.await
|
||||
.map_err(|err| map_ws_error(err, &url))?;
|
||||
let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER);
|
||||
if let Some(turn_state) = turn_state
|
||||
&& let Some(header_value) = response
|
||||
.headers()
|
||||
|
|
@ -155,7 +167,7 @@ async fn connect_websocket(
|
|||
{
|
||||
let _ = turn_state.set(header_value.to_string());
|
||||
}
|
||||
Ok(stream)
|
||||
Ok((stream, reasoning_included))
|
||||
}
|
||||
|
||||
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ use tokio_util::io::ReaderStream;
|
|||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
||||
|
||||
/// Streams SSE events from an on-disk fixture for tests.
|
||||
pub fn stream_from_fixture(
|
||||
path: impl AsRef<Path>,
|
||||
|
|
@ -58,6 +60,10 @@ pub fn spawn_response_stream(
|
|||
.get("X-Models-Etag")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
let reasoning_included = stream_response
|
||||
.headers
|
||||
.get(X_REASONING_INCLUDED_HEADER)
|
||||
.is_some();
|
||||
if let Some(turn_state) = turn_state.as_ref()
|
||||
&& let Some(header_value) = stream_response
|
||||
.headers
|
||||
|
|
@ -74,6 +80,11 @@ pub fn spawn_response_stream(
|
|||
if let Some(etag) = models_etag {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
|
||||
}
|
||||
if reasoning_included {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
|
||||
.await;
|
||||
}
|
||||
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -809,7 +809,7 @@ impl Session {
|
|||
|
||||
async fn get_total_token_usage(&self) -> i64 {
|
||||
let state = self.state.lock().await;
|
||||
state.get_total_token_usage()
|
||||
state.get_total_token_usage(state.server_reasoning_included())
|
||||
}
|
||||
|
||||
async fn record_initial_history(&self, conversation_history: InitialHistory) {
|
||||
|
|
@ -1618,6 +1618,11 @@ impl Session {
|
|||
self.send_token_count_event(turn_context).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn set_server_reasoning_included(&self, included: bool) {
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_server_reasoning_included(included);
|
||||
}
|
||||
|
||||
async fn send_token_count_event(&self, turn_context: &TurnContext) {
|
||||
let (info, rate_limits) = {
|
||||
let state = self.state.lock().await;
|
||||
|
|
@ -3149,6 +3154,9 @@ async fn try_run_sampling_request(
|
|||
active_item = Some(tracked_item);
|
||||
}
|
||||
}
|
||||
ResponseEvent::ServerReasoningIncluded(included) => {
|
||||
sess.set_server_reasoning_included(included).await;
|
||||
}
|
||||
ResponseEvent::RateLimits(snapshot) => {
|
||||
// Update internal state with latest rate limits, but defer sending until
|
||||
// token usage is available to avoid duplicate TokenCount events.
|
||||
|
|
|
|||
|
|
@ -316,6 +316,9 @@ async fn drain_to_completed(
|
|||
sess.record_into_history(std::slice::from_ref(&item), turn_context)
|
||||
.await;
|
||||
}
|
||||
Ok(ResponseEvent::ServerReasoningIncluded(included)) => {
|
||||
sess.set_server_reasoning_included(included).await;
|
||||
}
|
||||
Ok(ResponseEvent::RateLimits(snapshot)) => {
|
||||
sess.update_rate_limits(turn_context, snapshot).await;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -235,12 +235,19 @@ impl ContextManager {
|
|||
token_estimate as usize
|
||||
}
|
||||
|
||||
pub(crate) fn get_total_token_usage(&self) -> i64 {
|
||||
self.token_info
|
||||
/// When true, the server already accounted for past reasoning tokens and
|
||||
/// the client should not re-estimate them.
|
||||
pub(crate) fn get_total_token_usage(&self, server_reasoning_included: bool) -> i64 {
|
||||
let last_tokens = self
|
||||
.token_info
|
||||
.as_ref()
|
||||
.map(|info| info.last_token_usage.total_tokens)
|
||||
.unwrap_or(0)
|
||||
.saturating_add(self.get_non_last_reasoning_items_tokens() as i64)
|
||||
.unwrap_or(0);
|
||||
if server_reasoning_included {
|
||||
last_tokens
|
||||
} else {
|
||||
last_tokens.saturating_add(self.get_non_last_reasoning_items_tokens() as i64)
|
||||
}
|
||||
}
|
||||
|
||||
/// This function enforces a couple of invariants on the in-memory history:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ pub(crate) struct SessionState {
|
|||
pub(crate) session_configuration: SessionConfiguration,
|
||||
pub(crate) history: ContextManager,
|
||||
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
|
||||
pub(crate) server_reasoning_included: bool,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
|
|
@ -24,6 +25,7 @@ impl SessionState {
|
|||
session_configuration,
|
||||
history,
|
||||
latest_rate_limits: None,
|
||||
server_reasoning_included: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -78,8 +80,17 @@ impl SessionState {
|
|||
self.history.set_token_usage_full(context_window);
|
||||
}
|
||||
|
||||
pub(crate) fn get_total_token_usage(&self) -> i64 {
|
||||
self.history.get_total_token_usage()
|
||||
pub(crate) fn get_total_token_usage(&self, server_reasoning_included: bool) -> i64 {
|
||||
self.history
|
||||
.get_total_token_usage(server_reasoning_included)
|
||||
}
|
||||
|
||||
pub(crate) fn set_server_reasoning_included(&mut self, included: bool) {
|
||||
self.server_reasoning_included = included;
|
||||
}
|
||||
|
||||
pub(crate) fn server_reasoning_included(&self) -> bool {
|
||||
self.server_reasoning_included
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ impl SessionTask for RegularTask {
|
|||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
let run_turn_span = trace_span!("run_turn");
|
||||
sess.set_server_reasoning_included(false).await;
|
||||
sess.services
|
||||
.otel_manager
|
||||
.apply_traceparent_parent(&run_turn_span);
|
||||
|
|
|
|||
|
|
@ -15,10 +15,12 @@ use codex_otel::OtelManager;
|
|||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses::WebSocketConnectionConfig;
|
||||
use core_test_support::responses::WebSocketTestServer;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::start_websocket_server;
|
||||
use core_test_support::responses::start_websocket_server_with_headers;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use futures::StreamExt;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
|
@ -60,6 +62,40 @@ async fn responses_websocket_streams_request() {
|
|||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_emits_reasoning_included_event() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
|
||||
requests: vec![vec![ev_response_created("resp-1"), ev_completed("resp-1")]],
|
||||
response_headers: vec![("X-Reasoning-Included".to_string(), "true".to_string())],
|
||||
}])
|
||||
.await;
|
||||
|
||||
let harness = websocket_harness(&server).await;
|
||||
let mut session = harness.client.new_session();
|
||||
let prompt = prompt_with_input(vec![message_item("hello")]);
|
||||
|
||||
let mut stream = session
|
||||
.stream(&prompt)
|
||||
.await
|
||||
.expect("websocket stream failed");
|
||||
|
||||
let mut saw_reasoning_included = false;
|
||||
while let Some(event) = stream.next().await {
|
||||
match event.expect("event") {
|
||||
ResponseEvent::ServerReasoningIncluded(true) => {
|
||||
saw_reasoning_included = true;
|
||||
}
|
||||
ResponseEvent::Completed { .. } => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(saw_reasoning_included);
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_appends_on_prefix() {
|
||||
skip_if_no_network!();
|
||||
|
|
|
|||
|
|
@ -32,11 +32,13 @@ use core_test_support::responses::ev_completed;
|
|||
use core_test_support::responses::ev_completed_with_tokens;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_compact_json_once;
|
||||
use core_test_support::responses::mount_response_sequence;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_failed;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
|
@ -2147,3 +2149,85 @@ async fn auto_compact_counts_encrypted_reasoning_before_last_user() {
|
|||
"third turn should include compaction summary item"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_runs_when_reasoning_header_clears_between_turns() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let first_user = "SERVER_INCLUDED_FIRST";
|
||||
let second_user = "SERVER_INCLUDED_SECOND";
|
||||
let third_user = "SERVER_INCLUDED_THIRD";
|
||||
|
||||
let pre_last_reasoning_content = "a".repeat(2_400);
|
||||
let post_last_reasoning_content = "b".repeat(4_000);
|
||||
|
||||
let first_turn = sse(vec![
|
||||
ev_reasoning_item("pre-reasoning", &["pre"], &[&pre_last_reasoning_content]),
|
||||
ev_completed_with_tokens("r1", 10),
|
||||
]);
|
||||
let second_turn = sse(vec![
|
||||
ev_reasoning_item("post-reasoning", &["post"], &[&post_last_reasoning_content]),
|
||||
ev_completed_with_tokens("r2", 80),
|
||||
]);
|
||||
let third_turn = sse(vec![
|
||||
ev_assistant_message("m4", FINAL_REPLY),
|
||||
ev_completed_with_tokens("r4", 1),
|
||||
]);
|
||||
|
||||
let responses = vec![
|
||||
sse_response(first_turn).insert_header("X-Reasoning-Included", "true"),
|
||||
sse_response(second_turn),
|
||||
sse_response(third_turn),
|
||||
];
|
||||
mount_response_sequence(&server, responses).await;
|
||||
|
||||
let compacted_history = vec![
|
||||
codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: "REMOTE_COMPACT_SUMMARY".to_string(),
|
||||
}],
|
||||
},
|
||||
codex_protocol::models::ResponseItem::Compaction {
|
||||
encrypted_content: "ENCRYPTED_COMPACTION_SUMMARY".to_string(),
|
||||
},
|
||||
];
|
||||
let compact_mock =
|
||||
mount_compact_json_once(&server, serde_json::json!({ "output": compacted_history })).await;
|
||||
|
||||
let codex = test_codex()
|
||||
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
|
||||
.with_config(|config| {
|
||||
set_test_compact_prompt(config);
|
||||
config.model_auto_compact_token_limit = Some(300);
|
||||
config.features.enable(Feature::RemoteCompaction);
|
||||
})
|
||||
.build(&server)
|
||||
.await
|
||||
.expect("build codex")
|
||||
.codex;
|
||||
|
||||
for user in [first_user, second_user, third_user] {
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: user.into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
|
||||
}
|
||||
|
||||
let compact_requests = compact_mock.requests();
|
||||
assert_eq!(
|
||||
compact_requests.len(),
|
||||
1,
|
||||
"remote compaction should run once after the reasoning header clears"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -484,6 +484,7 @@ impl OtelManager {
|
|||
ResponseEvent::ReasoningSummaryPartAdded { .. } => {
|
||||
"reasoning_summary_part_added".into()
|
||||
}
|
||||
ResponseEvent::ServerReasoningIncluded(_) => "server_reasoning_included".into(),
|
||||
ResponseEvent::RateLimits(_) => "rate_limits".into(),
|
||||
ResponseEvent::ModelsEtag(_) => "models_etag".into(),
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue