Add model client sessions (#9102)

Maintain a long-running session.
This commit is contained in:
pakrym-oai 2026-01-12 17:15:56 -08:00 committed by GitHub
parent 87f7226cca
commit 490c1c1fdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 874 additions and 196 deletions

41
codex-rs/Cargo.lock generated
View file

@ -984,8 +984,10 @@ dependencies = [
"thiserror 2.0.17",
"tokio",
"tokio-test",
"tokio-tungstenite",
"tokio-util",
"tracing",
"url",
"wiremock",
]
@ -2126,6 +2128,7 @@ dependencies = [
"codex-protocol",
"codex-utils-absolute-path",
"codex-utils-cargo-bin",
"futures",
"notify",
"pretty_assertions",
"regex-lite",
@ -2134,6 +2137,7 @@ dependencies = [
"shlex",
"tempfile",
"tokio",
"tokio-tungstenite",
"walkdir",
"wiremock",
]
@ -2361,6 +2365,12 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "data-encoding"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea"
[[package]]
name = "dbus"
version = "0.9.9"
@ -7117,6 +7127,18 @@ dependencies = [
"tokio-stream",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.18"
@ -7511,6 +7533,25 @@ dependencies = [
"ratatui-core",
]
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 1.3.1",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror 1.0.69",
"url",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.18.0"

View file

@ -209,6 +209,7 @@ tiny_http = "0.12"
tokio = "1"
tokio-stream = "0.1.18"
tokio-test = "0.4"
tokio-tungstenite = "0.21.0"
tokio-util = "0.7.18"
toml = "0.9.5"
toml_edit = "0.24.0"

View file

@ -14,11 +14,13 @@ http = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "sync", "time"] }
tokio = { workspace = true, features = ["macros", "net", "rt", "sync", "time"] }
tokio-tungstenite = { workspace = true }
tracing = { workspace = true }
eventsource-stream = { workspace = true }
regex-lite = { workspace = true }
tokio-util = { workspace = true, features = ["codec"] }
url = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }

View file

@ -2,4 +2,5 @@ pub mod chat;
pub mod compact;
pub mod models;
pub mod responses;
pub mod responses_websocket;
mod streaming;

View file

@ -0,0 +1,280 @@
use crate::auth::AuthProvider;
use crate::common::Prompt as ApiPrompt;
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::endpoint::responses::ResponsesOptions;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::requests::ResponsesRequest;
use crate::requests::ResponsesRequestBuilder;
use crate::requests::responses::Compression;
use crate::sse::responses::ResponsesStreamEvent;
use crate::sse::responses::process_responses_event;
use codex_client::TransportError;
use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use http::HeaderValue;
use serde_json::Value;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Error as WsError;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tracing::debug;
use tracing::trace;
use tracing::warn;
use url::Url;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub struct ResponsesWebsocketClient<A: AuthProvider> {
provider: Provider,
auth: A,
}
impl<A: AuthProvider> ResponsesWebsocketClient<A> {
pub fn new(provider: Provider, auth: A) -> Self {
Self { provider, auth }
}
pub async fn stream_request(
&self,
request: ResponsesRequest,
) -> Result<ResponseStream, ApiError> {
self.stream(request.body, request.headers, request.compression)
.await
}
pub async fn stream_prompt(
&self,
model: &str,
prompt: &ApiPrompt,
options: ResponsesOptions,
) -> Result<ResponseStream, ApiError> {
let ResponsesOptions {
reasoning,
include,
prompt_cache_key,
text,
store_override,
conversation_id,
session_source,
extra_headers,
compression,
} = options;
// TODO (pakrym): share with HTTP based Responses API client
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
.tools(&prompt.tools)
.parallel_tool_calls(prompt.parallel_tool_calls)
.reasoning(reasoning)
.include(include)
.prompt_cache_key(prompt_cache_key)
.text(text)
.conversation(conversation_id)
.session_source(session_source)
.store_override(store_override)
.extra_headers(extra_headers)
.compression(compression)
.build(&self.provider)?;
self.stream_request(request).await
}
pub async fn stream(
&self,
body: Value,
extra_headers: HeaderMap,
compression: Compression,
) -> Result<ResponseStream, ApiError> {
if compression == Compression::Zstd {
warn!(
"request compression is not supported for websocket streaming; sending uncompressed payload"
);
}
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
let mut headers = self.provider.headers.clone();
headers.extend(extra_headers);
apply_auth_headers(&mut headers, &self.auth);
let connection = connect_websocket(ws_url, headers).await?;
let (tx_event, rx_event) =
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
let idle_timeout = self.provider.stream_idle_timeout;
// TODO (pakrym): surface rate limits
// TODO (pakrym): check models etags
tokio::spawn(async move {
if let Err(err) = run_websocket_response_stream(
connection.stream,
tx_event.clone(),
body,
idle_timeout,
)
.await
{
let _ = tx_event.send(Err(err)).await;
}
});
Ok(ResponseStream { rx_event })
}
}
// TODO (pakrym): share with /auth
fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
if let Some(token) = auth.bearer_token()
&& let Ok(header) = HeaderValue::from_str(&format!("Bearer {token}"))
{
let _ = headers.insert(http::header::AUTHORIZATION, header);
}
if let Some(account_id) = auth.account_id()
&& let Ok(header) = HeaderValue::from_str(&account_id)
{
let _ = headers.insert("ChatGPT-Account-ID", header);
}
}
struct WebSocketConnection {
stream: WsStream,
}
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WebSocketConnection, ApiError> {
let mut request = url
.clone()
.into_client_request()
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?;
request.headers_mut().extend(headers);
let (stream, _) = tokio_tungstenite::connect_async(request)
.await
.map_err(|err| map_ws_error(err, &url))?;
Ok(WebSocketConnection { stream })
}
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
match err {
WsError::Http(response) => {
let status = response.status();
let headers = response.headers().clone();
let body = response
.body()
.as_ref()
.and_then(|bytes| String::from_utf8(bytes.clone()).ok());
ApiError::Transport(TransportError::Http {
status,
url: Some(url.to_string()),
headers: Some(headers),
body,
})
}
WsError::ConnectionClosed | WsError::AlreadyClosed => {
ApiError::Stream("websocket closed".to_string())
}
WsError::Io(err) => ApiError::Transport(TransportError::Network(err.to_string())),
other => ApiError::Transport(TransportError::Network(other.to_string())),
}
}
async fn run_websocket_response_stream(
mut ws_stream: WsStream,
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
request_body: Value,
idle_timeout: Duration,
) -> Result<(), ApiError> {
let request_text = match serde_json::to_string(&request_body) {
Ok(text) => text,
Err(err) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(format!(
"failed to encode websocket request: {err}"
)));
}
};
if let Err(err) = ws_stream.send(Message::Text(request_text)).await {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(format!(
"failed to send websocket request: {err}"
)));
}
loop {
let response = tokio::time::timeout(idle_timeout, ws_stream.next())
.await
.map_err(|_| ApiError::Stream("idle timeout waiting for websocket".into()));
let message = match response {
Ok(Some(Ok(msg))) => msg,
Ok(Some(Err(err))) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(err.to_string()));
}
Ok(None) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(
"stream closed before response.completed".into(),
));
}
Err(err) => {
let _ = ws_stream.close(None).await;
return Err(err);
}
};
match message {
Message::Text(text) => {
trace!("websocket event: {text}");
let event = match serde_json::from_str::<ResponsesStreamEvent>(&text) {
Ok(event) => event,
Err(err) => {
debug!("failed to parse websocket event: {err}, data: {text}");
continue;
}
};
match process_responses_event(event) {
Ok(Some(event)) => {
let is_completed = matches!(event, ResponseEvent::Completed { .. });
let _ = tx_event.send(Ok(event)).await;
if is_completed {
break;
}
}
Ok(None) => {}
Err(error) => {
let _ = ws_stream.close(None).await;
return Err(error.into_api_error());
}
}
}
Message::Binary(_) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream("unexpected binary websocket event".into()));
}
Message::Ping(payload) => {
if ws_stream.send(Message::Pong(payload)).await.is_err() {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream("websocket ping failed".into()));
}
}
Message::Pong(_) => {}
Message::Close(_) => {
let _ = ws_stream.close(None).await;
return Err(ApiError::Stream(
"websocket closed before response.completed".into(),
));
}
_ => {}
}
}
let _ = ws_stream.close(None).await;
Ok(())
}

View file

@ -25,6 +25,7 @@ pub use crate::endpoint::compact::CompactClient;
pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::responses::ResponsesClient;
pub use crate::endpoint::responses::ResponsesOptions;
pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient;
pub use crate::error::ApiError;
pub use crate::provider::Provider;
pub use crate::provider::WireApi;

View file

@ -126,7 +126,7 @@ struct ResponseCompletedOutputTokensDetails {
}
#[derive(Deserialize, Debug)]
struct ResponsesStreamEvent {
pub struct ResponsesStreamEvent {
#[serde(rename = "type")]
kind: String,
response: Option<Value>,
@ -149,7 +149,7 @@ impl ResponsesEventError {
}
}
fn process_responses_event(
pub fn process_responses_event(
event: ResponsesStreamEvent,
) -> std::result::Result<Option<ResponseEvent>, ResponsesEventError> {
match event.kind.as_str() {

View file

@ -13,6 +13,7 @@ use codex_api::ReqwestTransport;
use codex_api::ResponseStream as ApiResponseStream;
use codex_api::ResponsesClient as ApiResponsesClient;
use codex_api::ResponsesOptions as ApiResponsesOptions;
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
use codex_api::SseTelemetry;
use codex_api::TransportError;
use codex_api::common::Reasoning;
@ -57,8 +58,8 @@ use crate::model_provider_info::WireApi;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;
#[derive(Debug, Clone)]
pub struct ModelClient {
#[derive(Debug)]
struct ModelClientState {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_info: ModelInfo,
@ -70,6 +71,16 @@ pub struct ModelClient {
session_source: SessionSource,
}
#[derive(Debug, Clone)]
pub struct ModelClient {
state: Arc<ModelClientState>,
}
#[derive(Debug, Clone)]
pub struct ModelClientSession {
state: Arc<ModelClientState>,
}
#[allow(clippy::too_many_arguments)]
impl ModelClient {
pub fn new(
@ -84,20 +95,30 @@ impl ModelClient {
session_source: SessionSource,
) -> Self {
Self {
config,
auth_manager,
model_info,
otel_manager,
provider,
conversation_id,
effort,
summary,
session_source,
state: Arc::new(ModelClientState {
config,
auth_manager,
model_info,
otel_manager,
provider,
conversation_id,
effort,
summary,
session_source,
}),
}
}
pub fn new_session(&self) -> ModelClientSession {
ModelClientSession {
state: Arc::clone(&self.state),
}
}
}
impl ModelClient {
pub fn get_model_context_window(&self) -> Option<i64> {
let model_info = self.get_model_info();
let model_info = &self.state.model_info;
let effective_context_window_percent = model_info.effective_context_window_percent;
model_info.context_window.map(|context_window| {
context_window.saturating_mul(effective_context_window_percent) / 100
@ -105,39 +126,210 @@ impl ModelClient {
}
pub fn config(&self) -> Arc<Config> {
Arc::clone(&self.config)
Arc::clone(&self.state.config)
}
pub fn provider(&self) -> &ModelProviderInfo {
&self.provider
&self.state.provider
}
pub fn get_provider(&self) -> ModelProviderInfo {
self.state.provider.clone()
}
pub fn get_otel_manager(&self) -> OtelManager {
self.state.otel_manager.clone()
}
pub fn get_session_source(&self) -> SessionSource {
self.state.session_source.clone()
}
/// Returns the currently configured model slug.
pub fn get_model(&self) -> String {
self.state.model_info.slug.clone()
}
pub fn get_model_info(&self) -> ModelInfo {
self.state.model_info.clone()
}
/// Returns the current reasoning effort setting.
pub fn get_reasoning_effort(&self) -> Option<ReasoningEffortConfig> {
self.state.effort
}
/// Returns the current reasoning summary setting.
pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig {
self.state.summary
}
pub fn get_auth_manager(&self) -> Option<Arc<AuthManager>> {
self.state.auth_manager.clone()
}
/// Compacts the current conversation history using the Compact endpoint.
///
/// This is a unary call (no streaming) that returns a new list of
/// `ResponseItem`s representing the compacted transcript.
pub async fn compact_conversation_history(&self, prompt: &Prompt) -> Result<Vec<ResponseItem>> {
if prompt.input.is_empty() {
return Ok(Vec::new());
}
let auth_manager = self.state.auth_manager.clone();
let auth = match auth_manager.as_ref() {
Some(manager) => manager.auth().await,
None => None,
};
let api_provider = self
.state
.provider
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry = self.build_request_telemetry();
let client = ApiCompactClient::new(transport, api_provider, api_auth)
.with_telemetry(Some(request_telemetry));
let instructions = prompt
.get_full_instructions(&self.state.model_info)
.into_owned();
let payload = ApiCompactionInput {
model: &self.state.model_info.slug,
input: &prompt.input,
instructions: &instructions,
};
let mut extra_headers = ApiHeaderMap::new();
if let SessionSource::SubAgent(sub) = &self.state.session_source {
let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub {
label.clone()
} else {
serde_json::to_value(sub)
.ok()
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
.unwrap_or_else(|| "other".to_string())
};
if let Ok(val) = HeaderValue::from_str(&subagent) {
extra_headers.insert("x-openai-subagent", val);
}
}
client
.compact_input(&payload, extra_headers)
.await
.map_err(map_api_error)
}
}
impl ModelClientSession {
/// Streams a single model turn using either the Responses or Chat
/// Completions wire API, depending on the configured provider.
///
/// For Chat providers, the underlying stream is optionally aggregated
/// based on the `show_raw_agent_reasoning` flag in the config.
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
match self.provider.wire_api {
match self.state.provider.wire_api {
WireApi::Responses => self.stream_responses_api(prompt).await,
WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await,
WireApi::Chat => {
let api_stream = self.stream_chat_completions(prompt).await?;
if self.config.show_raw_agent_reasoning {
if self.state.config.show_raw_agent_reasoning {
Ok(map_response_stream(
api_stream.streaming_mode(),
self.otel_manager.clone(),
self.state.otel_manager.clone(),
))
} else {
Ok(map_response_stream(
api_stream.aggregate(),
self.otel_manager.clone(),
self.state.otel_manager.clone(),
))
}
}
}
}
fn build_responses_request(&self, prompt: &Prompt) -> Result<ApiPrompt> {
let model_info = self.state.model_info.clone();
let instructions = prompt.get_full_instructions(&model_info).into_owned();
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
Ok(build_api_prompt(prompt, instructions, tools_json))
}
fn build_responses_options(
&self,
prompt: &Prompt,
compression: Compression,
) -> ApiResponsesOptions {
let model_info = &self.state.model_info;
let default_reasoning_effort = model_info.default_reasoning_level;
let reasoning = if model_info.supports_reasoning_summaries {
Some(Reasoning {
effort: self.state.effort.or(default_reasoning_effort),
summary: if self.state.summary == ReasoningSummaryConfig::None {
None
} else {
Some(self.state.summary)
},
})
} else {
None
};
let include = if reasoning.is_some() {
vec!["reasoning.encrypted_content".to_string()]
} else {
Vec::new()
};
let verbosity = if model_info.support_verbosity {
self.state
.config
.model_verbosity
.or(model_info.default_verbosity)
} else {
if self.state.config.model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
model_info.slug
);
}
None
};
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
let conversation_id = self.state.conversation_id.to_string();
ApiResponsesOptions {
reasoning,
include,
prompt_cache_key: Some(conversation_id.clone()),
text,
store_override: None,
conversation_id: Some(conversation_id),
session_source: Some(self.state.session_source.clone()),
extra_headers: beta_feature_headers(&self.state.config),
compression,
}
}
fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression {
if self
.state
.config
.features
.enabled(Feature::EnableRequestCompression)
&& auth.is_some_and(|auth| auth.mode == AuthMode::ChatGPT)
&& self.state.provider.is_openai()
{
Compression::Zstd
} else {
Compression::None
}
}
/// Streams a turn via the OpenAI Chat Completions API.
///
/// This path is only used when the provider is configured with
@ -149,13 +341,13 @@ impl ModelClient {
));
}
let auth_manager = self.auth_manager.clone();
let model_info = self.get_model_info();
let auth_manager = self.state.auth_manager.clone();
let model_info = self.state.model_info.clone();
let instructions = prompt.get_full_instructions(&model_info).into_owned();
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let api_prompt = build_api_prompt(prompt, instructions, tools_json);
let conversation_id = self.conversation_id.to_string();
let session_source = self.session_source.clone();
let conversation_id = self.state.conversation_id.to_string();
let session_source = self.state.session_source.clone();
let mut auth_recovery = auth_manager
.as_ref()
@ -166,9 +358,10 @@ impl ModelClient {
None => None,
};
let api_provider = self
.state
.provider
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
let transport = ReqwestTransport::new(build_reqwest_client());
let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry();
let client = ApiChatClient::new(transport, api_provider, api_auth)
@ -176,7 +369,7 @@ impl ModelClient {
let stream_result = client
.stream_prompt(
&self.get_model(),
&self.state.model_info.slug,
&api_prompt,
Some(conversation_id.clone()),
Some(session_source.clone()),
@ -203,52 +396,14 @@ impl ModelClient {
async fn stream_responses_api(&self, prompt: &Prompt) -> Result<ResponseStream> {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
warn!(path, "Streaming from fixture");
let stream = codex_api::stream_from_fixture(path, self.provider.stream_idle_timeout())
.map_err(map_api_error)?;
return Ok(map_response_stream(stream, self.otel_manager.clone()));
let stream =
codex_api::stream_from_fixture(path, self.state.provider.stream_idle_timeout())
.map_err(map_api_error)?;
return Ok(map_response_stream(stream, self.state.otel_manager.clone()));
}
let auth_manager = self.auth_manager.clone();
let model_info = self.get_model_info();
let instructions = prompt.get_full_instructions(&model_info).into_owned();
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
let default_reasoning_effort = model_info.default_reasoning_level;
let reasoning = if model_info.supports_reasoning_summaries {
Some(Reasoning {
effort: self.effort.or(default_reasoning_effort),
summary: if self.summary == ReasoningSummaryConfig::None {
None
} else {
Some(self.summary)
},
})
} else {
None
};
let include: Vec<String> = if reasoning.is_some() {
vec!["reasoning.encrypted_content".to_string()]
} else {
vec![]
};
let verbosity = if model_info.support_verbosity {
self.config.model_verbosity.or(model_info.default_verbosity)
} else {
if self.config.model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
model_info.slug
);
}
None
};
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
let api_prompt = build_api_prompt(prompt, instructions.clone(), tools_json);
let conversation_id = self.conversation_id.to_string();
let session_source = self.session_source.clone();
let auth_manager = self.state.auth_manager.clone();
let api_prompt = self.build_responses_request(prompt)?;
let mut auth_recovery = auth_manager
.as_ref()
@ -259,47 +414,26 @@ impl ModelClient {
None => None,
};
let api_provider = self
.state
.provider
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
let transport = ReqwestTransport::new(build_reqwest_client());
let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry();
let compression = if self
.config
.features
.enabled(Feature::EnableRequestCompression)
&& auth
.as_ref()
.is_some_and(|auth| auth.mode == AuthMode::ChatGPT)
&& self.provider.is_openai()
{
Compression::Zstd
} else {
Compression::None
};
let compression = self.responses_request_compression(auth.as_ref());
let client = ApiResponsesClient::new(transport, api_provider, api_auth)
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let options = ApiResponsesOptions {
reasoning: reasoning.clone(),
include: include.clone(),
prompt_cache_key: Some(conversation_id.clone()),
text: text.clone(),
store_override: None,
conversation_id: Some(conversation_id.clone()),
session_source: Some(session_source.clone()),
extra_headers: beta_feature_headers(&self.config),
compression,
};
let options = self.build_responses_options(prompt, compression);
let stream_result = client
.stream_prompt(&self.get_model(), &api_prompt, options)
.stream_prompt(&self.state.model_info.slug, &api_prompt, options)
.await;
match stream_result {
Ok(stream) => {
return Ok(map_response_stream(stream, self.otel_manager.clone()));
return Ok(map_response_stream(stream, self.state.otel_manager.clone()));
}
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UNAUTHORIZED =>
@ -312,106 +446,61 @@ impl ModelClient {
}
}
pub fn get_provider(&self) -> ModelProviderInfo {
self.provider.clone()
}
/// Streams a turn via the Responses API over WebSocket transport.
async fn stream_responses_websocket(&self, prompt: &Prompt) -> Result<ResponseStream> {
let auth_manager = self.state.auth_manager.clone();
let api_prompt = self.build_responses_request(prompt)?;
pub fn get_otel_manager(&self) -> OtelManager {
self.otel_manager.clone()
}
pub fn get_session_source(&self) -> SessionSource {
self.session_source.clone()
}
/// Returns the currently configured model slug.
pub fn get_model(&self) -> String {
self.model_info.slug.clone()
}
pub fn get_model_info(&self) -> ModelInfo {
self.model_info.clone()
}
/// Returns the current reasoning effort setting.
pub fn get_reasoning_effort(&self) -> Option<ReasoningEffortConfig> {
self.effort
}
/// Returns the current reasoning summary setting.
pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig {
self.summary
}
pub fn get_auth_manager(&self) -> Option<Arc<AuthManager>> {
self.auth_manager.clone()
}
/// Compacts the current conversation history using the Compact endpoint.
///
/// This is a unary call (no streaming) that returns a new list of
/// `ResponseItem`s representing the compacted transcript.
pub async fn compact_conversation_history(&self, prompt: &Prompt) -> Result<Vec<ResponseItem>> {
if prompt.input.is_empty() {
return Ok(Vec::new());
}
let auth_manager = self.auth_manager.clone();
let auth = match auth_manager.as_ref() {
Some(manager) => manager.auth().await,
None => None,
};
let api_provider = self
.provider
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry = self.build_request_telemetry();
let client = ApiCompactClient::new(transport, api_provider, api_auth)
.with_telemetry(Some(request_telemetry));
let instructions = prompt
.get_full_instructions(&self.get_model_info())
.into_owned();
let payload = ApiCompactionInput {
model: &self.get_model(),
input: &prompt.input,
instructions: &instructions,
};
let mut extra_headers = ApiHeaderMap::new();
if let SessionSource::SubAgent(sub) = &self.session_source {
let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub {
label.clone()
} else {
serde_json::to_value(sub)
.ok()
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
.unwrap_or_else(|| "other".to_string())
let mut auth_recovery = auth_manager
.as_ref()
.map(super::auth::AuthManager::unauthorized_recovery);
loop {
let auth = match auth_manager.as_ref() {
Some(manager) => manager.auth().await,
None => None,
};
if let Ok(val) = HeaderValue::from_str(&subagent) {
extra_headers.insert("x-openai-subagent", val);
let api_provider = self
.state
.provider
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
let compression = self.responses_request_compression(auth.as_ref());
let options = self.build_responses_options(prompt, compression);
let client = ApiWebSocketResponsesClient::new(api_provider, api_auth);
let stream_result = client
.stream_prompt(&self.state.model_info.slug, &api_prompt, options)
.await;
match stream_result {
Ok(stream) => {
return Ok(map_response_stream(stream, self.state.otel_manager.clone()));
}
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UNAUTHORIZED =>
{
handle_unauthorized(status, &mut auth_recovery).await?;
continue;
}
Err(err) => return Err(map_api_error(err)),
}
}
client
.compact_input(&payload, extra_headers)
.await
.map_err(map_api_error)
}
}
impl ModelClient {
/// Builds request and SSE telemetry for streaming API calls (Chat/Responses).
fn build_streaming_telemetry(&self) -> (Arc<dyn RequestTelemetry>, Arc<dyn SseTelemetry>) {
let telemetry = Arc::new(ApiTelemetry::new(self.otel_manager.clone()));
let telemetry = Arc::new(ApiTelemetry::new(self.state.otel_manager.clone()));
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry.clone();
let sse_telemetry: Arc<dyn SseTelemetry> = telemetry;
(request_telemetry, sse_telemetry)
}
}
impl ModelClient {
/// Builds request telemetry for unary API calls (e.g., Compact endpoint).
fn build_request_telemetry(&self) -> Arc<dyn RequestTelemetry> {
let telemetry = Arc::new(ApiTelemetry::new(self.otel_manager.clone()));
let telemetry = Arc::new(ApiTelemetry::new(self.state.otel_manager.clone()));
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry;
request_telemetry
}

View file

@ -78,6 +78,7 @@ use tracing::warn;
use crate::ModelProviderInfo;
use crate::WireApi;
use crate::client::ModelClient;
use crate::client::ModelClientSession;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::compact::collect_user_messages;
@ -2672,12 +2673,15 @@ async fn run_model_turn(
output_schema: turn_context.final_output_json_schema.clone(),
};
let client_session = turn_context.client.new_session();
let mut retries = 0;
loop {
let err = match try_run_turn(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
&client_session,
Arc::clone(&turn_diff_tracker),
&prompt,
cancellation_token.child_token(),
@ -2769,6 +2773,7 @@ async fn try_run_turn(
router: Arc<ToolRouter>,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
client_session: &ModelClientSession,
turn_diff_tracker: SharedTurnDiffTracker,
prompt: &Prompt,
cancellation_token: CancellationToken,
@ -2797,9 +2802,7 @@ async fn try_run_turn(
);
sess.persist_rollout_items(&[rollout_item]).await;
let mut stream = turn_context
.client
.clone()
let mut stream = client_session
.stream(prompt)
.instrument(trace_span!("stream_request"))
.or_cancel(&cancellation_token)

View file

@ -297,7 +297,8 @@ async fn drain_to_completed(
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let mut stream = turn_context.client.clone().stream(prompt).await?;
let client_session = turn_context.client.new_session();
let mut stream = client_session.stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {

View file

@ -126,6 +126,7 @@ pub use codex_protocol::protocol;
pub use codex_protocol::config_types as protocol_config_types;
pub use client::ModelClient;
pub use client::ModelClientSession;
pub use client_common::Prompt;
pub use client_common::REVIEW_PROMPT;
pub use client_common::ResponseEvent;

View file

@ -42,6 +42,10 @@ pub enum WireApi {
/// The Responses API exposed by OpenAI at `/v1/responses`.
Responses,
/// Experimental: Responses API over WebSocket transport.
#[serde(rename = "responses_websocket")]
ResponsesWebsocket,
/// Regular Chat Completions compatible with `/v1/chat/completions`.
#[default]
Chat,
@ -156,6 +160,7 @@ impl ModelProviderInfo {
query_params: self.query_params.clone(),
wire: match self.wire_api {
WireApi::Responses => ApiWireApi::Responses,
WireApi::ResponsesWebsocket => ApiWireApi::Responses,
WireApi::Chat => ApiWireApi::Chat,
},
headers,

View file

@ -98,7 +98,8 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
summary,
conversation_id,
SessionSource::Exec,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = input;

View file

@ -99,7 +99,8 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
summary,
conversation_id,
SessionSource::Exec,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {

View file

@ -15,11 +15,13 @@ codex-core = { workspace = true, features = ["test-support"] }
codex-protocol = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-cargo-bin = { workspace = true }
futures = { workspace = true }
notify = { workspace = true }
regex-lite = { workspace = true }
serde_json = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true, features = ["time"] }
tokio = { workspace = true, features = ["net", "time"] }
tokio-tungstenite = { workspace = true }
walkdir = { workspace = true }
wiremock = { workspace = true }
shlex = { workspace = true }

View file

@ -1,3 +1,4 @@
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
@ -5,7 +6,12 @@ use std::time::Duration;
use anyhow::Result;
use base64::Engine;
use codex_protocol::openai_models::ModelsResponse;
use futures::SinkExt;
use futures::StreamExt;
use serde_json::Value;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message;
use wiremock::BodyPrintLimit;
use wiremock::Match;
use wiremock::Mock;
@ -199,6 +205,47 @@ impl ResponsesRequest {
}
}
#[derive(Debug, Clone)]
pub struct WebSocketRequest {
body: Value,
}
impl WebSocketRequest {
pub fn body_json(&self) -> Value {
self.body.clone()
}
}
pub struct WebSocketTestServer {
uri: String,
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
shutdown: oneshot::Sender<()>,
task: tokio::task::JoinHandle<()>,
}
impl WebSocketTestServer {
pub fn uri(&self) -> &str {
&self.uri
}
pub fn connections(&self) -> Vec<Vec<WebSocketRequest>> {
self.connections.lock().unwrap().clone()
}
pub fn single_connection(&self) -> Vec<WebSocketRequest> {
let connections = self.connections.lock().unwrap();
if connections.len() != 1 {
panic!("expected 1 connection, got {}", connections.len());
}
connections.first().cloned().unwrap_or_default()
}
pub async fn shutdown(self) {
let _ = self.shutdown.send(());
let _ = self.task.await;
}
}
#[derive(Debug, Clone)]
pub struct ModelsMock {
requests: Arc<Mutex<Vec<wiremock::Request>>>,
@ -724,6 +771,91 @@ pub async fn start_mock_server() -> MockServer {
server
}
/// Starts a lightweight WebSocket server for `/v1/responses` tests.
///
/// Each connection consumes a queue of request/event sequences. For each
/// request message, the server records the payload and streams the matching
/// events as WebSocket text frames before moving to the next request.
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind websocket server");
let addr = listener.local_addr().expect("websocket server address");
let uri = format!("ws://{addr}");
let connections_log = Arc::new(Mutex::new(Vec::new()));
let requests = Arc::clone(&connections_log);
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let task = tokio::spawn(async move {
loop {
let accept_res = tokio::select! {
_ = &mut shutdown_rx => return,
accept_res = listener.accept() => accept_res,
};
let (stream, _) = match accept_res {
Ok(value) => value,
Err(_) => return,
};
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(_) => continue,
};
let connection_requests = {
let mut pending = connections.lock().unwrap();
pending.pop_front()
};
let Some(connection_requests) = connection_requests else {
let _ = ws_stream.close(None).await;
continue;
};
let mut connection_log = Vec::new();
for request_events in connection_requests {
let Some(Ok(message)) = ws_stream.next().await else {
break;
};
if let Some(body) = parse_ws_request_body(message) {
connection_log.push(WebSocketRequest { body });
}
for event in &request_events {
let Ok(payload) = serde_json::to_string(event) else {
continue;
};
if ws_stream.send(Message::Text(payload)).await.is_err() {
break;
}
}
}
requests.lock().unwrap().push(connection_log);
let _ = ws_stream.close(None).await;
if connections.lock().unwrap().is_empty() {
return;
}
}
});
WebSocketTestServer {
uri,
connections: connections_log,
shutdown: shutdown_tx,
task,
}
}
fn parse_ws_request_body(message: Message) -> Option<Value> {
match message {
Message::Text(text) => serde_json::from_str(&text).ok(),
Message::Binary(bytes) => serde_json::from_slice(&bytes).ok(),
_ => None,
}
}
#[derive(Clone)]
pub struct FunctionCallResponseMocks {
pub function_call: ResponseMock,

View file

@ -91,7 +91,8 @@ async fn responses_stream_includes_subagent_header_on_review() {
summary,
conversation_id,
session_source,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
@ -186,7 +187,8 @@ async fn responses_stream_includes_subagent_header_on_other() {
summary,
conversation_id,
session_source,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
@ -279,7 +281,8 @@ async fn responses_respects_model_info_overrides_from_config() {
summary,
conversation_id,
session_source,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {

View file

@ -1181,7 +1181,8 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
summary,
conversation_id,
SessionSource::Exec,
);
)
.new_session();
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Reasoning {

View file

@ -71,3 +71,4 @@ mod user_notification;
mod user_shell_cmd;
mod view_image;
mod web_search_cached;
mod websocket;

View file

@ -67,7 +67,7 @@ async fn retries_on_early_close() {
name: "openai".into(),
base_url: Some(format!("{}/v1", server.uri())),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// ModelClientSession will return an error if the environment variable for the
// provider is not set.
env_key: Some("PATH".into()),
env_key_instructions: None,

View file

@ -0,0 +1,112 @@
use codex_core::AuthManager;
use codex_core::CodexAuth;
use codex_core::ContentItem;
use codex_core::ModelClient;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::WireApi;
use codex_core::models_manager::manager::ModelsManager;
use codex_core::protocol::SessionSource;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::start_websocket_server;
use futures::StreamExt;
use std::sync::Arc;
use tempfile::TempDir;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_streams_request() {
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
.await;
let provider = ModelProviderInfo {
name: "mock-ws".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::ResponsesWebsocket,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
requires_openai_auth: false,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = ModelsManager::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
let conversation_id = ThreadId::new();
let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
let otel_manager = OtelManager::new(
conversation_id,
model.as_str(),
model_info.slug.as_str(),
None,
Some("test@test.com".to_string()),
auth_manager.get_auth_mode(),
false,
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
Arc::clone(&config),
None,
model_info,
otel_manager,
provider,
effort,
summary,
conversation_id,
SessionSource::Exec,
)
.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
let mut stream = client
.stream(&prompt)
.await
.expect("websocket stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
let connection = server.single_connection();
assert_eq!(connection.len(), 1);
let request = connection.first().cloned().unwrap();
let body = request.body_json();
assert_eq!(body["model"].as_str(), Some(model.as_str()));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
server.shutdown().await;
}

View file

@ -102,7 +102,7 @@ pub enum Op {
/// Policy to use for tool calls such as `local_shell`.
sandbox_policy: SandboxPolicy,
/// Must be a valid model slug for the [`crate::client::ModelClient`]
/// Must be a valid model slug for the configured client session
/// associated with this conversation.
model: String,