parent
87f7226cca
commit
490c1c1fdd
22 changed files with 874 additions and 196 deletions
41
codex-rs/Cargo.lock
generated
41
codex-rs/Cargo.lock
generated
|
|
@ -984,8 +984,10 @@ dependencies = [
|
|||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"url",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
|
|
@ -2126,6 +2128,7 @@ dependencies = [
|
|||
"codex-protocol",
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-cargo-bin",
|
||||
"futures",
|
||||
"notify",
|
||||
"pretty_assertions",
|
||||
"regex-lite",
|
||||
|
|
@ -2134,6 +2137,7 @@ dependencies = [
|
|||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"walkdir",
|
||||
"wiremock",
|
||||
]
|
||||
|
|
@ -2361,6 +2365,12 @@ dependencies = [
|
|||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea"
|
||||
|
||||
[[package]]
|
||||
name = "dbus"
|
||||
version = "0.9.9"
|
||||
|
|
@ -7117,6 +7127,18 @@ dependencies = [
|
|||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.18"
|
||||
|
|
@ -7511,6 +7533,25 @@ dependencies = [
|
|||
"ratatui-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.3.1",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
"sha1",
|
||||
"thiserror 1.0.69",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.18.0"
|
||||
|
|
|
|||
|
|
@ -209,6 +209,7 @@ tiny_http = "0.12"
|
|||
tokio = "1"
|
||||
tokio-stream = "0.1.18"
|
||||
tokio-test = "0.4"
|
||||
tokio-tungstenite = "0.21.0"
|
||||
tokio-util = "0.7.18"
|
||||
toml = "0.9.5"
|
||||
toml_edit = "0.24.0"
|
||||
|
|
|
|||
|
|
@ -14,11 +14,13 @@ http = { workspace = true }
|
|||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "sync", "time"] }
|
||||
tokio = { workspace = true, features = ["macros", "net", "rt", "sync", "time"] }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
url = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -2,4 +2,5 @@ pub mod chat;
|
|||
pub mod compact;
|
||||
pub mod models;
|
||||
pub mod responses;
|
||||
pub mod responses_websocket;
|
||||
mod streaming;
|
||||
|
|
|
|||
280
codex-rs/codex-api/src/endpoint/responses_websocket.rs
Normal file
280
codex-rs/codex-api/src/endpoint/responses_websocket.rs
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
use crate::auth::AuthProvider;
|
||||
use crate::common::Prompt as ApiPrompt;
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::endpoint::responses::ResponsesOptions;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::requests::ResponsesRequest;
|
||||
use crate::requests::ResponsesRequestBuilder;
|
||||
use crate::requests::responses::Compression;
|
||||
use crate::sse::responses::ResponsesStreamEvent;
|
||||
use crate::sse::responses::process_responses_event;
|
||||
use codex_client::TransportError;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Error as WsError;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use url::Url;
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
|
||||
pub struct ResponsesWebsocketClient<A: AuthProvider> {
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
}
|
||||
|
||||
impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
||||
pub fn new(provider: Provider, auth: A) -> Self {
|
||||
Self { provider, auth }
|
||||
}
|
||||
|
||||
pub async fn stream_request(
|
||||
&self,
|
||||
request: ResponsesRequest,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers, request.compression)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
options: ResponsesOptions,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let ResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key,
|
||||
text,
|
||||
store_override,
|
||||
conversation_id,
|
||||
session_source,
|
||||
extra_headers,
|
||||
compression,
|
||||
} = options;
|
||||
|
||||
// TODO (pakrym): share with HTTP based Responses API client
|
||||
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
|
||||
.tools(&prompt.tools)
|
||||
.parallel_tool_calls(prompt.parallel_tool_calls)
|
||||
.reasoning(reasoning)
|
||||
.include(include)
|
||||
.prompt_cache_key(prompt_cache_key)
|
||||
.text(text)
|
||||
.conversation(conversation_id)
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.compression(compression)
|
||||
.build(&self.provider)?;
|
||||
|
||||
self.stream_request(request).await
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
compression: Compression,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
if compression == Compression::Zstd {
|
||||
warn!(
|
||||
"request compression is not supported for websocket streaming; sending uncompressed payload"
|
||||
);
|
||||
}
|
||||
|
||||
let ws_url = Url::parse(&self.provider.url_for_path("responses"))
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?;
|
||||
let mut headers = self.provider.headers.clone();
|
||||
headers.extend(extra_headers);
|
||||
apply_auth_headers(&mut headers, &self.auth);
|
||||
|
||||
let connection = connect_websocket(ws_url, headers).await?;
|
||||
|
||||
let (tx_event, rx_event) =
|
||||
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
|
||||
let idle_timeout = self.provider.stream_idle_timeout;
|
||||
|
||||
// TODO (pakrym): surface rate limits
|
||||
// TODO (pakrym): check models etags
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = run_websocket_response_stream(
|
||||
connection.stream,
|
||||
tx_event.clone(),
|
||||
body,
|
||||
idle_timeout,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let _ = tx_event.send(Err(err)).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
}
|
||||
|
||||
// TODO (pakrym): share with /auth
|
||||
fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
|
||||
if let Some(token) = auth.bearer_token()
|
||||
&& let Ok(header) = HeaderValue::from_str(&format!("Bearer {token}"))
|
||||
{
|
||||
let _ = headers.insert(http::header::AUTHORIZATION, header);
|
||||
}
|
||||
if let Some(account_id) = auth.account_id()
|
||||
&& let Ok(header) = HeaderValue::from_str(&account_id)
|
||||
{
|
||||
let _ = headers.insert("ChatGPT-Account-ID", header);
|
||||
}
|
||||
}
|
||||
|
||||
struct WebSocketConnection {
|
||||
stream: WsStream,
|
||||
}
|
||||
|
||||
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WebSocketConnection, ApiError> {
|
||||
let mut request = url
|
||||
.clone()
|
||||
.into_client_request()
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?;
|
||||
request.headers_mut().extend(headers);
|
||||
|
||||
let (stream, _) = tokio_tungstenite::connect_async(request)
|
||||
.await
|
||||
.map_err(|err| map_ws_error(err, &url))?;
|
||||
Ok(WebSocketConnection { stream })
|
||||
}
|
||||
|
||||
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
match err {
|
||||
WsError::Http(response) => {
|
||||
let status = response.status();
|
||||
let headers = response.headers().clone();
|
||||
let body = response
|
||||
.body()
|
||||
.as_ref()
|
||||
.and_then(|bytes| String::from_utf8(bytes.clone()).ok());
|
||||
ApiError::Transport(TransportError::Http {
|
||||
status,
|
||||
url: Some(url.to_string()),
|
||||
headers: Some(headers),
|
||||
body,
|
||||
})
|
||||
}
|
||||
WsError::ConnectionClosed | WsError::AlreadyClosed => {
|
||||
ApiError::Stream("websocket closed".to_string())
|
||||
}
|
||||
WsError::Io(err) => ApiError::Transport(TransportError::Network(err.to_string())),
|
||||
other => ApiError::Transport(TransportError::Network(other.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_websocket_response_stream(
|
||||
mut ws_stream: WsStream,
|
||||
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
|
||||
request_body: Value,
|
||||
idle_timeout: Duration,
|
||||
) -> Result<(), ApiError> {
|
||||
let request_text = match serde_json::to_string(&request_body) {
|
||||
Ok(text) => text,
|
||||
Err(err) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to encode websocket request: {err}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = ws_stream.send(Message::Text(request_text)).await {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to send websocket request: {err}"
|
||||
)));
|
||||
}
|
||||
|
||||
loop {
|
||||
let response = tokio::time::timeout(idle_timeout, ws_stream.next())
|
||||
.await
|
||||
.map_err(|_| ApiError::Stream("idle timeout waiting for websocket".into()));
|
||||
let message = match response {
|
||||
Ok(Some(Ok(msg))) => msg,
|
||||
Ok(Some(Err(err))) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(err.to_string()));
|
||||
}
|
||||
Ok(None) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
));
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
match message {
|
||||
Message::Text(text) => {
|
||||
trace!("websocket event: {text}");
|
||||
let event = match serde_json::from_str::<ResponsesStreamEvent>(&text) {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
debug!("failed to parse websocket event: {err}, data: {text}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match process_responses_event(event) {
|
||||
Ok(Some(event)) => {
|
||||
let is_completed = matches!(event, ResponseEvent::Completed { .. });
|
||||
let _ = tx_event.send(Ok(event)).await;
|
||||
if is_completed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(error) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(error.into_api_error());
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Binary(_) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream("unexpected binary websocket event".into()));
|
||||
}
|
||||
Message::Ping(payload) => {
|
||||
if ws_stream.send(Message::Pong(payload)).await.is_err() {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream("websocket ping failed".into()));
|
||||
}
|
||||
}
|
||||
Message::Pong(_) => {}
|
||||
Message::Close(_) => {
|
||||
let _ = ws_stream.close(None).await;
|
||||
return Err(ApiError::Stream(
|
||||
"websocket closed before response.completed".into(),
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = ws_stream.close(None).await;
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -25,6 +25,7 @@ pub use crate::endpoint::compact::CompactClient;
|
|||
pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
pub use crate::endpoint::responses::ResponsesOptions;
|
||||
pub use crate::endpoint::responses_websocket::ResponsesWebsocketClient;
|
||||
pub use crate::error::ApiError;
|
||||
pub use crate::provider::Provider;
|
||||
pub use crate::provider::WireApi;
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ struct ResponseCompletedOutputTokensDetails {
|
|||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ResponsesStreamEvent {
|
||||
pub struct ResponsesStreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
response: Option<Value>,
|
||||
|
|
@ -149,7 +149,7 @@ impl ResponsesEventError {
|
|||
}
|
||||
}
|
||||
|
||||
fn process_responses_event(
|
||||
pub fn process_responses_event(
|
||||
event: ResponsesStreamEvent,
|
||||
) -> std::result::Result<Option<ResponseEvent>, ResponsesEventError> {
|
||||
match event.kind.as_str() {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ use codex_api::ReqwestTransport;
|
|||
use codex_api::ResponseStream as ApiResponseStream;
|
||||
use codex_api::ResponsesClient as ApiResponsesClient;
|
||||
use codex_api::ResponsesOptions as ApiResponsesOptions;
|
||||
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
|
||||
use codex_api::SseTelemetry;
|
||||
use codex_api::TransportError;
|
||||
use codex_api::common::Reasoning;
|
||||
|
|
@ -57,8 +58,8 @@ use crate::model_provider_info::WireApi;
|
|||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelClient {
|
||||
#[derive(Debug)]
|
||||
struct ModelClientState {
|
||||
config: Arc<Config>,
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
model_info: ModelInfo,
|
||||
|
|
@ -70,6 +71,16 @@ pub struct ModelClient {
|
|||
session_source: SessionSource,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelClient {
|
||||
state: Arc<ModelClientState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelClientSession {
|
||||
state: Arc<ModelClientState>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl ModelClient {
|
||||
pub fn new(
|
||||
|
|
@ -84,20 +95,30 @@ impl ModelClient {
|
|||
session_source: SessionSource,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
auth_manager,
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
conversation_id,
|
||||
effort,
|
||||
summary,
|
||||
session_source,
|
||||
state: Arc::new(ModelClientState {
|
||||
config,
|
||||
auth_manager,
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
conversation_id,
|
||||
effort,
|
||||
summary,
|
||||
session_source,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_session(&self) -> ModelClientSession {
|
||||
ModelClientSession {
|
||||
state: Arc::clone(&self.state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelClient {
|
||||
pub fn get_model_context_window(&self) -> Option<i64> {
|
||||
let model_info = self.get_model_info();
|
||||
let model_info = &self.state.model_info;
|
||||
let effective_context_window_percent = model_info.effective_context_window_percent;
|
||||
model_info.context_window.map(|context_window| {
|
||||
context_window.saturating_mul(effective_context_window_percent) / 100
|
||||
|
|
@ -105,39 +126,210 @@ impl ModelClient {
|
|||
}
|
||||
|
||||
pub fn config(&self) -> Arc<Config> {
|
||||
Arc::clone(&self.config)
|
||||
Arc::clone(&self.state.config)
|
||||
}
|
||||
|
||||
pub fn provider(&self) -> &ModelProviderInfo {
|
||||
&self.provider
|
||||
&self.state.provider
|
||||
}
|
||||
|
||||
pub fn get_provider(&self) -> ModelProviderInfo {
|
||||
self.state.provider.clone()
|
||||
}
|
||||
|
||||
pub fn get_otel_manager(&self) -> OtelManager {
|
||||
self.state.otel_manager.clone()
|
||||
}
|
||||
|
||||
pub fn get_session_source(&self) -> SessionSource {
|
||||
self.state.session_source.clone()
|
||||
}
|
||||
|
||||
/// Returns the currently configured model slug.
|
||||
pub fn get_model(&self) -> String {
|
||||
self.state.model_info.slug.clone()
|
||||
}
|
||||
|
||||
pub fn get_model_info(&self) -> ModelInfo {
|
||||
self.state.model_info.clone()
|
||||
}
|
||||
|
||||
/// Returns the current reasoning effort setting.
|
||||
pub fn get_reasoning_effort(&self) -> Option<ReasoningEffortConfig> {
|
||||
self.state.effort
|
||||
}
|
||||
|
||||
/// Returns the current reasoning summary setting.
|
||||
pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig {
|
||||
self.state.summary
|
||||
}
|
||||
|
||||
pub fn get_auth_manager(&self) -> Option<Arc<AuthManager>> {
|
||||
self.state.auth_manager.clone()
|
||||
}
|
||||
|
||||
/// Compacts the current conversation history using the Compact endpoint.
|
||||
///
|
||||
/// This is a unary call (no streaming) that returns a new list of
|
||||
/// `ResponseItem`s representing the compacted transcript.
|
||||
pub async fn compact_conversation_history(&self, prompt: &Prompt) -> Result<Vec<ResponseItem>> {
|
||||
if prompt.input.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let auth_manager = self.state.auth_manager.clone();
|
||||
let auth = match auth_manager.as_ref() {
|
||||
Some(manager) => manager.auth().await,
|
||||
None => None,
|
||||
};
|
||||
let api_provider = self
|
||||
.state
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let request_telemetry = self.build_request_telemetry();
|
||||
let client = ApiCompactClient::new(transport, api_provider, api_auth)
|
||||
.with_telemetry(Some(request_telemetry));
|
||||
|
||||
let instructions = prompt
|
||||
.get_full_instructions(&self.state.model_info)
|
||||
.into_owned();
|
||||
let payload = ApiCompactionInput {
|
||||
model: &self.state.model_info.slug,
|
||||
input: &prompt.input,
|
||||
instructions: &instructions,
|
||||
};
|
||||
|
||||
let mut extra_headers = ApiHeaderMap::new();
|
||||
if let SessionSource::SubAgent(sub) = &self.state.session_source {
|
||||
let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub {
|
||||
label.clone()
|
||||
} else {
|
||||
serde_json::to_value(sub)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string())
|
||||
};
|
||||
if let Ok(val) = HeaderValue::from_str(&subagent) {
|
||||
extra_headers.insert("x-openai-subagent", val);
|
||||
}
|
||||
}
|
||||
|
||||
client
|
||||
.compact_input(&payload, extra_headers)
|
||||
.await
|
||||
.map_err(map_api_error)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelClientSession {
|
||||
/// Streams a single model turn using either the Responses or Chat
|
||||
/// Completions wire API, depending on the configured provider.
|
||||
///
|
||||
/// For Chat providers, the underlying stream is optionally aggregated
|
||||
/// based on the `show_raw_agent_reasoning` flag in the config.
|
||||
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
match self.provider.wire_api {
|
||||
match self.state.provider.wire_api {
|
||||
WireApi::Responses => self.stream_responses_api(prompt).await,
|
||||
WireApi::ResponsesWebsocket => self.stream_responses_websocket(prompt).await,
|
||||
WireApi::Chat => {
|
||||
let api_stream = self.stream_chat_completions(prompt).await?;
|
||||
|
||||
if self.config.show_raw_agent_reasoning {
|
||||
if self.state.config.show_raw_agent_reasoning {
|
||||
Ok(map_response_stream(
|
||||
api_stream.streaming_mode(),
|
||||
self.otel_manager.clone(),
|
||||
self.state.otel_manager.clone(),
|
||||
))
|
||||
} else {
|
||||
Ok(map_response_stream(
|
||||
api_stream.aggregate(),
|
||||
self.otel_manager.clone(),
|
||||
self.state.otel_manager.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_responses_request(&self, prompt: &Prompt) -> Result<ApiPrompt> {
|
||||
let model_info = self.state.model_info.clone();
|
||||
let instructions = prompt.get_full_instructions(&model_info).into_owned();
|
||||
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
Ok(build_api_prompt(prompt, instructions, tools_json))
|
||||
}
|
||||
|
||||
fn build_responses_options(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
compression: Compression,
|
||||
) -> ApiResponsesOptions {
|
||||
let model_info = &self.state.model_info;
|
||||
|
||||
let default_reasoning_effort = model_info.default_reasoning_level;
|
||||
let reasoning = if model_info.supports_reasoning_summaries {
|
||||
Some(Reasoning {
|
||||
effort: self.state.effort.or(default_reasoning_effort),
|
||||
summary: if self.state.summary == ReasoningSummaryConfig::None {
|
||||
None
|
||||
} else {
|
||||
Some(self.state.summary)
|
||||
},
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let include = if reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let verbosity = if model_info.support_verbosity {
|
||||
self.state
|
||||
.config
|
||||
.model_verbosity
|
||||
.or(model_info.default_verbosity)
|
||||
} else {
|
||||
if self.state.config.model_verbosity.is_some() {
|
||||
warn!(
|
||||
"model_verbosity is set but ignored as the model does not support verbosity: {}",
|
||||
model_info.slug
|
||||
);
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
|
||||
let conversation_id = self.state.conversation_id.to_string();
|
||||
|
||||
ApiResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key: Some(conversation_id.clone()),
|
||||
text,
|
||||
store_override: None,
|
||||
conversation_id: Some(conversation_id),
|
||||
session_source: Some(self.state.session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.state.config),
|
||||
compression,
|
||||
}
|
||||
}
|
||||
|
||||
fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression {
|
||||
if self
|
||||
.state
|
||||
.config
|
||||
.features
|
||||
.enabled(Feature::EnableRequestCompression)
|
||||
&& auth.is_some_and(|auth| auth.mode == AuthMode::ChatGPT)
|
||||
&& self.state.provider.is_openai()
|
||||
{
|
||||
Compression::Zstd
|
||||
} else {
|
||||
Compression::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Streams a turn via the OpenAI Chat Completions API.
|
||||
///
|
||||
/// This path is only used when the provider is configured with
|
||||
|
|
@ -149,13 +341,13 @@ impl ModelClient {
|
|||
));
|
||||
}
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_info = self.get_model_info();
|
||||
let auth_manager = self.state.auth_manager.clone();
|
||||
let model_info = self.state.model_info.clone();
|
||||
let instructions = prompt.get_full_instructions(&model_info).into_owned();
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let api_prompt = build_api_prompt(prompt, instructions, tools_json);
|
||||
let conversation_id = self.conversation_id.to_string();
|
||||
let session_source = self.session_source.clone();
|
||||
let conversation_id = self.state.conversation_id.to_string();
|
||||
let session_source = self.state.session_source.clone();
|
||||
|
||||
let mut auth_recovery = auth_manager
|
||||
.as_ref()
|
||||
|
|
@ -166,9 +358,10 @@ impl ModelClient {
|
|||
None => None,
|
||||
};
|
||||
let api_provider = self
|
||||
.state
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry();
|
||||
let client = ApiChatClient::new(transport, api_provider, api_auth)
|
||||
|
|
@ -176,7 +369,7 @@ impl ModelClient {
|
|||
|
||||
let stream_result = client
|
||||
.stream_prompt(
|
||||
&self.get_model(),
|
||||
&self.state.model_info.slug,
|
||||
&api_prompt,
|
||||
Some(conversation_id.clone()),
|
||||
Some(session_source.clone()),
|
||||
|
|
@ -203,52 +396,14 @@ impl ModelClient {
|
|||
async fn stream_responses_api(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
|
||||
warn!(path, "Streaming from fixture");
|
||||
let stream = codex_api::stream_from_fixture(path, self.provider.stream_idle_timeout())
|
||||
.map_err(map_api_error)?;
|
||||
return Ok(map_response_stream(stream, self.otel_manager.clone()));
|
||||
let stream =
|
||||
codex_api::stream_from_fixture(path, self.state.provider.stream_idle_timeout())
|
||||
.map_err(map_api_error)?;
|
||||
return Ok(map_response_stream(stream, self.state.otel_manager.clone()));
|
||||
}
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_info = self.get_model_info();
|
||||
let instructions = prompt.get_full_instructions(&model_info).into_owned();
|
||||
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
|
||||
let default_reasoning_effort = model_info.default_reasoning_level;
|
||||
let reasoning = if model_info.supports_reasoning_summaries {
|
||||
Some(Reasoning {
|
||||
effort: self.effort.or(default_reasoning_effort),
|
||||
summary: if self.summary == ReasoningSummaryConfig::None {
|
||||
None
|
||||
} else {
|
||||
Some(self.summary)
|
||||
},
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let include: Vec<String> = if reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let verbosity = if model_info.support_verbosity {
|
||||
self.config.model_verbosity.or(model_info.default_verbosity)
|
||||
} else {
|
||||
if self.config.model_verbosity.is_some() {
|
||||
warn!(
|
||||
"model_verbosity is set but ignored as the model does not support verbosity: {}",
|
||||
model_info.slug
|
||||
);
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
|
||||
let api_prompt = build_api_prompt(prompt, instructions.clone(), tools_json);
|
||||
let conversation_id = self.conversation_id.to_string();
|
||||
let session_source = self.session_source.clone();
|
||||
let auth_manager = self.state.auth_manager.clone();
|
||||
let api_prompt = self.build_responses_request(prompt)?;
|
||||
|
||||
let mut auth_recovery = auth_manager
|
||||
.as_ref()
|
||||
|
|
@ -259,47 +414,26 @@ impl ModelClient {
|
|||
None => None,
|
||||
};
|
||||
let api_provider = self
|
||||
.state
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry();
|
||||
let compression = if self
|
||||
.config
|
||||
.features
|
||||
.enabled(Feature::EnableRequestCompression)
|
||||
&& auth
|
||||
.as_ref()
|
||||
.is_some_and(|auth| auth.mode == AuthMode::ChatGPT)
|
||||
&& self.provider.is_openai()
|
||||
{
|
||||
Compression::Zstd
|
||||
} else {
|
||||
Compression::None
|
||||
};
|
||||
let compression = self.responses_request_compression(auth.as_ref());
|
||||
|
||||
let client = ApiResponsesClient::new(transport, api_provider, api_auth)
|
||||
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
|
||||
|
||||
let options = ApiResponsesOptions {
|
||||
reasoning: reasoning.clone(),
|
||||
include: include.clone(),
|
||||
prompt_cache_key: Some(conversation_id.clone()),
|
||||
text: text.clone(),
|
||||
store_override: None,
|
||||
conversation_id: Some(conversation_id.clone()),
|
||||
session_source: Some(session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config),
|
||||
compression,
|
||||
};
|
||||
let options = self.build_responses_options(prompt, compression);
|
||||
|
||||
let stream_result = client
|
||||
.stream_prompt(&self.get_model(), &api_prompt, options)
|
||||
.stream_prompt(&self.state.model_info.slug, &api_prompt, options)
|
||||
.await;
|
||||
|
||||
match stream_result {
|
||||
Ok(stream) => {
|
||||
return Ok(map_response_stream(stream, self.otel_manager.clone()));
|
||||
return Ok(map_response_stream(stream, self.state.otel_manager.clone()));
|
||||
}
|
||||
Err(ApiError::Transport(TransportError::Http { status, .. }))
|
||||
if status == StatusCode::UNAUTHORIZED =>
|
||||
|
|
@ -312,106 +446,61 @@ impl ModelClient {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn get_provider(&self) -> ModelProviderInfo {
|
||||
self.provider.clone()
|
||||
}
|
||||
/// Streams a turn via the Responses API over WebSocket transport.
|
||||
async fn stream_responses_websocket(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
let auth_manager = self.state.auth_manager.clone();
|
||||
let api_prompt = self.build_responses_request(prompt)?;
|
||||
|
||||
pub fn get_otel_manager(&self) -> OtelManager {
|
||||
self.otel_manager.clone()
|
||||
}
|
||||
|
||||
pub fn get_session_source(&self) -> SessionSource {
|
||||
self.session_source.clone()
|
||||
}
|
||||
|
||||
/// Returns the currently configured model slug.
|
||||
pub fn get_model(&self) -> String {
|
||||
self.model_info.slug.clone()
|
||||
}
|
||||
|
||||
pub fn get_model_info(&self) -> ModelInfo {
|
||||
self.model_info.clone()
|
||||
}
|
||||
|
||||
/// Returns the current reasoning effort setting.
|
||||
pub fn get_reasoning_effort(&self) -> Option<ReasoningEffortConfig> {
|
||||
self.effort
|
||||
}
|
||||
|
||||
/// Returns the current reasoning summary setting.
|
||||
pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig {
|
||||
self.summary
|
||||
}
|
||||
|
||||
pub fn get_auth_manager(&self) -> Option<Arc<AuthManager>> {
|
||||
self.auth_manager.clone()
|
||||
}
|
||||
|
||||
/// Compacts the current conversation history using the Compact endpoint.
|
||||
///
|
||||
/// This is a unary call (no streaming) that returns a new list of
|
||||
/// `ResponseItem`s representing the compacted transcript.
|
||||
pub async fn compact_conversation_history(&self, prompt: &Prompt) -> Result<Vec<ResponseItem>> {
|
||||
if prompt.input.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let auth = match auth_manager.as_ref() {
|
||||
Some(manager) => manager.auth().await,
|
||||
None => None,
|
||||
};
|
||||
let api_provider = self
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let request_telemetry = self.build_request_telemetry();
|
||||
let client = ApiCompactClient::new(transport, api_provider, api_auth)
|
||||
.with_telemetry(Some(request_telemetry));
|
||||
|
||||
let instructions = prompt
|
||||
.get_full_instructions(&self.get_model_info())
|
||||
.into_owned();
|
||||
let payload = ApiCompactionInput {
|
||||
model: &self.get_model(),
|
||||
input: &prompt.input,
|
||||
instructions: &instructions,
|
||||
};
|
||||
|
||||
let mut extra_headers = ApiHeaderMap::new();
|
||||
if let SessionSource::SubAgent(sub) = &self.session_source {
|
||||
let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub {
|
||||
label.clone()
|
||||
} else {
|
||||
serde_json::to_value(sub)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string())
|
||||
let mut auth_recovery = auth_manager
|
||||
.as_ref()
|
||||
.map(super::auth::AuthManager::unauthorized_recovery);
|
||||
loop {
|
||||
let auth = match auth_manager.as_ref() {
|
||||
Some(manager) => manager.auth().await,
|
||||
None => None,
|
||||
};
|
||||
if let Ok(val) = HeaderValue::from_str(&subagent) {
|
||||
extra_headers.insert("x-openai-subagent", val);
|
||||
let api_provider = self
|
||||
.state
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
|
||||
let compression = self.responses_request_compression(auth.as_ref());
|
||||
|
||||
let options = self.build_responses_options(prompt, compression);
|
||||
let client = ApiWebSocketResponsesClient::new(api_provider, api_auth);
|
||||
|
||||
let stream_result = client
|
||||
.stream_prompt(&self.state.model_info.slug, &api_prompt, options)
|
||||
.await;
|
||||
|
||||
match stream_result {
|
||||
Ok(stream) => {
|
||||
return Ok(map_response_stream(stream, self.state.otel_manager.clone()));
|
||||
}
|
||||
Err(ApiError::Transport(TransportError::Http { status, .. }))
|
||||
if status == StatusCode::UNAUTHORIZED =>
|
||||
{
|
||||
handle_unauthorized(status, &mut auth_recovery).await?;
|
||||
continue;
|
||||
}
|
||||
Err(err) => return Err(map_api_error(err)),
|
||||
}
|
||||
}
|
||||
|
||||
client
|
||||
.compact_input(&payload, extra_headers)
|
||||
.await
|
||||
.map_err(map_api_error)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelClient {
|
||||
/// Builds request and SSE telemetry for streaming API calls (Chat/Responses).
|
||||
fn build_streaming_telemetry(&self) -> (Arc<dyn RequestTelemetry>, Arc<dyn SseTelemetry>) {
|
||||
let telemetry = Arc::new(ApiTelemetry::new(self.otel_manager.clone()));
|
||||
let telemetry = Arc::new(ApiTelemetry::new(self.state.otel_manager.clone()));
|
||||
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry.clone();
|
||||
let sse_telemetry: Arc<dyn SseTelemetry> = telemetry;
|
||||
(request_telemetry, sse_telemetry)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelClient {
|
||||
/// Builds request telemetry for unary API calls (e.g., Compact endpoint).
|
||||
fn build_request_telemetry(&self) -> Arc<dyn RequestTelemetry> {
|
||||
let telemetry = Arc::new(ApiTelemetry::new(self.otel_manager.clone()));
|
||||
let telemetry = Arc::new(ApiTelemetry::new(self.state.otel_manager.clone()));
|
||||
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry;
|
||||
request_telemetry
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ use tracing::warn;
|
|||
use crate::ModelProviderInfo;
|
||||
use crate::WireApi;
|
||||
use crate::client::ModelClient;
|
||||
use crate::client::ModelClientSession;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::compact::collect_user_messages;
|
||||
|
|
@ -2672,12 +2673,15 @@ async fn run_model_turn(
|
|||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
|
||||
let client_session = turn_context.client.new_session();
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let err = match try_run_turn(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
&client_session,
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&prompt,
|
||||
cancellation_token.child_token(),
|
||||
|
|
@ -2769,6 +2773,7 @@ async fn try_run_turn(
|
|||
router: Arc<ToolRouter>,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
client_session: &ModelClientSession,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
prompt: &Prompt,
|
||||
cancellation_token: CancellationToken,
|
||||
|
|
@ -2797,9 +2802,7 @@ async fn try_run_turn(
|
|||
);
|
||||
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
let mut stream = client_session
|
||||
.stream(prompt)
|
||||
.instrument(trace_span!("stream_request"))
|
||||
.or_cancel(&cancellation_token)
|
||||
|
|
|
|||
|
|
@ -297,7 +297,8 @@ async fn drain_to_completed(
|
|||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
let client_session = turn_context.client.new_session();
|
||||
let mut stream = client_session.stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ pub use codex_protocol::protocol;
|
|||
pub use codex_protocol::config_types as protocol_config_types;
|
||||
|
||||
pub use client::ModelClient;
|
||||
pub use client::ModelClientSession;
|
||||
pub use client_common::Prompt;
|
||||
pub use client_common::REVIEW_PROMPT;
|
||||
pub use client_common::ResponseEvent;
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@ pub enum WireApi {
|
|||
/// The Responses API exposed by OpenAI at `/v1/responses`.
|
||||
Responses,
|
||||
|
||||
/// Experimental: Responses API over WebSocket transport.
|
||||
#[serde(rename = "responses_websocket")]
|
||||
ResponsesWebsocket,
|
||||
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
#[default]
|
||||
Chat,
|
||||
|
|
@ -156,6 +160,7 @@ impl ModelProviderInfo {
|
|||
query_params: self.query_params.clone(),
|
||||
wire: match self.wire_api {
|
||||
WireApi::Responses => ApiWireApi::Responses,
|
||||
WireApi::ResponsesWebsocket => ApiWireApi::Responses,
|
||||
WireApi::Chat => ApiWireApi::Chat,
|
||||
},
|
||||
headers,
|
||||
|
|
|
|||
|
|
@ -98,7 +98,8 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
|||
summary,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
);
|
||||
)
|
||||
.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = input;
|
||||
|
|
|
|||
|
|
@ -99,7 +99,8 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
|||
summary,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
);
|
||||
)
|
||||
.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
|
|
|
|||
|
|
@ -15,11 +15,13 @@ codex-core = { workspace = true, features = ["test-support"] }
|
|||
codex-protocol = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
notify = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true, features = ["time"] }
|
||||
tokio = { workspace = true, features = ["net", "time"] }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
walkdir = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
|
@ -5,7 +6,12 @@ use std::time::Duration;
|
|||
use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use serde_json::Value;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Match;
|
||||
use wiremock::Mock;
|
||||
|
|
@ -199,6 +205,47 @@ impl ResponsesRequest {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WebSocketRequest {
|
||||
body: Value,
|
||||
}
|
||||
|
||||
impl WebSocketRequest {
|
||||
pub fn body_json(&self) -> Value {
|
||||
self.body.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WebSocketTestServer {
|
||||
uri: String,
|
||||
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
||||
shutdown: oneshot::Sender<()>,
|
||||
task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl WebSocketTestServer {
|
||||
pub fn uri(&self) -> &str {
|
||||
&self.uri
|
||||
}
|
||||
|
||||
pub fn connections(&self) -> Vec<Vec<WebSocketRequest>> {
|
||||
self.connections.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn single_connection(&self) -> Vec<WebSocketRequest> {
|
||||
let connections = self.connections.lock().unwrap();
|
||||
if connections.len() != 1 {
|
||||
panic!("expected 1 connection, got {}", connections.len());
|
||||
}
|
||||
connections.first().cloned().unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn shutdown(self) {
|
||||
let _ = self.shutdown.send(());
|
||||
let _ = self.task.await;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelsMock {
|
||||
requests: Arc<Mutex<Vec<wiremock::Request>>>,
|
||||
|
|
@ -724,6 +771,91 @@ pub async fn start_mock_server() -> MockServer {
|
|||
server
|
||||
}
|
||||
|
||||
/// Starts a lightweight WebSocket server for `/v1/responses` tests.
|
||||
///
|
||||
/// Each connection consumes a queue of request/event sequences. For each
|
||||
/// request message, the server records the payload and streams the matching
|
||||
/// events as WebSocket text frames before moving to the next request.
|
||||
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind websocket server");
|
||||
let addr = listener.local_addr().expect("websocket server address");
|
||||
let uri = format!("ws://{addr}");
|
||||
let connections_log = Arc::new(Mutex::new(Vec::new()));
|
||||
let requests = Arc::clone(&connections_log);
|
||||
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
loop {
|
||||
let accept_res = tokio::select! {
|
||||
_ = &mut shutdown_rx => return,
|
||||
accept_res = listener.accept() => accept_res,
|
||||
};
|
||||
let (stream, _) = match accept_res {
|
||||
Ok(value) => value,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
|
||||
Ok(ws) => ws,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let connection_requests = {
|
||||
let mut pending = connections.lock().unwrap();
|
||||
pending.pop_front()
|
||||
};
|
||||
|
||||
let Some(connection_requests) = connection_requests else {
|
||||
let _ = ws_stream.close(None).await;
|
||||
continue;
|
||||
};
|
||||
|
||||
let mut connection_log = Vec::new();
|
||||
for request_events in connection_requests {
|
||||
let Some(Ok(message)) = ws_stream.next().await else {
|
||||
break;
|
||||
};
|
||||
if let Some(body) = parse_ws_request_body(message) {
|
||||
connection_log.push(WebSocketRequest { body });
|
||||
}
|
||||
|
||||
for event in &request_events {
|
||||
let Ok(payload) = serde_json::to_string(event) else {
|
||||
continue;
|
||||
};
|
||||
if ws_stream.send(Message::Text(payload)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
requests.lock().unwrap().push(connection_log);
|
||||
let _ = ws_stream.close(None).await;
|
||||
|
||||
if connections.lock().unwrap().is_empty() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
WebSocketTestServer {
|
||||
uri,
|
||||
connections: connections_log,
|
||||
shutdown: shutdown_tx,
|
||||
task,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ws_request_body(message: Message) -> Option<Value> {
|
||||
match message {
|
||||
Message::Text(text) => serde_json::from_str(&text).ok(),
|
||||
Message::Binary(bytes) => serde_json::from_slice(&bytes).ok(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FunctionCallResponseMocks {
|
||||
pub function_call: ResponseMock,
|
||||
|
|
|
|||
|
|
@ -91,7 +91,8 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
|||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
);
|
||||
)
|
||||
.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
|
|
@ -186,7 +187,8 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
|||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
);
|
||||
)
|
||||
.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
|
|
@ -279,7 +281,8 @@ async fn responses_respects_model_info_overrides_from_config() {
|
|||
summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
);
|
||||
)
|
||||
.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
|
|
|
|||
|
|
@ -1181,7 +1181,8 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
|||
summary,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
);
|
||||
)
|
||||
.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input.push(ResponseItem::Reasoning {
|
||||
|
|
|
|||
|
|
@ -71,3 +71,4 @@ mod user_notification;
|
|||
mod user_shell_cmd;
|
||||
mod view_image;
|
||||
mod web_search_cached;
|
||||
mod websocket;
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ async fn retries_on_early_close() {
|
|||
name: "openai".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// ModelClientSession will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
|
|
|
|||
112
codex-rs/core/tests/suite/websocket.rs
Normal file
112
codex-rs/core/tests/suite/websocket.rs
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
use codex_core::AuthManager;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::models_manager::manager::ModelsManager;
|
||||
use codex_core::protocol::SessionSource;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_protocol::ThreadId;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::start_websocket_server;
|
||||
use futures::StreamExt;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_streams_request() {
|
||||
let server = start_websocket_server(vec![vec![vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_completed("resp-1"),
|
||||
]]])
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "mock-ws".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::ResponsesWebsocket,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let model = ModelsManager::get_model_offline(config.model.as_deref());
|
||||
config.model = Some(model.clone());
|
||||
let config = Arc::new(config);
|
||||
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
|
||||
let conversation_id = ThreadId::new();
|
||||
let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
|
||||
let otel_manager = OtelManager::new(
|
||||
conversation_id,
|
||||
model.as_str(),
|
||||
model_info.slug.as_str(),
|
||||
None,
|
||||
Some("test@test.com".to_string()),
|
||||
auth_manager.get_auth_mode(),
|
||||
false,
|
||||
"test".to_string(),
|
||||
SessionSource::Exec,
|
||||
);
|
||||
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
None,
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
SessionSource::Exec,
|
||||
)
|
||||
.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}];
|
||||
|
||||
let mut stream = client
|
||||
.stream(&prompt)
|
||||
.await
|
||||
.expect("websocket stream failed");
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let connection = server.single_connection();
|
||||
assert_eq!(connection.len(), 1);
|
||||
let request = connection.first().cloned().unwrap();
|
||||
let body = request.body_json();
|
||||
assert_eq!(body["model"].as_str(), Some(model.as_str()));
|
||||
assert_eq!(body["stream"], serde_json::Value::Bool(true));
|
||||
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
|
@ -102,7 +102,7 @@ pub enum Op {
|
|||
/// Policy to use for tool calls such as `local_shell`.
|
||||
sandbox_policy: SandboxPolicy,
|
||||
|
||||
/// Must be a valid model slug for the [`crate::client::ModelClient`]
|
||||
/// Must be a valid model slug for the configured client session
|
||||
/// associated with this conversation.
|
||||
model: String,
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue