diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 800ef1ce7..34ea8c877 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -843,6 +843,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "codex-api" +version = "0.0.0" +dependencies = [ + "anyhow", + "assert_matches", + "async-trait", + "bytes", + "codex-client", + "codex-protocol", + "eventsource-stream", + "futures", + "http", + "pretty_assertions", + "regex-lite", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", + "tokio-test", + "tokio-util", + "tracing", +] + [[package]] name = "codex-app-server" version = "0.0.0" @@ -1029,6 +1053,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "codex-client" +version = "0.0.0" +dependencies = [ + "async-trait", + "bytes", + "eventsource-stream", + "futures", + "http", + "rand 0.9.2", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", +] + [[package]] name = "codex-cloud-tasks" version = "0.0.0" @@ -1096,9 +1137,9 @@ dependencies = [ "async-channel", "async-trait", "base64", - "bytes", "chardetng", "chrono", + "codex-api", "codex-app-server-protocol", "codex-apply-patch", "codex-arg0", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 7d57b62ec..053d79d7e 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -41,6 +41,8 @@ members = [ "utils/pty", "utils/readiness", "utils/string", + "codex-client", + "codex-api", ] resolver = "2" @@ -63,6 +65,8 @@ codex-apply-patch = { path = "apply-patch" } codex-arg0 = { path = "arg0" } codex-async-utils = { path = "async-utils" } codex-backend-client = { path = "backend-client" } +codex-api = { path = "codex-api" } +codex-client = { path = "codex-client" } codex-chatgpt = { path = "chatgpt" } codex-common = { path = "common" } codex-core = { path = "core" } @@ -171,6 +175,7 @@ reqwest = "0.12" rmcp = { version = "0.9.0", default-features = false } schemars = "0.8.22" seccompiler = "0.5.0" +sentry = "0.34.0" serde = "1" serde_json = "1" serde_with = "3.14" diff --git a/codex-rs/app-server/tests/suite/send_message.rs b/codex-rs/app-server/tests/suite/send_message.rs index 8d2b36af2..39b3a31a8 100644 --- a/codex-rs/app-server/tests/suite/send_message.rs +++ b/codex-rs/app-server/tests/suite/send_message.rs @@ -272,40 +272,45 @@ async fn read_raw_response_item( mcp: &mut McpProcess, conversation_id: ConversationId, ) -> ResponseItem { - let raw_notification: JSONRPCNotification = timeout( - DEFAULT_READ_TIMEOUT, - mcp.read_stream_until_notification_message("codex/event/raw_response_item"), - ) - .await - .expect("codex/event/raw_response_item notification timeout") - .expect("codex/event/raw_response_item notification resp"); + loop { + let raw_notification: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("codex/event/raw_response_item"), + ) + .await + .expect("codex/event/raw_response_item notification timeout") + .expect("codex/event/raw_response_item notification resp"); - let serde_json::Value::Object(params) = raw_notification - .params - .expect("codex/event/raw_response_item should have params") - else { - panic!("codex/event/raw_response_item should have params"); - }; + let serde_json::Value::Object(params) = raw_notification + .params + .expect("codex/event/raw_response_item should have params") + else { + panic!("codex/event/raw_response_item should have params"); + }; - let conversation_id_value = params - .get("conversationId") - .and_then(|value| value.as_str()) - .expect("raw response item should include conversationId"); + let conversation_id_value = params + .get("conversationId") + .and_then(|value| value.as_str()) + .expect("raw response item should include conversationId"); - assert_eq!( - conversation_id_value, - conversation_id.to_string(), - "raw response item conversation mismatch" - ); + assert_eq!( + conversation_id_value, + conversation_id.to_string(), + "raw response item conversation mismatch" + ); - let msg_value = params - .get("msg") - .cloned() - .expect("raw response item should include msg payload"); + let msg_value = params + .get("msg") + .cloned() + .expect("raw response item should include msg payload"); - let event: RawResponseItemEvent = - serde_json::from_value(msg_value).expect("deserialize raw response item"); - event.item + // Ghost snapshots are produced concurrently and may arrive before the model reply. + let event: RawResponseItemEvent = + serde_json::from_value(msg_value).expect("deserialize raw response item"); + if !matches!(event.item, ResponseItem::GhostSnapshot { .. }) { + return event.item; + } + } } fn assert_instructions_message(item: &ResponseItem) { diff --git a/codex-rs/client.md b/codex-rs/client.md new file mode 100644 index 000000000..4c9027ccb --- /dev/null +++ b/codex-rs/client.md @@ -0,0 +1,206 @@ +# Client Extraction Plan + +## Goals +- Split the HTTP transport/client code out of `codex-core` into a reusable crate that is agnostic of Codex/OpenAI business logic and API schemas. +- Create a separate API library crate that houses typed requests/responses for well-known APIs (Responses, Chat Completions, Compact) and plugs into the transport crate via minimal traits. +- Preserve current behaviour (auth headers, retries, SSE handling, rate-limit parsing, compaction, fixtures) while making the APIs symmetric and avoiding code duplication. +- Keep existing consumers (`codex-core`, tests, and tools) stable by providing a small compatibility layer during the transition. + +## Snapshot of Today +- `core/src/client.rs (ModelClient)` owns config/auth/session state, chooses wire API, builds payloads, drives retries, parses SSE, compaction, and rate-limit headers. +- `core/src/chat_completions.rs` implements the Chat Completions call + SSE parser + aggregation helper. +- `core/src/client_common.rs` holds `Prompt`, tool specs, shared request structs (`ResponsesApiRequest`, `TextControls`), and `ResponseEvent`/`ResponseStream`. +- `core/src/default_client.rs` wraps `reqwest` with Codex UA/originator defaults. +- `core/src/model_provider_info.rs` models providers (base URL, headers, env keys, retry/timeout tuning) and builds `CodexRequestBuilder`s. + - Current retry logic is co-located with API handling; streaming SSE parsing is duplicated across Responses/Chat. + +## Target Crates (with interfaces) + +- `codex-client` (generic transport) + - Owns the generic HTTP machinery: a `CodexHttpClient`/`CodexRequestBuilder`-style wrapper, retry/backoff hooks, streaming connector (SSE framing + idle timeout), header injection, and optional telemetry callbacks. + - Does **not** know about OpenAI/Codex-specific paths, headers, or error codes; it only exposes HTTP-level concepts (status, headers, bodies, connection errors). + - Minimal surface: + ```rust + pub trait HttpTransport { + fn execute(&self, req: Request) -> Result; + fn stream(&self, req: Request) -> Result; + } + + pub struct Request { + pub method: Method, + pub url: String, + pub headers: HeaderMap, + pub body: Option, + pub timeout: Option, + } + ``` + - Generic client traits (request/response/chunk are abstract over the transport): + ```rust + #[async_trait::async_trait] + pub trait UnaryClient { + async fn run(&self, req: Req) -> Result; + } + + #[async_trait::async_trait] + pub trait StreamClient { + async fn run(&self, req: Req) -> Result, TransportError>; + } + + pub struct RetryPolicy { + pub max_attempts: u64, + pub base_delay: Duration, + pub retry_on: RetryOn, // e.g., transport errors + 429/5xx + } + ``` + - `RetryOn` lives in `codex-client` and captures HTTP status classes and transport failures that qualify for retry. + - Implementations in `codex-api` plug in their own request types, parsers, and retry policies while reusing the transport’s backoff and error types. + - Planned runtime helper: + ```rust + pub async fn run_with_retry( + policy: RetryPolicy, + make_req: impl Fn() -> Request, + op: F, + ) -> Result + where + F: Fn(Request) -> Fut, + Fut: Future>, + { + for attempt in 0..=policy.max_attempts { + let req = make_req(); + match op(req).await { + Ok(resp) => return Ok(resp), + Err(err) if policy.retry_on.should_retry(&err, attempt) => { + tokio::time::sleep(backoff(policy.base_delay, attempt + 1)).await; + } + Err(err) => return Err(err), + } + } + Err(TransportError::RetryLimit) + } + ``` + - Unary clients wrap `transport.execute` with this helper and then deserialize. + - Stream clients wrap the **initial** `transport.stream` call with this helper. Mid-stream disconnects are surfaced as `StreamError`s; automatic resume/reconnect can be added later on top of this primitive if we introduce cursor support. + - Common helpers: `retry::backoff(attempt)`, `errors::{TransportError, StreamError}`. + - Streaming utility (SSE framing only): + ```rust + pub fn sse_stream( + bytes: S, + idle_timeout: Duration, + tx: mpsc::Sender>, + telemetry: Option>, + ) + where + S: Stream> + Unpin + Send + 'static; + ``` + - `sse_stream` is responsible for timeouts, connection-level errors, and emitting raw `data:` chunks as UTF-8 strings; parsing those strings into structured events is done in `codex-api`. + +- `codex-api` (OpenAI/Codex API library) + - Owns typed models for Responses/Chat/Compact plus shared helpers (`Prompt`, tool specs, text controls, `ResponsesApiRequest`, etc.). + - Knows about OpenAI/Codex semantics: + - URL shapes (`/v1/responses`, `/v1/chat/completions`, `/responses/compact`). + - Provider configuration (`WireApi`, base URLs, query params, per-provider retry knobs). + - Rate-limit headers (`x-codex-*`) and their mapping into `RateLimitSnapshot` / `CreditsSnapshot`. + - Error body formats (`{ error: { type, code, message, plan_type, resets_at } }`) and how they become API errors (context window exceeded, quota/usage limit, etc.). + - SSE event names (`response.output_item.done`, `response.completed`, `response.failed`, etc.) and their mapping into high-level events. + - Provides a provider abstraction (conceptually similar to `ModelProviderInfo`): + ```rust + pub struct Provider { + pub name: String, + pub base_url: String, + pub wire: WireApi, // Responses | Chat + pub headers: HeaderMap, + pub retry: RetryConfig, + pub stream_idle_timeout: Duration, + } + + pub trait AuthProvider { + /// Returns a bearer token to use for this request (if any). + /// Implementations are expected to be cheap and to surface already-refreshed tokens; + /// higher layers (`codex-core`) remain responsible for token refresh flows. + fn bearer_token(&self) -> Option; + + /// Optional ChatGPT account id header for Chat mode. + fn account_id(&self) -> Option; + } + ``` + - Ready-made clients built on `HttpTransport`: + ```rust + pub struct ResponsesClient { /* ... */ } + impl ResponsesClient { + pub async fn stream(&self, prompt: &Prompt) -> ApiResult>; + pub async fn compact(&self, prompt: &Prompt) -> ApiResult>; + } + + pub struct ChatClient { /* ... */ } + impl ChatClient { + pub async fn stream(&self, prompt: &Prompt) -> ApiResult>; + } + + pub struct CompactClient { /* ... */ } + impl CompactClient { + pub async fn compact(&self, prompt: &Prompt) -> ApiResult>; + } + ``` + - Streaming events unified across wire APIs (this can closely mirror `ResponseEvent` today, and we may type-alias one to the other during migration): + ```rust + pub enum ApiEvent { + Created, + OutputItemAdded(ResponseItem), + OutputItemDone(ResponseItem), + OutputTextDelta(String), + ReasoningContentDelta { delta: String, content_index: i64 }, + ReasoningSummaryDelta { delta: String, summary_index: i64 }, + RateLimits(RateLimitSnapshot), + Completed { response_id: String, token_usage: Option }, + } + ``` + - Error layering: + - `codex-client`: defines `TransportError` / `StreamError` (status codes, IO, timeouts). + - `codex-api`: defines `ApiError` that wraps `TransportError` plus API-specific errors parsed from bodies and headers. + - `codex-core`: maps `ApiError` into existing `CodexErr` variants so downstream callers remain unchanged. + - Aggregation strategies (today’s `AggregateStreamExt`) live here as adapters (`Aggregated`, `Streaming`) that transform `ResponseStream` into the higher-level views used by `codex-core`. + +## Implementation Steps + +1. **Create crates**: add `codex-client` and `codex-api` (names keep the `codex-` prefix). Stub lib files with feature flags/tests wired into the workspace; wire them into `Cargo.toml`. +2. **Extract API-level SSE + rate limits into `codex-api`**: + - Move the Responses SSE parser (`process_sse`), rate-limit parsing, and related tests from `core/src/client.rs` into `codex-api`, keeping the behavior identical. + - Introduce `ApiEvent` (initially equivalent to `ResponseEvent`) and `ApiError`, and adjust the parser to emit those. + - Provide test-only helpers for fixture streams (replacement for `CODEX_RS_SSE_FIXTURE`) in `codex-api`. +3. **Lift transport layer into `codex-client`**: + - Move `CodexHttpClient`/`CodexRequestBuilder`, UA/originator plumbing, and backoff helpers from `core/src/default_client.rs` into `codex-client` (or a thin wrapper on top of it). + - Introduce `HttpTransport`, `Request`, `RetryPolicy`, `RetryOn`, and `run_with_retry` as described above. + - Keep sandbox/no-proxy toggles behind injected configuration so `codex-client` stays generic and does not depend on Codex-specific env vars. +4. **Model provider abstraction in `codex-api`**: + - Relocate `ModelProviderInfo` (base URL, env/header resolution, retry knobs, wire API enum) into `codex-api`, expressed in terms of `Provider` and `AuthProvider`. + - Ensure provider logic handles: + - URL building for Responses/Chat/Compact (including Azure special cases). + - Static and env-based headers. + - Per-provider retry and idle-timeout settings that map cleanly into `RetryPolicy`/`RetryOn`. +5. **API crate wiring**: + - Move `Prompt`, tool specs, `ResponsesApiRequest`, `TextControls`, and `ResponseEvent/ResponseStream` into `codex-api` under modules (`common`, `responses`, `chat`, `compact`), keeping public types stable or re-exported through `codex-core` as needed. + - Rebuild Responses and Chat clients on top of `HttpTransport` + `StreamClient`, reusing shared retry + SSE helpers; keep aggregation adapters as reusable strategies instead of `ModelClient`-local logic. + - Implement Compact on top of `UnaryClient` and the unary `execute` path with JSON deserialization, sharing the same retry policy. + - Keep request builders symmetric: each client prepares a `Request`, attaches headers/auth via `AuthProvider`, and plugs in its parser (streaming clients) or deserializer (unary) while sharing retry/backoff configuration derived from `Provider`. +6. **Core integration layer**: + - Replace `core::ModelClient` internals with thin adapters that construct `codex-api` clients using `Config`, `AuthManager`, and `OtelEventManager`. + - Keep the public `ModelClient` API and `ResponseEvent`/`ResponseStream` types stable by re-exporting `codex-api` types or providing type aliases. + - Preserve existing auth flows (including ChatGPT token refresh) inside `codex-core` or a thin adapter, using `AuthProvider` to surface bearer tokens to `codex-api` and handling 401/refresh semantics at this layer. +7. **Tests/migration**: + - Move unit tests for SSE parsing, retry/backoff decisions, and provider/header behavior into the new crates; keep integration tests in `core` using the compatibility layer. + - Update fixtures to be consumed via test-only adapters in `codex-api`. + - Run targeted `just fmt`, `just fix -p` for the touched crates, and scoped `cargo test -p codex-client`, `-p codex-api`, and existing `codex-core` suites. + +## Design Decisions + +- **UA construction** + - `codex-client` exposes an optional UA suffix/provider hook (tiny feature) and remains unaware of the CLI; `codex-core` / the CLI compute the full UA (including `terminal::user_agent()`) and pass the suffix or builder down. +- **Config vs provider** + - Most configuration stays in `codex-core`. `codex-api::Provider` only contains what is strictly required for HTTP (base URLs, query params, retry/timeout knobs, wire API), while higher-level knobs (reasoning defaults, verbosity flags, etc.) remain core concerns. +- **Auth flow ownership** + - Auth flows (including ChatGPT token refresh) remain in `codex-core`. `AuthProvider` simply exposes already-fresh tokens/account IDs; 401 handling and refresh retries stay in the existing auth layer. +- **Error enums** + - `codex-client` continues to define `TransportError` / `StreamError`. `codex-api` defines an `ApiError` (deriving `thiserror::Error`) that wraps `TransportError` and API-specific failures, and `codex-core` maps `ApiError` into existing `CodexErr` variants for callers. +- **Streaming reconnection semantics** + - For now, mid-stream SSE failures are surfaced as errors and only the initial connection is retried via `run_with_retry`. We will revisit mid-stream reconnect/resume once the underlying APIs support cursor/idempotent event semantics. + diff --git a/codex-rs/codex-api/Cargo.toml b/codex-rs/codex-api/Cargo.toml new file mode 100644 index 000000000..f79416c96 --- /dev/null +++ b/codex-rs/codex-api/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "codex-api" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +async-trait = { workspace = true } +bytes = { workspace = true } +codex-client = { workspace = true } +codex-protocol = { workspace = true } +futures = { workspace = true } +http = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "sync", "time"] } +tracing = { workspace = true } +eventsource-stream = { workspace = true } +regex-lite = { workspace = true } +tokio-util = { workspace = true, features = ["codec"] } + +[dev-dependencies] +anyhow = { workspace = true } +assert_matches = { workspace = true } +pretty_assertions = { workspace = true } +tokio-test = { workspace = true } + +[lints] +workspace = true diff --git a/codex-rs/codex-api/README.md b/codex-rs/codex-api/README.md new file mode 100644 index 000000000..98db0bec6 --- /dev/null +++ b/codex-rs/codex-api/README.md @@ -0,0 +1,32 @@ +# codex-api + +Typed clients for Codex/OpenAI APIs built on top of the generic transport in `codex-client`. + +- Hosts the request/response models and prompt helpers for Responses, Chat Completions, and Compact APIs. +- Owns provider configuration (base URLs, headers, query params), auth header injection, retry tuning, and stream idle settings. +- Parses SSE streams into `ResponseEvent`/`ResponseStream`, including rate-limit snapshots and API-specific error mapping. +- Serves as the wire-level layer consumed by `codex-core`; higher layers handle auth refresh and business logic. + +## Core interface + +The public interface of this crate is intentionally small and uniform: + +- **Prompted endpoints (Chat + Responses)** + - Input: a single `Prompt` plus endpoint-specific options. + - `Prompt` (re-exported as `codex_api::Prompt`) carries: + - `instructions: String` – the fully-resolved system prompt for this turn. + - `input: Vec` – conversation history and user/tool messages. + - `tools: Vec` – JSON tools compatible with the target API. + - `parallel_tool_calls: bool`. + - `output_schema: Option` – used to build `text.format` when present. + - Output: a `ResponseStream` of `ResponseEvent` (both re-exported from `common`). + +- **Compaction endpoint** + - Input: `CompactionInput<'a>` (re-exported as `codex_api::CompactionInput`): + - `model: &str`. + - `input: &[ResponseItem]` – history to compact. + - `instructions: &str` – fully-resolved compaction instructions. + - Output: `Vec`. + - `CompactClient::compact_input(&CompactionInput, extra_headers)` wraps the JSON encoding and retry/telemetry wiring. + +All HTTP details (URLs, headers, retry/backoff policies, SSE framing) are encapsulated in `codex-api` and `codex-client`. Callers construct prompts/inputs using protocol types and work with typed streams of `ResponseEvent` or compacted `ResponseItem` values. diff --git a/codex-rs/codex-api/src/auth.rs b/codex-rs/codex-api/src/auth.rs new file mode 100644 index 000000000..6c26963cb --- /dev/null +++ b/codex-rs/codex-api/src/auth.rs @@ -0,0 +1,27 @@ +use codex_client::Request; + +/// Provides bearer and account identity information for API requests. +/// +/// Implementations should be cheap and non-blocking; any asynchronous +/// refresh or I/O should be handled by higher layers before requests +/// reach this interface. +pub trait AuthProvider: Send + Sync { + fn bearer_token(&self) -> Option; + fn account_id(&self) -> Option { + None + } +} + +pub(crate) fn add_auth_headers(auth: &A, mut req: Request) -> Request { + if let Some(token) = auth.bearer_token() + && let Ok(header) = format!("Bearer {token}").parse() + { + let _ = req.headers.insert(http::header::AUTHORIZATION, header); + } + if let Some(account_id) = auth.account_id() + && let Ok(header) = account_id.parse() + { + let _ = req.headers.insert("ChatGPT-Account-ID", header); + } + req +} diff --git a/codex-rs/codex-api/src/common.rs b/codex-rs/codex-api/src/common.rs new file mode 100644 index 000000000..addab02dc --- /dev/null +++ b/codex-rs/codex-api/src/common.rs @@ -0,0 +1,167 @@ +use crate::error::ApiError; +use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; +use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; +use codex_protocol::config_types::Verbosity as VerbosityConfig; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::RateLimitSnapshot; +use codex_protocol::protocol::TokenUsage; +use futures::Stream; +use serde::Serialize; +use serde_json::Value; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; +use tokio::sync::mpsc; + +/// Canonical prompt input for Chat and Responses endpoints. +#[derive(Debug, Clone)] +pub struct Prompt { + /// Fully-resolved system instructions for this turn. + pub instructions: String, + /// Conversation history and user/tool messages. + pub input: Vec, + /// JSON-encoded tool definitions compatible with the target API. + // TODO(jif) have a proper type here + pub tools: Vec, + /// Whether parallel tool calls are permitted. + pub parallel_tool_calls: bool, + /// Optional output schema used to build the `text.format` controls. + pub output_schema: Option, +} + +/// Canonical input payload for the compaction endpoint. +#[derive(Debug, Clone, Serialize)] +pub struct CompactionInput<'a> { + pub model: &'a str, + pub input: &'a [ResponseItem], + pub instructions: &'a str, +} + +#[derive(Debug)] +pub enum ResponseEvent { + Created, + OutputItemDone(ResponseItem), + OutputItemAdded(ResponseItem), + Completed { + response_id: String, + token_usage: Option, + }, + OutputTextDelta(String), + ReasoningSummaryDelta { + delta: String, + summary_index: i64, + }, + ReasoningContentDelta { + delta: String, + content_index: i64, + }, + ReasoningSummaryPartAdded { + summary_index: i64, + }, + RateLimits(RateLimitSnapshot), +} + +#[derive(Debug, Serialize, Clone)] +pub struct Reasoning { + #[serde(skip_serializing_if = "Option::is_none")] + pub effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +#[derive(Debug, Serialize, Default, Clone)] +#[serde(rename_all = "snake_case")] +pub enum TextFormatType { + #[default] + JsonSchema, +} + +#[derive(Debug, Serialize, Default, Clone)] +pub struct TextFormat { + /// Format type used by the OpenAI text controls. + pub r#type: TextFormatType, + /// When true, the server is expected to strictly validate responses. + pub strict: bool, + /// JSON schema for the desired output. + pub schema: Value, + /// Friendly name for the format, used in telemetry/debugging. + pub name: String, +} + +/// Controls the `text` field for the Responses API, combining verbosity and +/// optional JSON schema output formatting. +#[derive(Debug, Serialize, Default, Clone)] +pub struct TextControls { + #[serde(skip_serializing_if = "Option::is_none")] + pub verbosity: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, +} + +#[derive(Debug, Serialize, Default, Clone)] +#[serde(rename_all = "lowercase")] +pub enum OpenAiVerbosity { + Low, + #[default] + Medium, + High, +} + +impl From for OpenAiVerbosity { + fn from(v: VerbosityConfig) -> Self { + match v { + VerbosityConfig::Low => OpenAiVerbosity::Low, + VerbosityConfig::Medium => OpenAiVerbosity::Medium, + VerbosityConfig::High => OpenAiVerbosity::High, + } + } +} + +#[derive(Debug, Serialize)] +pub struct ResponsesApiRequest<'a> { + pub model: &'a str, + pub instructions: &'a str, + pub input: &'a [ResponseItem], + pub tools: &'a [serde_json::Value], + pub tool_choice: &'static str, + pub parallel_tool_calls: bool, + pub reasoning: Option, + pub store: bool, + pub stream: bool, + pub include: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, +} + +pub fn create_text_param_for_request( + verbosity: Option, + output_schema: &Option, +) -> Option { + if verbosity.is_none() && output_schema.is_none() { + return None; + } + + Some(TextControls { + verbosity: verbosity.map(std::convert::Into::into), + format: output_schema.as_ref().map(|schema| TextFormat { + r#type: TextFormatType::JsonSchema, + strict: true, + schema: schema.clone(), + name: "codex_output_schema".to_string(), + }), + }) +} + +pub struct ResponseStream { + pub rx_event: mpsc::Receiver>, +} + +impl Stream for ResponseStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx_event.poll_recv(cx) + } +} diff --git a/codex-rs/codex-api/src/endpoint/chat.rs b/codex-rs/codex-api/src/endpoint/chat.rs new file mode 100644 index 000000000..4ad133dda --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/chat.rs @@ -0,0 +1,266 @@ +use crate::ChatRequest; +use crate::auth::AuthProvider; +use crate::common::Prompt as ApiPrompt; +use crate::common::ResponseEvent; +use crate::common::ResponseStream; +use crate::endpoint::streaming::StreamingClient; +use crate::error::ApiError; +use crate::provider::Provider; +use crate::provider::WireApi; +use crate::sse::chat::spawn_chat_stream; +use crate::telemetry::SseTelemetry; +use codex_client::HttpTransport; +use codex_client::RequestTelemetry; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ReasoningItemContent; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use futures::Stream; +use http::HeaderMap; +use serde_json::Value; +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; + +pub struct ChatClient { + streaming: StreamingClient, +} + +impl ChatClient { + pub fn new(transport: T, provider: Provider, auth: A) -> Self { + Self { + streaming: StreamingClient::new(transport, provider, auth), + } + } + + pub fn with_telemetry( + self, + request: Option>, + sse: Option>, + ) -> Self { + Self { + streaming: self.streaming.with_telemetry(request, sse), + } + } + + pub async fn stream_request(&self, request: ChatRequest) -> Result { + self.stream(request.body, request.headers).await + } + + pub async fn stream_prompt( + &self, + model: &str, + prompt: &ApiPrompt, + conversation_id: Option, + session_source: Option, + ) -> Result { + use crate::requests::ChatRequestBuilder; + + let request = + ChatRequestBuilder::new(model, &prompt.instructions, &prompt.input, &prompt.tools) + .conversation_id(conversation_id) + .session_source(session_source) + .build(self.streaming.provider())?; + + self.stream_request(request).await + } + + fn path(&self) -> &'static str { + match self.streaming.provider().wire { + WireApi::Chat => "chat/completions", + _ => "responses", + } + } + + pub async fn stream( + &self, + body: Value, + extra_headers: HeaderMap, + ) -> Result { + self.streaming + .stream(self.path(), body, extra_headers, spawn_chat_stream) + .await + } +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AggregateMode { + AggregatedOnly, + Streaming, +} + +/// Stream adapter that merges token deltas into a single assistant message per turn. +pub struct AggregatedStream { + inner: ResponseStream, + cumulative: String, + cumulative_reasoning: String, + pending: VecDeque, + mode: AggregateMode, +} + +impl Stream for AggregatedStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let Some(ev) = this.pending.pop_front() { + return Poll::Ready(Some(Ok(ev))); + } + + loop { + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { + let is_assistant_message = matches!( + &item, + ResponseItem::Message { role, .. } if role == "assistant" + ); + + if is_assistant_message { + match this.mode { + AggregateMode::AggregatedOnly => { + if this.cumulative.is_empty() + && let ResponseItem::Message { content, .. } = &item + && let Some(text) = content.iter().find_map(|c| match c { + ContentItem::OutputText { text } => Some(text), + _ => None, + }) + { + this.cumulative.push_str(text); + } + continue; + } + AggregateMode::Streaming => { + if this.cumulative.is_empty() { + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( + item, + )))); + } else { + continue; + } + } + } + } + + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); + } + Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { + return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); + } + Poll::Ready(Some(Ok(ResponseEvent::Completed { + response_id, + token_usage, + }))) => { + let mut emitted_any = false; + + if !this.cumulative_reasoning.is_empty() { + let aggregated_reasoning = ResponseItem::Reasoning { + id: String::new(), + summary: Vec::new(), + content: Some(vec![ReasoningItemContent::ReasoningText { + text: std::mem::take(&mut this.cumulative_reasoning), + }]), + encrypted_content: None, + }; + this.pending + .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); + emitted_any = true; + } + + if !this.cumulative.is_empty() { + let aggregated_message = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: std::mem::take(&mut this.cumulative), + }], + }; + this.pending + .push_back(ResponseEvent::OutputItemDone(aggregated_message)); + emitted_any = true; + } + + if emitted_any { + this.pending.push_back(ResponseEvent::Completed { + response_id: response_id.clone(), + token_usage: token_usage.clone(), + }); + if let Some(ev) = this.pending.pop_front() { + return Poll::Ready(Some(Ok(ev))); + } + } + + return Poll::Ready(Some(Ok(ResponseEvent::Completed { + response_id, + token_usage, + }))); + } + Poll::Ready(Some(Ok(ResponseEvent::Created))) => { + continue; + } + Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { + this.cumulative.push_str(&delta); + if matches!(this.mode, AggregateMode::Streaming) { + return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))); + } else { + continue; + } + } + Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta { + delta, + content_index, + }))) => { + this.cumulative_reasoning.push_str(&delta); + if matches!(this.mode, AggregateMode::Streaming) { + return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta { + delta, + content_index, + }))); + } else { + continue; + } + } + Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue, + Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => { + continue; + } + Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); + } + } + } + } +} + +pub trait AggregateStreamExt { + fn aggregate(self) -> AggregatedStream; + + fn streaming_mode(self) -> ResponseStream; +} + +impl AggregateStreamExt for ResponseStream { + fn aggregate(self) -> AggregatedStream { + AggregatedStream::new(self, AggregateMode::AggregatedOnly) + } + + fn streaming_mode(self) -> ResponseStream { + self + } +} + +impl AggregatedStream { + fn new(inner: ResponseStream, mode: AggregateMode) -> Self { + AggregatedStream { + inner, + cumulative: String::new(), + cumulative_reasoning: String::new(), + pending: VecDeque::new(), + mode, + } + } +} diff --git a/codex-rs/codex-api/src/endpoint/compact.rs b/codex-rs/codex-api/src/endpoint/compact.rs new file mode 100644 index 000000000..2b02ebd0f --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/compact.rs @@ -0,0 +1,162 @@ +use crate::auth::AuthProvider; +use crate::auth::add_auth_headers; +use crate::common::CompactionInput; +use crate::error::ApiError; +use crate::provider::Provider; +use crate::provider::WireApi; +use crate::telemetry::run_with_request_telemetry; +use codex_client::HttpTransport; +use codex_client::RequestTelemetry; +use codex_protocol::models::ResponseItem; +use http::HeaderMap; +use http::Method; +use serde::Deserialize; +use serde_json::to_value; +use std::sync::Arc; + +pub struct CompactClient { + transport: T, + provider: Provider, + auth: A, + request_telemetry: Option>, +} + +impl CompactClient { + pub fn new(transport: T, provider: Provider, auth: A) -> Self { + Self { + transport, + provider, + auth, + request_telemetry: None, + } + } + + pub fn with_telemetry(mut self, request: Option>) -> Self { + self.request_telemetry = request; + self + } + + fn path(&self) -> Result<&'static str, ApiError> { + match self.provider.wire { + WireApi::Compact | WireApi::Responses => Ok("responses/compact"), + WireApi::Chat => Err(ApiError::Stream( + "compact endpoint requires responses wire api".to_string(), + )), + } + } + + pub async fn compact( + &self, + body: serde_json::Value, + extra_headers: HeaderMap, + ) -> Result, ApiError> { + let path = self.path()?; + let builder = || { + let mut req = self.provider.build_request(Method::POST, path); + req.headers.extend(extra_headers.clone()); + req.body = Some(body.clone()); + add_auth_headers(&self.auth, req) + }; + + let resp = run_with_request_telemetry( + self.provider.retry.to_policy(), + self.request_telemetry.clone(), + builder, + |req| self.transport.execute(req), + ) + .await?; + let parsed: CompactHistoryResponse = + serde_json::from_slice(&resp.body).map_err(|e| ApiError::Stream(e.to_string()))?; + Ok(parsed.output) + } + + pub async fn compact_input( + &self, + input: &CompactionInput<'_>, + extra_headers: HeaderMap, + ) -> Result, ApiError> { + let body = to_value(input) + .map_err(|e| ApiError::Stream(format!("failed to encode compaction input: {e}")))?; + self.compact(body, extra_headers).await + } +} + +#[derive(Debug, Deserialize)] +struct CompactHistoryResponse { + output: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::RetryConfig; + use async_trait::async_trait; + use codex_client::Request; + use codex_client::Response; + use codex_client::StreamResponse; + use codex_client::TransportError; + use http::HeaderMap; + use std::time::Duration; + + #[derive(Clone, Default)] + struct DummyTransport; + + #[async_trait] + impl HttpTransport for DummyTransport { + async fn execute(&self, _req: Request) -> Result { + Err(TransportError::Build("execute should not run".to_string())) + } + + async fn stream(&self, _req: Request) -> Result { + Err(TransportError::Build("stream should not run".to_string())) + } + } + + #[derive(Clone, Default)] + struct DummyAuth; + + impl AuthProvider for DummyAuth { + fn bearer_token(&self) -> Option { + None + } + } + + fn provider(wire: WireApi) -> Provider { + Provider { + name: "test".to_string(), + base_url: "https://example.com/v1".to_string(), + query_params: None, + wire, + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: true, + retry_transport: true, + }, + stream_idle_timeout: Duration::from_secs(1), + } + } + + #[tokio::test] + async fn errors_when_wire_is_chat() { + let client = CompactClient::new(DummyTransport, provider(WireApi::Chat), DummyAuth); + let input = CompactionInput { + model: "gpt-test", + input: &[], + instructions: "inst", + }; + let err = client + .compact_input(&input, HeaderMap::new()) + .await + .expect_err("expected wire mismatch to fail"); + + match err { + ApiError::Stream(msg) => { + assert_eq!(msg, "compact endpoint requires responses wire api"); + } + other => panic!("unexpected error: {other:?}"), + } + } +} diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs new file mode 100644 index 000000000..104b4c264 --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -0,0 +1,4 @@ +pub mod chat; +pub mod compact; +pub mod responses; +mod streaming; diff --git a/codex-rs/codex-api/src/endpoint/responses.rs b/codex-rs/codex-api/src/endpoint/responses.rs new file mode 100644 index 000000000..d3a314d76 --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/responses.rs @@ -0,0 +1,107 @@ +use crate::auth::AuthProvider; +use crate::common::Prompt as ApiPrompt; +use crate::common::Reasoning; +use crate::common::ResponseStream; +use crate::common::TextControls; +use crate::endpoint::streaming::StreamingClient; +use crate::error::ApiError; +use crate::provider::Provider; +use crate::provider::WireApi; +use crate::requests::ResponsesRequest; +use crate::requests::ResponsesRequestBuilder; +use crate::sse::spawn_response_stream; +use crate::telemetry::SseTelemetry; +use codex_client::HttpTransport; +use codex_client::RequestTelemetry; +use codex_protocol::protocol::SessionSource; +use http::HeaderMap; +use serde_json::Value; +use std::sync::Arc; + +pub struct ResponsesClient { + streaming: StreamingClient, +} + +#[derive(Default)] +pub struct ResponsesOptions { + pub reasoning: Option, + pub include: Vec, + pub prompt_cache_key: Option, + pub text: Option, + pub store_override: Option, + pub conversation_id: Option, + pub session_source: Option, +} + +impl ResponsesClient { + pub fn new(transport: T, provider: Provider, auth: A) -> Self { + Self { + streaming: StreamingClient::new(transport, provider, auth), + } + } + + pub fn with_telemetry( + self, + request: Option>, + sse: Option>, + ) -> Self { + Self { + streaming: self.streaming.with_telemetry(request, sse), + } + } + + pub async fn stream_request( + &self, + request: ResponsesRequest, + ) -> Result { + self.stream(request.body, request.headers).await + } + + pub async fn stream_prompt( + &self, + model: &str, + prompt: &ApiPrompt, + options: ResponsesOptions, + ) -> Result { + let ResponsesOptions { + reasoning, + include, + prompt_cache_key, + text, + store_override, + conversation_id, + session_source, + } = options; + + let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input) + .tools(&prompt.tools) + .parallel_tool_calls(prompt.parallel_tool_calls) + .reasoning(reasoning) + .include(include) + .prompt_cache_key(prompt_cache_key) + .text(text) + .conversation(conversation_id) + .session_source(session_source) + .store_override(store_override) + .build(self.streaming.provider())?; + + self.stream_request(request).await + } + + fn path(&self) -> &'static str { + match self.streaming.provider().wire { + WireApi::Responses | WireApi::Compact => "responses", + WireApi::Chat => "chat/completions", + } + } + + pub async fn stream( + &self, + body: Value, + extra_headers: HeaderMap, + ) -> Result { + self.streaming + .stream(self.path(), body, extra_headers, spawn_response_stream) + .await + } +} diff --git a/codex-rs/codex-api/src/endpoint/streaming.rs b/codex-rs/codex-api/src/endpoint/streaming.rs new file mode 100644 index 000000000..156d4084b --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/streaming.rs @@ -0,0 +1,82 @@ +use crate::auth::AuthProvider; +use crate::auth::add_auth_headers; +use crate::common::ResponseStream; +use crate::error::ApiError; +use crate::provider::Provider; +use crate::telemetry::SseTelemetry; +use crate::telemetry::run_with_request_telemetry; +use codex_client::HttpTransport; +use codex_client::RequestTelemetry; +use codex_client::StreamResponse; +use http::HeaderMap; +use http::Method; +use serde_json::Value; +use std::sync::Arc; +use std::time::Duration; + +pub(crate) struct StreamingClient { + transport: T, + provider: Provider, + auth: A, + request_telemetry: Option>, + sse_telemetry: Option>, +} + +impl StreamingClient { + pub(crate) fn new(transport: T, provider: Provider, auth: A) -> Self { + Self { + transport, + provider, + auth, + request_telemetry: None, + sse_telemetry: None, + } + } + + pub(crate) fn with_telemetry( + mut self, + request: Option>, + sse: Option>, + ) -> Self { + self.request_telemetry = request; + self.sse_telemetry = sse; + self + } + + pub(crate) fn provider(&self) -> &Provider { + &self.provider + } + + pub(crate) async fn stream( + &self, + path: &str, + body: Value, + extra_headers: HeaderMap, + spawner: fn(StreamResponse, Duration, Option>) -> ResponseStream, + ) -> Result { + let builder = || { + let mut req = self.provider.build_request(Method::POST, path); + req.headers.extend(extra_headers.clone()); + req.headers.insert( + http::header::ACCEPT, + http::HeaderValue::from_static("text/event-stream"), + ); + req.body = Some(body.clone()); + add_auth_headers(&self.auth, req) + }; + + let stream_response = run_with_request_telemetry( + self.provider.retry.to_policy(), + self.request_telemetry.clone(), + builder, + |req| self.transport.stream(req), + ) + .await?; + + Ok(spawner( + stream_response, + self.provider.stream_idle_timeout, + self.sse_telemetry.clone(), + )) + } +} diff --git a/codex-rs/codex-api/src/error.rs b/codex-rs/codex-api/src/error.rs new file mode 100644 index 000000000..60118e872 --- /dev/null +++ b/codex-rs/codex-api/src/error.rs @@ -0,0 +1,34 @@ +use crate::rate_limits::RateLimitError; +use codex_client::TransportError; +use http::StatusCode; +use std::time::Duration; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ApiError { + #[error(transparent)] + Transport(#[from] TransportError), + #[error("api error {status}: {message}")] + Api { status: StatusCode, message: String }, + #[error("stream error: {0}")] + Stream(String), + #[error("context window exceeded")] + ContextWindowExceeded, + #[error("quota exceeded")] + QuotaExceeded, + #[error("usage not included")] + UsageNotIncluded, + #[error("retryable error: {message}")] + Retryable { + message: String, + delay: Option, + }, + #[error("rate limit: {0}")] + RateLimit(String), +} + +impl From for ApiError { + fn from(err: RateLimitError) -> Self { + Self::RateLimit(err.to_string()) + } +} diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs new file mode 100644 index 000000000..acde4b458 --- /dev/null +++ b/codex-rs/codex-api/src/lib.rs @@ -0,0 +1,35 @@ +pub mod auth; +pub mod common; +pub mod endpoint; +pub mod error; +pub mod provider; +pub mod rate_limits; +pub mod requests; +pub mod sse; +pub mod telemetry; + +pub use codex_client::RequestTelemetry; +pub use codex_client::ReqwestTransport; +pub use codex_client::TransportError; + +pub use crate::auth::AuthProvider; +pub use crate::common::CompactionInput; +pub use crate::common::Prompt; +pub use crate::common::ResponseEvent; +pub use crate::common::ResponseStream; +pub use crate::common::ResponsesApiRequest; +pub use crate::common::create_text_param_for_request; +pub use crate::endpoint::chat::AggregateStreamExt; +pub use crate::endpoint::chat::ChatClient; +pub use crate::endpoint::compact::CompactClient; +pub use crate::endpoint::responses::ResponsesClient; +pub use crate::endpoint::responses::ResponsesOptions; +pub use crate::error::ApiError; +pub use crate::provider::Provider; +pub use crate::provider::WireApi; +pub use crate::requests::ChatRequest; +pub use crate::requests::ChatRequestBuilder; +pub use crate::requests::ResponsesRequest; +pub use crate::requests::ResponsesRequestBuilder; +pub use crate::sse::stream_from_fixture; +pub use crate::telemetry::SseTelemetry; diff --git a/codex-rs/codex-api/src/provider.rs b/codex-rs/codex-api/src/provider.rs new file mode 100644 index 000000000..8bd5fc909 --- /dev/null +++ b/codex-rs/codex-api/src/provider.rs @@ -0,0 +1,118 @@ +use codex_client::Request; +use codex_client::RetryOn; +use codex_client::RetryPolicy; +use http::Method; +use http::header::HeaderMap; +use std::collections::HashMap; +use std::time::Duration; + +/// Wire-level APIs supported by a `Provider`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WireApi { + Responses, + Chat, + Compact, +} + +/// High-level retry configuration for a provider. +/// +/// This is converted into a `RetryPolicy` used by `codex-client` to drive +/// transport-level retries for both unary and streaming calls. +#[derive(Debug, Clone)] +pub struct RetryConfig { + pub max_attempts: u64, + pub base_delay: Duration, + pub retry_429: bool, + pub retry_5xx: bool, + pub retry_transport: bool, +} + +impl RetryConfig { + pub fn to_policy(&self) -> RetryPolicy { + RetryPolicy { + max_attempts: self.max_attempts, + base_delay: self.base_delay, + retry_on: RetryOn { + retry_429: self.retry_429, + retry_5xx: self.retry_5xx, + retry_transport: self.retry_transport, + }, + } + } +} + +/// HTTP endpoint configuration used to talk to a concrete API deployment. +/// +/// Encapsulates base URL, default headers, query params, retry policy, and +/// stream idle timeout, plus helper methods for building requests. +#[derive(Debug, Clone)] +pub struct Provider { + pub name: String, + pub base_url: String, + pub query_params: Option>, + pub wire: WireApi, + pub headers: HeaderMap, + pub retry: RetryConfig, + pub stream_idle_timeout: Duration, +} + +impl Provider { + pub fn url_for_path(&self, path: &str) -> String { + let base = self.base_url.trim_end_matches('/'); + let path = path.trim_start_matches('/'); + let mut url = if path.is_empty() { + base.to_string() + } else { + format!("{base}/{path}") + }; + + if let Some(params) = &self.query_params + && !params.is_empty() + { + let qs = params + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&"); + url.push('?'); + url.push_str(&qs); + } + + url + } + + pub fn build_request(&self, method: Method, path: &str) -> Request { + Request { + method, + url: self.url_for_path(path), + headers: self.headers.clone(), + body: None, + timeout: None, + } + } + + pub fn is_azure_responses_endpoint(&self) -> bool { + if self.wire != WireApi::Responses { + return false; + } + + if self.name.eq_ignore_ascii_case("azure") { + return true; + } + + self.base_url.to_ascii_lowercase().contains("openai.azure.") + || matches_azure_responses_base_url(&self.base_url) + } +} + +fn matches_azure_responses_base_url(base_url: &str) -> bool { + const AZURE_MARKERS: [&str; 5] = [ + "cognitiveservices.azure.", + "aoai.azure.", + "azure-api.", + "azurefd.", + "windows.net/openai", + ]; + let base = base_url.to_ascii_lowercase(); + AZURE_MARKERS.iter().any(|marker| base.contains(marker)) +} diff --git a/codex-rs/codex-api/src/rate_limits.rs b/codex-rs/codex-api/src/rate_limits.rs new file mode 100644 index 000000000..69092063f --- /dev/null +++ b/codex-rs/codex-api/src/rate_limits.rs @@ -0,0 +1,105 @@ +use codex_protocol::protocol::CreditsSnapshot; +use codex_protocol::protocol::RateLimitSnapshot; +use codex_protocol::protocol::RateLimitWindow; +use http::HeaderMap; +use std::fmt::Display; + +#[derive(Debug)] +pub struct RateLimitError { + pub message: String, +} + +impl Display for RateLimitError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +/// Parses the bespoke Codex rate-limit headers into a `RateLimitSnapshot`. +pub fn parse_rate_limit(headers: &HeaderMap) -> Option { + let primary = parse_rate_limit_window( + headers, + "x-codex-primary-used-percent", + "x-codex-primary-window-minutes", + "x-codex-primary-reset-at", + ); + + let secondary = parse_rate_limit_window( + headers, + "x-codex-secondary-used-percent", + "x-codex-secondary-window-minutes", + "x-codex-secondary-reset-at", + ); + + let credits = parse_credits_snapshot(headers); + + Some(RateLimitSnapshot { + primary, + secondary, + credits, + }) +} + +fn parse_rate_limit_window( + headers: &HeaderMap, + used_percent_header: &str, + window_minutes_header: &str, + resets_at_header: &str, +) -> Option { + let used_percent: Option = parse_header_f64(headers, used_percent_header); + + used_percent.and_then(|used_percent| { + let window_minutes = parse_header_i64(headers, window_minutes_header); + let resets_at = parse_header_i64(headers, resets_at_header); + + let has_data = used_percent != 0.0 + || window_minutes.is_some_and(|minutes| minutes != 0) + || resets_at.is_some(); + + has_data.then_some(RateLimitWindow { + used_percent, + window_minutes, + resets_at, + }) + }) +} + +fn parse_credits_snapshot(headers: &HeaderMap) -> Option { + let has_credits = parse_header_bool(headers, "x-codex-credits-has-credits")?; + let unlimited = parse_header_bool(headers, "x-codex-credits-unlimited")?; + let balance = parse_header_str(headers, "x-codex-credits-balance") + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(std::string::ToString::to_string); + Some(CreditsSnapshot { + has_credits, + unlimited, + balance, + }) +} + +fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option { + parse_header_str(headers, name)? + .parse::() + .ok() + .filter(|v| v.is_finite()) +} + +fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option { + parse_header_str(headers, name)?.parse::().ok() +} + +fn parse_header_bool(headers: &HeaderMap, name: &str) -> Option { + let raw = parse_header_str(headers, name)?; + if raw.eq_ignore_ascii_case("true") || raw == "1" { + Some(true) + } else if raw.eq_ignore_ascii_case("false") || raw == "0" { + Some(false) + } else { + None + } +} + +fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { + headers.get(name)?.to_str().ok() +} diff --git a/codex-rs/codex-api/src/requests/chat.rs b/codex-rs/codex-api/src/requests/chat.rs new file mode 100644 index 000000000..c1ba89f37 --- /dev/null +++ b/codex-rs/codex-api/src/requests/chat.rs @@ -0,0 +1,388 @@ +use crate::error::ApiError; +use crate::provider::Provider; +use crate::requests::headers::build_conversation_headers; +use crate::requests::headers::insert_header; +use crate::requests::headers::subagent_header; +use codex_protocol::models::ContentItem; +use codex_protocol::models::FunctionCallOutputContentItem; +use codex_protocol::models::ReasoningItemContent; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use http::HeaderMap; +use serde_json::Value; +use serde_json::json; +use std::collections::HashMap; + +/// Assembled request body plus headers for Chat Completions streaming calls. +pub struct ChatRequest { + pub body: Value, + pub headers: HeaderMap, +} + +pub struct ChatRequestBuilder<'a> { + model: &'a str, + instructions: &'a str, + input: &'a [ResponseItem], + tools: &'a [Value], + conversation_id: Option, + session_source: Option, +} + +impl<'a> ChatRequestBuilder<'a> { + pub fn new( + model: &'a str, + instructions: &'a str, + input: &'a [ResponseItem], + tools: &'a [Value], + ) -> Self { + Self { + model, + instructions, + input, + tools, + conversation_id: None, + session_source: None, + } + } + + pub fn conversation_id(mut self, id: Option) -> Self { + self.conversation_id = id; + self + } + + pub fn session_source(mut self, source: Option) -> Self { + self.session_source = source; + self + } + + pub fn build(self, _provider: &Provider) -> Result { + let mut messages = Vec::::new(); + messages.push(json!({"role": "system", "content": self.instructions})); + + let input = self.input; + let mut reasoning_by_anchor_index: HashMap = HashMap::new(); + let mut last_emitted_role: Option<&str> = None; + for item in input { + match item { + ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()), + ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { + last_emitted_role = Some("assistant") + } + ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"), + ResponseItem::Reasoning { .. } | ResponseItem::Other => {} + ResponseItem::CustomToolCall { .. } => {} + ResponseItem::CustomToolCallOutput { .. } => {} + ResponseItem::WebSearchCall { .. } => {} + ResponseItem::GhostSnapshot { .. } => {} + ResponseItem::CompactionSummary { .. } => {} + } + } + + let mut last_user_index: Option = None; + for (idx, item) in input.iter().enumerate() { + if let ResponseItem::Message { role, .. } = item + && role == "user" + { + last_user_index = Some(idx); + } + } + + if !matches!(last_emitted_role, Some("user")) { + for (idx, item) in input.iter().enumerate() { + if let Some(u_idx) = last_user_index + && idx <= u_idx + { + continue; + } + + if let ResponseItem::Reasoning { + content: Some(items), + .. + } = item + { + let mut text = String::new(); + for entry in items { + match entry { + ReasoningItemContent::ReasoningText { text: segment } + | ReasoningItemContent::Text { text: segment } => { + text.push_str(segment) + } + } + } + if text.trim().is_empty() { + continue; + } + + let mut attached = false; + if idx > 0 + && let ResponseItem::Message { role, .. } = &input[idx - 1] + && role == "assistant" + { + reasoning_by_anchor_index + .entry(idx - 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + attached = true; + } + + if !attached && idx + 1 < input.len() { + match &input[idx + 1] { + ResponseItem::FunctionCall { .. } + | ResponseItem::LocalShellCall { .. } => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + } + ResponseItem::Message { role, .. } if role == "assistant" => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + } + _ => {} + } + } + } + } + } + + let mut last_assistant_text: Option = None; + + for (idx, item) in input.iter().enumerate() { + match item { + ResponseItem::Message { role, content, .. } => { + let mut text = String::new(); + let mut items: Vec = Vec::new(); + let mut saw_image = false; + + for c in content { + match c { + ContentItem::InputText { text: t } + | ContentItem::OutputText { text: t } => { + text.push_str(t); + items.push(json!({"type":"text","text": t})); + } + ContentItem::InputImage { image_url } => { + saw_image = true; + items.push( + json!({"type":"image_url","image_url": {"url": image_url}}), + ); + } + } + } + + if role == "assistant" { + if let Some(prev) = &last_assistant_text + && prev == &text + { + continue; + } + last_assistant_text = Some(text.clone()); + } + + let content_value = if role == "assistant" { + json!(text) + } else if saw_image { + json!(items) + } else { + json!(text) + }; + + let mut msg = json!({"role": role, "content": content_value}); + if role == "assistant" + && let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); + } + ResponseItem::FunctionCall { + name, + arguments, + call_id, + .. + } => { + let mut msg = json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": arguments, + } + }] + }); + if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); + } + ResponseItem::LocalShellCall { + id, + call_id: _, + status, + action, + } => { + let mut msg = json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": id.clone().unwrap_or_default(), + "type": "local_shell_call", + "status": status, + "action": action, + }] + }); + if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); + } + ResponseItem::FunctionCallOutput { call_id, output } => { + let content_value = if let Some(items) = &output.content_items { + let mapped: Vec = items + .iter() + .map(|it| match it { + FunctionCallOutputContentItem::InputText { text } => { + json!({"type":"text","text": text}) + } + FunctionCallOutputContentItem::InputImage { image_url } => { + json!({"type":"image_url","image_url": {"url": image_url}}) + } + }) + .collect(); + json!(mapped) + } else { + json!(output.content) + }; + + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": content_value, + })); + } + ResponseItem::CustomToolCall { + id, + call_id: _, + name, + input, + status: _, + } => { + messages.push(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": id, + "type": "custom", + "custom": { + "name": name, + "input": input, + } + }] + })); + } + ResponseItem::CustomToolCallOutput { call_id, output } => { + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": output, + })); + } + ResponseItem::GhostSnapshot { .. } => { + continue; + } + ResponseItem::Reasoning { .. } + | ResponseItem::WebSearchCall { .. } + | ResponseItem::Other + | ResponseItem::CompactionSummary { .. } => { + continue; + } + } + } + + let payload = json!({ + "model": self.model, + "messages": messages, + "stream": true, + "tools": self.tools, + }); + + let mut headers = build_conversation_headers(self.conversation_id); + if let Some(subagent) = subagent_header(&self.session_source) { + insert_header(&mut headers, "x-openai-subagent", &subagent); + } + + Ok(ChatRequest { + body: payload, + headers, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::RetryConfig; + use crate::provider::WireApi; + use codex_protocol::protocol::SessionSource; + use codex_protocol::protocol::SubAgentSource; + use http::HeaderValue; + use pretty_assertions::assert_eq; + use std::time::Duration; + + fn provider() -> Provider { + Provider { + name: "openai".to_string(), + base_url: "https://api.openai.com/v1".to_string(), + query_params: None, + wire: WireApi::Chat, + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(10), + retry_429: false, + retry_5xx: true, + retry_transport: true, + }, + stream_idle_timeout: Duration::from_secs(1), + } + } + + #[test] + fn attaches_conversation_and_subagent_headers() { + let prompt_input = vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "hi".to_string(), + }], + }]; + let req = ChatRequestBuilder::new("gpt-test", "inst", &prompt_input, &[]) + .conversation_id(Some("conv-1".into())) + .session_source(Some(SessionSource::SubAgent(SubAgentSource::Review))) + .build(&provider()) + .expect("request"); + + assert_eq!( + req.headers.get("conversation_id"), + Some(&HeaderValue::from_static("conv-1")) + ); + assert_eq!( + req.headers.get("session_id"), + Some(&HeaderValue::from_static("conv-1")) + ); + assert_eq!( + req.headers.get("x-openai-subagent"), + Some(&HeaderValue::from_static("review")) + ); + } +} diff --git a/codex-rs/codex-api/src/requests/headers.rs b/codex-rs/codex-api/src/requests/headers.rs new file mode 100644 index 000000000..4d8a17d18 --- /dev/null +++ b/codex-rs/codex-api/src/requests/headers.rs @@ -0,0 +1,36 @@ +use codex_protocol::protocol::SessionSource; +use http::HeaderMap; +use http::HeaderValue; + +pub(crate) fn build_conversation_headers(conversation_id: Option) -> HeaderMap { + let mut headers = HeaderMap::new(); + if let Some(id) = conversation_id { + insert_header(&mut headers, "conversation_id", &id); + insert_header(&mut headers, "session_id", &id); + } + headers +} + +pub(crate) fn subagent_header(source: &Option) -> Option { + let SessionSource::SubAgent(sub) = source.as_ref()? else { + return None; + }; + match sub { + codex_protocol::protocol::SubAgentSource::Other(label) => Some(label.clone()), + other => Some( + serde_json::to_value(other) + .ok() + .and_then(|v| v.as_str().map(std::string::ToString::to_string)) + .unwrap_or_else(|| "other".to_string()), + ), + } +} + +pub(crate) fn insert_header(headers: &mut HeaderMap, name: &str, value: &str) { + if let (Ok(header_name), Ok(header_value)) = ( + name.parse::(), + HeaderValue::from_str(value), + ) { + headers.insert(header_name, header_value); + } +} diff --git a/codex-rs/codex-api/src/requests/mod.rs b/codex-rs/codex-api/src/requests/mod.rs new file mode 100644 index 000000000..f0ab23a25 --- /dev/null +++ b/codex-rs/codex-api/src/requests/mod.rs @@ -0,0 +1,8 @@ +pub mod chat; +pub(crate) mod headers; +pub mod responses; + +pub use chat::ChatRequest; +pub use chat::ChatRequestBuilder; +pub use responses::ResponsesRequest; +pub use responses::ResponsesRequestBuilder; diff --git a/codex-rs/codex-api/src/requests/responses.rs b/codex-rs/codex-api/src/requests/responses.rs new file mode 100644 index 000000000..543b79bbe --- /dev/null +++ b/codex-rs/codex-api/src/requests/responses.rs @@ -0,0 +1,247 @@ +use crate::common::Reasoning; +use crate::common::ResponsesApiRequest; +use crate::common::TextControls; +use crate::error::ApiError; +use crate::provider::Provider; +use crate::requests::headers::build_conversation_headers; +use crate::requests::headers::insert_header; +use crate::requests::headers::subagent_header; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use http::HeaderMap; +use serde_json::Value; + +/// Assembled request body plus headers for a Responses stream request. +pub struct ResponsesRequest { + pub body: Value, + pub headers: HeaderMap, +} + +#[derive(Default)] +pub struct ResponsesRequestBuilder<'a> { + model: Option<&'a str>, + instructions: Option<&'a str>, + input: Option<&'a [ResponseItem]>, + tools: Option<&'a [Value]>, + parallel_tool_calls: bool, + reasoning: Option, + include: Vec, + prompt_cache_key: Option, + text: Option, + conversation_id: Option, + session_source: Option, + store_override: Option, + headers: HeaderMap, +} + +impl<'a> ResponsesRequestBuilder<'a> { + pub fn new(model: &'a str, instructions: &'a str, input: &'a [ResponseItem]) -> Self { + Self { + model: Some(model), + instructions: Some(instructions), + input: Some(input), + ..Default::default() + } + } + + pub fn tools(mut self, tools: &'a [Value]) -> Self { + self.tools = Some(tools); + self + } + + pub fn parallel_tool_calls(mut self, enabled: bool) -> Self { + self.parallel_tool_calls = enabled; + self + } + + pub fn reasoning(mut self, reasoning: Option) -> Self { + self.reasoning = reasoning; + self + } + + pub fn include(mut self, include: Vec) -> Self { + self.include = include; + self + } + + pub fn prompt_cache_key(mut self, key: Option) -> Self { + self.prompt_cache_key = key; + self + } + + pub fn text(mut self, text: Option) -> Self { + self.text = text; + self + } + + pub fn conversation(mut self, conversation_id: Option) -> Self { + self.conversation_id = conversation_id; + self + } + + pub fn session_source(mut self, source: Option) -> Self { + self.session_source = source; + self + } + + pub fn store_override(mut self, store: Option) -> Self { + self.store_override = store; + self + } + + pub fn extra_headers(mut self, headers: HeaderMap) -> Self { + self.headers = headers; + self + } + + pub fn build(self, provider: &Provider) -> Result { + let model = self + .model + .ok_or_else(|| ApiError::Stream("missing model for responses request".into()))?; + let instructions = self + .instructions + .ok_or_else(|| ApiError::Stream("missing instructions for responses request".into()))?; + let input = self + .input + .ok_or_else(|| ApiError::Stream("missing input for responses request".into()))?; + let tools = self.tools.unwrap_or_default(); + + let store = self + .store_override + .unwrap_or_else(|| provider.is_azure_responses_endpoint()); + + let req = ResponsesApiRequest { + model, + instructions, + input, + tools, + tool_choice: "auto", + parallel_tool_calls: self.parallel_tool_calls, + reasoning: self.reasoning, + store, + stream: true, + include: self.include, + prompt_cache_key: self.prompt_cache_key, + text: self.text, + }; + + let mut body = serde_json::to_value(&req) + .map_err(|e| ApiError::Stream(format!("failed to encode responses request: {e}")))?; + + if store && provider.is_azure_responses_endpoint() { + attach_item_ids(&mut body, input); + } + + let mut headers = self.headers; + headers.extend(build_conversation_headers(self.conversation_id)); + if let Some(subagent) = subagent_header(&self.session_source) { + insert_header(&mut headers, "x-openai-subagent", &subagent); + } + + Ok(ResponsesRequest { body, headers }) + } +} + +fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) { + let Some(input_value) = payload_json.get_mut("input") else { + return; + }; + let Value::Array(items) = input_value else { + return; + }; + + for (value, item) in items.iter_mut().zip(original_items.iter()) { + if let ResponseItem::Reasoning { id, .. } + | ResponseItem::Message { id: Some(id), .. } + | ResponseItem::WebSearchCall { id: Some(id), .. } + | ResponseItem::FunctionCall { id: Some(id), .. } + | ResponseItem::LocalShellCall { id: Some(id), .. } + | ResponseItem::CustomToolCall { id: Some(id), .. } = item + { + if id.is_empty() { + continue; + } + + if let Some(obj) = value.as_object_mut() { + obj.insert("id".to_string(), Value::String(id.clone())); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::RetryConfig; + use crate::provider::WireApi; + use codex_protocol::protocol::SubAgentSource; + use http::HeaderValue; + use pretty_assertions::assert_eq; + use std::time::Duration; + + fn provider(name: &str, base_url: &str) -> Provider { + Provider { + name: name.to_string(), + base_url: base_url.to_string(), + query_params: None, + wire: WireApi::Responses, + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(50), + retry_429: false, + retry_5xx: true, + retry_transport: true, + }, + stream_idle_timeout: Duration::from_secs(5), + } + } + + #[test] + fn azure_default_store_attaches_ids_and_headers() { + let provider = provider("azure", "https://example.openai.azure.com/v1"); + let input = vec![ + ResponseItem::Message { + id: Some("m1".into()), + role: "assistant".into(), + content: Vec::new(), + }, + ResponseItem::Message { + id: None, + role: "assistant".into(), + content: Vec::new(), + }, + ]; + + let request = ResponsesRequestBuilder::new("gpt-test", "inst", &input) + .conversation(Some("conv-1".into())) + .session_source(Some(SessionSource::SubAgent(SubAgentSource::Review))) + .build(&provider) + .expect("request"); + + assert_eq!(request.body.get("store"), Some(&Value::Bool(true))); + + let ids: Vec> = request + .body + .get("input") + .and_then(|v| v.as_array()) + .into_iter() + .flatten() + .map(|item| item.get("id").and_then(|v| v.as_str().map(str::to_string))) + .collect(); + assert_eq!(ids, vec![Some("m1".to_string()), None]); + + assert_eq!( + request.headers.get("conversation_id"), + Some(&HeaderValue::from_static("conv-1")) + ); + assert_eq!( + request.headers.get("session_id"), + Some(&HeaderValue::from_static("conv-1")) + ); + assert_eq!( + request.headers.get("x-openai-subagent"), + Some(&HeaderValue::from_static("review")) + ); + } +} diff --git a/codex-rs/codex-api/src/sse/chat.rs b/codex-rs/codex-api/src/sse/chat.rs new file mode 100644 index 000000000..7f50bb634 --- /dev/null +++ b/codex-rs/codex-api/src/sse/chat.rs @@ -0,0 +1,504 @@ +use crate::common::ResponseEvent; +use crate::common::ResponseStream; +use crate::error::ApiError; +use crate::telemetry::SseTelemetry; +use codex_client::StreamResponse; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ReasoningItemContent; +use codex_protocol::models::ResponseItem; +use eventsource_stream::Eventsource; +use futures::Stream; +use futures::StreamExt; +use std::collections::HashMap; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::Instant; +use tokio::time::timeout; +use tracing::debug; +use tracing::trace; + +pub(crate) fn spawn_chat_stream( + stream_response: StreamResponse, + idle_timeout: Duration, + telemetry: Option>, +) -> ResponseStream { + let (tx_event, rx_event) = mpsc::channel::>(1600); + tokio::spawn(async move { + process_chat_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await; + }); + ResponseStream { rx_event } +} + +pub async fn process_chat_sse( + stream: S, + tx_event: mpsc::Sender>, + idle_timeout: Duration, + telemetry: Option>, +) where + S: Stream> + Unpin, +{ + let mut stream = stream.eventsource(); + + #[derive(Default, Debug)] + struct ToolCallState { + name: Option, + arguments: String, + } + + let mut tool_calls: HashMap = HashMap::new(); + let mut tool_call_order: Vec = Vec::new(); + let mut assistant_item: Option = None; + let mut reasoning_item: Option = None; + let mut completed_sent = false; + + loop { + let start = Instant::now(); + let response = timeout(idle_timeout, stream.next()).await; + if let Some(t) = telemetry.as_ref() { + t.on_sse_poll(&response, start.elapsed()); + } + let sse = match response { + Ok(Some(Ok(sse))) => sse, + Ok(Some(Err(e))) => { + let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await; + return; + } + Ok(None) => { + if let Some(reasoning) = reasoning_item { + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemDone(reasoning))) + .await; + } + + if let Some(assistant) = assistant_item { + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemDone(assistant))) + .await; + } + if !completed_sent { + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + token_usage: None, + })) + .await; + } + return; + } + Err(_) => { + let _ = tx_event + .send(Err(ApiError::Stream("idle timeout waiting for SSE".into()))) + .await; + return; + } + }; + + trace!("SSE event: {}", sse.data); + + if sse.data.trim().is_empty() { + continue; + } + + let value: serde_json::Value = match serde_json::from_str(&sse.data) { + Ok(val) => val, + Err(err) => { + debug!( + "Failed to parse ChatCompletions SSE event: {err}, data: {}", + &sse.data + ); + continue; + } + }; + + let Some(choices) = value.get("choices").and_then(|c| c.as_array()) else { + continue; + }; + + for choice in choices { + if let Some(delta) = choice.get("delta") { + if let Some(reasoning) = delta.get("reasoning") { + if let Some(text) = reasoning.as_str() { + append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()) + .await; + } else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) { + append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()) + .await; + } else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) { + append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()) + .await; + } + } + + if let Some(content) = delta.get("content") { + if content.is_array() { + for item in content.as_array().unwrap_or(&vec![]) { + if let Some(text) = item.get("text").and_then(|t| t.as_str()) { + append_assistant_text( + &tx_event, + &mut assistant_item, + text.to_string(), + ) + .await; + } + } + } else if let Some(text) = content.as_str() { + append_assistant_text(&tx_event, &mut assistant_item, text.to_string()) + .await; + } + } + + if let Some(tool_call_values) = delta.get("tool_calls").and_then(|c| c.as_array()) { + for tool_call in tool_call_values { + let id = tool_call + .get("id") + .and_then(|i| i.as_str()) + .map(str::to_string) + .unwrap_or_else(|| format!("tool-call-{}", tool_call_order.len())); + + let call_state = tool_calls.entry(id.clone()).or_default(); + if !tool_call_order.contains(&id) { + tool_call_order.push(id.clone()); + } + + if let Some(func) = tool_call.get("function") { + if let Some(fname) = func.get("name").and_then(|n| n.as_str()) { + call_state.name = Some(fname.to_string()); + } + if let Some(arguments) = func.get("arguments").and_then(|a| a.as_str()) + { + call_state.arguments.push_str(arguments); + } + } + } + } + } + + if let Some(message) = choice.get("message") + && let Some(reasoning) = message.get("reasoning") + { + if let Some(text) = reasoning.as_str() { + append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await; + } else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) { + append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await; + } else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) { + append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await; + } + } + + let finish_reason = choice.get("finish_reason").and_then(|r| r.as_str()); + if finish_reason == Some("stop") { + if let Some(reasoning) = reasoning_item.take() { + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemDone(reasoning))) + .await; + } + + if let Some(assistant) = assistant_item.take() { + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemDone(assistant))) + .await; + } + if !completed_sent { + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + token_usage: None, + })) + .await; + completed_sent = true; + } + continue; + } + + if finish_reason == Some("length") { + let _ = tx_event.send(Err(ApiError::ContextWindowExceeded)).await; + return; + } + + if finish_reason == Some("tool_calls") { + if let Some(reasoning) = reasoning_item.take() { + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemDone(reasoning))) + .await; + } + + for call_id in tool_call_order.drain(..) { + let state = tool_calls.remove(&call_id).unwrap_or_default(); + let item = ResponseItem::FunctionCall { + id: None, + name: state.name.unwrap_or_default(), + arguments: state.arguments, + call_id: call_id.clone(), + }; + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + } + } + } +} + +async fn append_assistant_text( + tx_event: &mpsc::Sender>, + assistant_item: &mut Option, + text: String, +) { + if assistant_item.is_none() { + let item = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![], + }; + *assistant_item = Some(item.clone()); + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemAdded(item))) + .await; + } + + if let Some(ResponseItem::Message { content, .. }) = assistant_item { + content.push(ContentItem::OutputText { text: text.clone() }); + let _ = tx_event + .send(Ok(ResponseEvent::OutputTextDelta(text.clone()))) + .await; + } +} + +async fn append_reasoning_text( + tx_event: &mpsc::Sender>, + reasoning_item: &mut Option, + text: String, +) { + if reasoning_item.is_none() { + let item = ResponseItem::Reasoning { + id: String::new(), + summary: Vec::new(), + content: Some(vec![]), + encrypted_content: None, + }; + *reasoning_item = Some(item.clone()); + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemAdded(item))) + .await; + } + + if let Some(ResponseItem::Reasoning { + content: Some(content), + .. + }) = reasoning_item + { + let content_index = content.len() as i64; + content.push(ReasoningItemContent::ReasoningText { text: text.clone() }); + + let _ = tx_event + .send(Ok(ResponseEvent::ReasoningContentDelta { + delta: text.clone(), + content_index, + })) + .await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use assert_matches::assert_matches; + use codex_protocol::models::ResponseItem; + use futures::TryStreamExt; + use serde_json::json; + use tokio::sync::mpsc; + use tokio_util::io::ReaderStream; + + fn build_body(events: &[serde_json::Value]) -> String { + let mut body = String::new(); + for e in events { + body.push_str(&format!("event: message\ndata: {e}\n\n")); + } + body + } + + async fn collect_events(body: &str) -> Vec { + let reader = ReaderStream::new(std::io::Cursor::new(body.to_string())) + .map_err(|err| codex_client::TransportError::Network(err.to_string())); + let (tx, mut rx) = mpsc::channel::>(16); + tokio::spawn(process_chat_sse( + reader, + tx, + Duration::from_millis(1000), + None, + )); + + let mut out = Vec::new(); + while let Some(ev) = rx.recv().await { + out.push(ev.expect("stream error")); + } + out + } + + #[tokio::test] + async fn emits_multiple_tool_calls() { + let delta_a = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "id": "call_a", + "function": { "name": "do_a", "arguments": "{\"foo\":1}" } + }] + } + }] + }); + + let delta_b = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "id": "call_b", + "function": { "name": "do_b", "arguments": "{\"bar\":2}" } + }] + } + }] + }); + + let finish = json!({ + "choices": [{ + "finish_reason": "tool_calls" + }] + }); + + let body = build_body(&[delta_a, delta_b, finish]); + let events = collect_events(&body).await; + assert_eq!(events.len(), 3); + + assert_matches!( + &events[0], + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }) + if call_id == "call_a" && name == "do_a" && arguments == "{\"foo\":1}" + ); + assert_matches!( + &events[1], + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }) + if call_id == "call_b" && name == "do_b" && arguments == "{\"bar\":2}" + ); + assert_matches!(events[2], ResponseEvent::Completed { .. }); + } + + #[tokio::test] + async fn concatenates_tool_call_arguments_across_deltas() { + let delta_name = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "id": "call_a", + "function": { "name": "do_a" } + }] + } + }] + }); + + let delta_args_1 = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "id": "call_a", + "function": { "arguments": "{ \"foo\":" } + }] + } + }] + }); + + let delta_args_2 = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "id": "call_a", + "function": { "arguments": "1}" } + }] + } + }] + }); + + let finish = json!({ + "choices": [{ + "finish_reason": "tool_calls" + }] + }); + + let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]); + let events = collect_events(&body).await; + assert_matches!( + &events[..], + [ + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }), + ResponseEvent::Completed { .. } + ] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}" + ); + } + + #[tokio::test] + async fn emits_tool_calls_even_when_content_and_reasoning_present() { + let delta_content_and_tools = json!({ + "choices": [{ + "delta": { + "content": [{"text": "hi"}], + "reasoning": "because", + "tool_calls": [{ + "id": "call_a", + "function": { "name": "do_a", "arguments": "{}" } + }] + } + }] + }); + + let finish = json!({ + "choices": [{ + "finish_reason": "tool_calls" + }] + }); + + let body = build_body(&[delta_content_and_tools, finish]); + let events = collect_events(&body).await; + + assert_matches!( + &events[..], + [ + ResponseEvent::OutputItemAdded(ResponseItem::Reasoning { .. }), + ResponseEvent::ReasoningContentDelta { .. }, + ResponseEvent::OutputItemAdded(ResponseItem::Message { .. }), + ResponseEvent::OutputTextDelta(delta), + ResponseEvent::OutputItemDone(ResponseItem::Reasoning { .. }), + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, .. }), + ResponseEvent::OutputItemDone(ResponseItem::Message { .. }), + ResponseEvent::Completed { .. } + ] if delta == "hi" && call_id == "call_a" && name == "do_a" + ); + } + + #[tokio::test] + async fn drops_partial_tool_calls_on_stop_finish_reason() { + let delta_tool = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "id": "call_a", + "function": { "name": "do_a", "arguments": "{}" } + }] + } + }] + }); + + let finish_stop = json!({ + "choices": [{ + "finish_reason": "stop" + }] + }); + + let body = build_body(&[delta_tool, finish_stop]); + let events = collect_events(&body).await; + + assert!(!events.iter().any(|ev| { + matches!( + ev, + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. }) + ) + })); + assert_matches!(events.last(), Some(ResponseEvent::Completed { .. })); + } +} diff --git a/codex-rs/codex-api/src/sse/mod.rs b/codex-rs/codex-api/src/sse/mod.rs new file mode 100644 index 000000000..e3ab770c4 --- /dev/null +++ b/codex-rs/codex-api/src/sse/mod.rs @@ -0,0 +1,6 @@ +pub mod chat; +pub mod responses; + +pub use responses::process_sse; +pub use responses::spawn_response_stream; +pub use responses::stream_from_fixture; diff --git a/codex-rs/codex-api/src/sse/responses.rs b/codex-rs/codex-api/src/sse/responses.rs new file mode 100644 index 000000000..5dbec7b77 --- /dev/null +++ b/codex-rs/codex-api/src/sse/responses.rs @@ -0,0 +1,672 @@ +use crate::common::ResponseEvent; +use crate::common::ResponseStream; +use crate::error::ApiError; +use crate::rate_limits::parse_rate_limit; +use crate::telemetry::SseTelemetry; +use codex_client::ByteStream; +use codex_client::StreamResponse; +use codex_client::TransportError; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::TokenUsage; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use futures::TryStreamExt; +use serde::Deserialize; +use serde_json::Value; +use std::io::BufRead; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::Instant; +use tokio::time::timeout; +use tokio_util::io::ReaderStream; +use tracing::debug; +use tracing::trace; + +/// Streams SSE events from an on-disk fixture for tests. +pub fn stream_from_fixture( + path: impl AsRef, + idle_timeout: Duration, +) -> Result { + let file = + std::fs::File::open(path.as_ref()).map_err(|err| ApiError::Stream(err.to_string()))?; + let mut content = String::new(); + for line in std::io::BufReader::new(file).lines() { + let line = line.map_err(|err| ApiError::Stream(err.to_string()))?; + content.push_str(&line); + content.push_str("\n\n"); + } + + let reader = std::io::Cursor::new(content); + let stream = ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string())); + let (tx_event, rx_event) = mpsc::channel::>(1600); + tokio::spawn(process_sse(Box::pin(stream), tx_event, idle_timeout, None)); + Ok(ResponseStream { rx_event }) +} + +pub fn spawn_response_stream( + stream_response: StreamResponse, + idle_timeout: Duration, + telemetry: Option>, +) -> ResponseStream { + let rate_limits = parse_rate_limit(&stream_response.headers); + let (tx_event, rx_event) = mpsc::channel::>(1600); + tokio::spawn(async move { + if let Some(snapshot) = rate_limits { + let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await; + } + process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await; + }); + + ResponseStream { rx_event } +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct Error { + r#type: Option, + code: Option, + message: Option, + plan_type: Option, + resets_at: Option, +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct ResponseCompleted { + id: String, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ResponseCompletedUsage { + input_tokens: i64, + input_tokens_details: Option, + output_tokens: i64, + output_tokens_details: Option, + total_tokens: i64, +} + +impl From for TokenUsage { + fn from(val: ResponseCompletedUsage) -> Self { + TokenUsage { + input_tokens: val.input_tokens, + cached_input_tokens: val + .input_tokens_details + .map(|d| d.cached_tokens) + .unwrap_or(0), + output_tokens: val.output_tokens, + reasoning_output_tokens: val + .output_tokens_details + .map(|d| d.reasoning_tokens) + .unwrap_or(0), + total_tokens: val.total_tokens, + } + } +} + +#[derive(Debug, Deserialize)] +struct ResponseCompletedInputTokensDetails { + cached_tokens: i64, +} + +#[derive(Debug, Deserialize)] +struct ResponseCompletedOutputTokensDetails { + reasoning_tokens: i64, +} + +#[derive(Deserialize, Debug)] +struct SseEvent { + #[serde(rename = "type")] + kind: String, + response: Option, + item: Option, + delta: Option, + summary_index: Option, + content_index: Option, +} + +pub async fn process_sse( + stream: ByteStream, + tx_event: mpsc::Sender>, + idle_timeout: Duration, + telemetry: Option>, +) { + let mut stream = stream.eventsource(); + let mut response_completed: Option = None; + let mut response_error: Option = None; + + loop { + let start = Instant::now(); + let response = timeout(idle_timeout, stream.next()).await; + if let Some(t) = telemetry.as_ref() { + t.on_sse_poll(&response, start.elapsed()); + } + let sse = match response { + Ok(Some(Ok(sse))) => sse, + Ok(Some(Err(e))) => { + debug!("SSE Error: {e:#}"); + let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await; + return; + } + Ok(None) => { + match response_completed.take() { + Some(ResponseCompleted { id, usage }) => { + let event = ResponseEvent::Completed { + response_id: id, + token_usage: usage.map(Into::into), + }; + let _ = tx_event.send(Ok(event)).await; + } + None => { + let error = response_error.unwrap_or(ApiError::Stream( + "stream closed before response.completed".into(), + )); + let _ = tx_event.send(Err(error)).await; + } + } + return; + } + Err(_) => { + let _ = tx_event + .send(Err(ApiError::Stream("idle timeout waiting for SSE".into()))) + .await; + return; + } + }; + + let raw = sse.data.clone(); + trace!("SSE event: {raw}"); + + let event: SseEvent = match serde_json::from_str(&sse.data) { + Ok(event) => event, + Err(e) => { + debug!("Failed to parse SSE event: {e}, data: {}", &sse.data); + continue; + } + }; + + match event.kind.as_str() { + "response.output_item.done" => { + let Some(item_val) = event.item else { continue }; + let Ok(item) = serde_json::from_value::(item_val) else { + debug!("failed to parse ResponseItem from output_item.done"); + continue; + }; + + let event = ResponseEvent::OutputItemDone(item); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + "response.output_text.delta" => { + if let Some(delta) = event.delta { + let event = ResponseEvent::OutputTextDelta(delta); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } + "response.reasoning_summary_text.delta" => { + if let (Some(delta), Some(summary_index)) = (event.delta, event.summary_index) { + let event = ResponseEvent::ReasoningSummaryDelta { + delta, + summary_index, + }; + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } + "response.reasoning_text.delta" => { + if let (Some(delta), Some(content_index)) = (event.delta, event.content_index) { + let event = ResponseEvent::ReasoningContentDelta { + delta, + content_index, + }; + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } + "response.created" => { + if event.response.is_some() { + let _ = tx_event.send(Ok(ResponseEvent::Created {})).await; + } + } + "response.failed" => { + if let Some(resp_val) = event.response { + response_error = + Some(ApiError::Stream("response.failed event received".into())); + + if let Some(error) = resp_val.get("error") + && let Ok(error) = serde_json::from_value::(error.clone()) + { + if is_context_window_error(&error) { + response_error = Some(ApiError::ContextWindowExceeded); + } else if is_quota_exceeded_error(&error) { + response_error = Some(ApiError::QuotaExceeded); + } else if is_usage_not_included(&error) { + response_error = Some(ApiError::UsageNotIncluded); + } else { + let delay = try_parse_retry_after(&error); + let message = error.message.clone().unwrap_or_default(); + response_error = Some(ApiError::Retryable { message, delay }); + } + } + } + } + "response.completed" => { + if let Some(resp_val) = event.response { + match serde_json::from_value::(resp_val) { + Ok(r) => { + response_completed = Some(r); + } + Err(e) => { + let error = format!("failed to parse ResponseCompleted: {e}"); + debug!(error); + response_error = Some(ApiError::Stream(error)); + continue; + } + }; + }; + } + "response.output_item.added" => { + let Some(item_val) = event.item else { continue }; + let Ok(item) = serde_json::from_value::(item_val) else { + debug!("failed to parse ResponseItem from output_item.done"); + continue; + }; + + let event = ResponseEvent::OutputItemAdded(item); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + "response.reasoning_summary_part.added" => { + if let Some(summary_index) = event.summary_index { + let event = ResponseEvent::ReasoningSummaryPartAdded { summary_index }; + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } + _ => {} + } + } +} + +fn try_parse_retry_after(err: &Error) -> Option { + if err.code.as_deref() != Some("rate_limit_exceeded") { + return None; + } + + let re = rate_limit_regex(); + if let Some(message) = &err.message + && let Some(captures) = re.captures(message) + { + let seconds = captures.get(1); + let unit = captures.get(2); + + if let (Some(value), Some(unit)) = (seconds, unit) { + let value = value.as_str().parse::().ok()?; + let unit = unit.as_str().to_ascii_lowercase(); + + if unit == "s" || unit.starts_with("second") { + return Some(Duration::from_secs_f64(value)); + } else if unit == "ms" { + return Some(Duration::from_millis(value as u64)); + } + } + } + None +} + +fn is_context_window_error(error: &Error) -> bool { + error.code.as_deref() == Some("context_length_exceeded") +} + +fn is_quota_exceeded_error(error: &Error) -> bool { + error.code.as_deref() == Some("insufficient_quota") +} + +fn is_usage_not_included(error: &Error) -> bool { + error.code.as_deref() == Some("usage_not_included") +} + +fn rate_limit_regex() -> &'static regex_lite::Regex { + static RE: std::sync::OnceLock = std::sync::OnceLock::new(); + #[expect(clippy::unwrap_used)] + RE.get_or_init(|| { + regex_lite::Regex::new(r"(?i)try again in\s*(\d+(?:\.\d+)?)\s*(s|ms|seconds?)").unwrap() + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use assert_matches::assert_matches; + use codex_protocol::models::ResponseItem; + use pretty_assertions::assert_eq; + use serde_json::json; + use tokio::sync::mpsc; + use tokio_test::io::Builder as IoBuilder; + + async fn collect_events(chunks: &[&[u8]]) -> Vec> { + let mut builder = IoBuilder::new(); + for chunk in chunks { + builder.read(chunk); + } + + let reader = builder.build(); + let stream = + ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string())); + let (tx, mut rx) = mpsc::channel::>(16); + tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None)); + + let mut events = Vec::new(); + while let Some(ev) = rx.recv().await { + events.push(ev); + } + events + } + + async fn run_sse(events: Vec) -> Vec { + let mut body = String::new(); + for e in events { + let kind = e + .get("type") + .and_then(|v| v.as_str()) + .expect("fixture event missing type"); + if e.as_object().map(|o| o.len() == 1).unwrap_or(false) { + body.push_str(&format!("event: {kind}\n\n")); + } else { + body.push_str(&format!("event: {kind}\ndata: {e}\n\n")); + } + } + + let (tx, mut rx) = mpsc::channel::>(8); + let stream = ReaderStream::new(std::io::Cursor::new(body)) + .map_err(|err| TransportError::Network(err.to_string())); + tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None)); + + let mut out = Vec::new(); + while let Some(ev) = rx.recv().await { + out.push(ev.expect("channel closed")); + } + out + } + + fn idle_timeout() -> Duration { + Duration::from_millis(1000) + } + + #[tokio::test] + async fn parses_items_and_completed() { + let item1 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello"}] + } + }) + .to_string(); + + let item2 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "World"}] + } + }) + .to_string(); + + let completed = json!({ + "type": "response.completed", + "response": { "id": "resp1" } + }) + .to_string(); + + let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); + let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n"); + let sse3 = format!("event: response.completed\ndata: {completed}\n\n"); + + let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await; + + assert_eq!(events.len(), 3); + + assert_matches!( + &events[0], + Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) + if role == "assistant" + ); + + assert_matches!( + &events[1], + Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) + if role == "assistant" + ); + + match &events[2] { + Ok(ResponseEvent::Completed { + response_id, + token_usage, + }) => { + assert_eq!(response_id, "resp1"); + assert!(token_usage.is_none()); + } + other => panic!("unexpected third event: {other:?}"), + } + } + + #[tokio::test] + async fn error_when_missing_completed() { + let item1 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello"}] + } + }) + .to_string(); + + let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); + + let events = collect_events(&[sse1.as_bytes()]).await; + + assert_eq!(events.len(), 2); + + assert_matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))); + + match &events[1] { + Err(ApiError::Stream(msg)) => { + assert_eq!(msg, "stream closed before response.completed") + } + other => panic!("unexpected second event: {other:?}"), + } + } + + #[tokio::test] + async fn error_when_error_event() { + let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#; + + let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); + + let events = collect_events(&[sse1.as_bytes()]).await; + + assert_eq!(events.len(), 1); + + match &events[0] { + Err(ApiError::Retryable { message, delay }) => { + assert_eq!( + message, + "Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more." + ); + assert_eq!(*delay, Some(Duration::from_secs_f64(11.054))); + } + other => panic!("unexpected second event: {other:?}"), + } + } + + #[tokio::test] + async fn context_window_error_is_fatal() { + let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#; + + let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); + + let events = collect_events(&[sse1.as_bytes()]).await; + + assert_eq!(events.len(), 1); + + assert_matches!(events[0], Err(ApiError::ContextWindowExceeded)); + } + + #[tokio::test] + async fn context_window_error_with_newline_is_fatal() { + let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#; + + let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); + + let events = collect_events(&[sse1.as_bytes()]).await; + + assert_eq!(events.len(), 1); + + assert_matches!(events[0], Err(ApiError::ContextWindowExceeded)); + } + + #[tokio::test] + async fn quota_exceeded_error_is_fatal() { + let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_quota","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"insufficient_quota","message":"You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors."},"incomplete_details":null}}"#; + + let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); + + let events = collect_events(&[sse1.as_bytes()]).await; + + assert_eq!(events.len(), 1); + + assert_matches!(events[0], Err(ApiError::QuotaExceeded)); + } + + #[tokio::test] + async fn table_driven_event_kinds() { + struct TestCase { + name: &'static str, + event: serde_json::Value, + expect_first: fn(&ResponseEvent) -> bool, + expected_len: usize, + } + + fn is_created(ev: &ResponseEvent) -> bool { + matches!(ev, ResponseEvent::Created) + } + fn is_output(ev: &ResponseEvent) -> bool { + matches!(ev, ResponseEvent::OutputItemDone(_)) + } + fn is_completed(ev: &ResponseEvent) -> bool { + matches!(ev, ResponseEvent::Completed { .. }) + } + + let completed = json!({ + "type": "response.completed", + "response": { + "id": "c", + "usage": { + "input_tokens": 0, + "input_tokens_details": null, + "output_tokens": 0, + "output_tokens_details": null, + "total_tokens": 0 + }, + "output": [] + } + }); + + let cases = vec![ + TestCase { + name: "created", + event: json!({"type": "response.created", "response": {}}), + expect_first: is_created, + expected_len: 2, + }, + TestCase { + name: "output_item.done", + event: json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "hi"} + ] + } + }), + expect_first: is_output, + expected_len: 2, + }, + TestCase { + name: "unknown", + event: json!({"type": "response.new_tool_event"}), + expect_first: is_completed, + expected_len: 1, + }, + ]; + + for case in cases { + let mut evs = vec![case.event]; + evs.push(completed.clone()); + + let out = run_sse(evs).await; + assert_eq!(out.len(), case.expected_len, "case {}", case.name); + assert!( + (case.expect_first)(&out[0]), + "first event mismatch in case {}", + case.name + ); + } + } + + #[test] + fn test_try_parse_retry_after() { + let err = Error { + r#type: None, + message: Some("Rate limit reached for gpt-5.1 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()), + code: Some("rate_limit_exceeded".to_string()), + plan_type: None, + resets_at: None, + }; + + let delay = try_parse_retry_after(&err); + assert_eq!(delay, Some(Duration::from_millis(28))); + } + + #[test] + fn test_try_parse_retry_after_no_delay() { + let err = Error { + r#type: None, + message: Some("Rate limit reached for gpt-5.1 in organization on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()), + code: Some("rate_limit_exceeded".to_string()), + plan_type: None, + resets_at: None, + }; + let delay = try_parse_retry_after(&err); + assert_eq!(delay, Some(Duration::from_secs_f64(1.898))); + } + + #[test] + fn test_try_parse_retry_after_azure() { + let err = Error { + r#type: None, + message: Some("Rate limit exceeded. Try again in 35 seconds.".to_string()), + code: Some("rate_limit_exceeded".to_string()), + plan_type: None, + resets_at: None, + }; + let delay = try_parse_retry_after(&err); + assert_eq!(delay, Some(Duration::from_secs(35))); + } +} diff --git a/codex-rs/codex-api/src/telemetry.rs b/codex-rs/codex-api/src/telemetry.rs new file mode 100644 index 000000000..d6a38b2af --- /dev/null +++ b/codex-rs/codex-api/src/telemetry.rs @@ -0,0 +1,84 @@ +use codex_client::Request; +use codex_client::RequestTelemetry; +use codex_client::Response; +use codex_client::RetryPolicy; +use codex_client::StreamResponse; +use codex_client::TransportError; +use codex_client::run_with_retry; +use http::StatusCode; +use std::future::Future; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::Instant; + +/// Generic telemetry. +pub trait SseTelemetry: Send + Sync { + fn on_sse_poll( + &self, + result: &Result< + Option< + Result< + eventsource_stream::Event, + eventsource_stream::EventStreamError, + >, + >, + tokio::time::error::Elapsed, + >, + duration: Duration, + ); +} + +pub(crate) trait WithStatus { + fn status(&self) -> StatusCode; +} + +fn http_status(err: &TransportError) -> Option { + match err { + TransportError::Http { status, .. } => Some(*status), + _ => None, + } +} + +impl WithStatus for Response { + fn status(&self) -> StatusCode { + self.status + } +} + +impl WithStatus for StreamResponse { + fn status(&self) -> StatusCode { + self.status + } +} + +pub(crate) async fn run_with_request_telemetry( + policy: RetryPolicy, + telemetry: Option>, + make_request: impl FnMut() -> Request, + send: F, +) -> Result +where + T: WithStatus, + F: Clone + Fn(Request) -> Fut, + Fut: Future>, +{ + // Wraps `run_with_retry` to attach per-attempt request telemetry for both + // unary and streaming HTTP calls. + run_with_retry(policy, make_request, move |req, attempt| { + let telemetry = telemetry.clone(); + let send = send.clone(); + async move { + let start = Instant::now(); + let result = send(req).await; + if let Some(t) = telemetry.as_ref() { + let (status, err) = match &result { + Ok(resp) => (Some(resp.status()), None), + Err(err) => (http_status(err), Some(err)), + }; + t.on_request(attempt, status, err, start.elapsed()); + } + result + } + }) + .await +} diff --git a/codex-rs/codex-api/tests/clients.rs b/codex-rs/codex-api/tests/clients.rs new file mode 100644 index 000000000..3dafaf74f --- /dev/null +++ b/codex-rs/codex-api/tests/clients.rs @@ -0,0 +1,315 @@ +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; +use codex_api::AuthProvider; +use codex_api::ChatClient; +use codex_api::Provider; +use codex_api::ResponsesClient; +use codex_api::ResponsesOptions; +use codex_api::WireApi; +use codex_client::HttpTransport; +use codex_client::Request; +use codex_client::Response; +use codex_client::StreamResponse; +use codex_client::TransportError; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use http::HeaderMap; +use http::StatusCode; +use pretty_assertions::assert_eq; +use serde_json::Value; + +fn assert_path_ends_with(requests: &[Request], suffix: &str) { + assert_eq!(requests.len(), 1); + let url = &requests[0].url; + assert!( + url.ends_with(suffix), + "expected url to end with {suffix}, got {url}" + ); +} + +#[derive(Debug, Default, Clone)] +struct RecordingState { + stream_requests: Arc>>, +} + +impl RecordingState { + fn record(&self, req: Request) { + let mut guard = self + .stream_requests + .lock() + .unwrap_or_else(|err| panic!("mutex poisoned: {err}")); + guard.push(req); + } + + fn take_stream_requests(&self) -> Vec { + let mut guard = self + .stream_requests + .lock() + .unwrap_or_else(|err| panic!("mutex poisoned: {err}")); + std::mem::take(&mut *guard) + } +} + +#[derive(Clone)] +struct RecordingTransport { + state: RecordingState, +} + +impl RecordingTransport { + fn new(state: RecordingState) -> Self { + Self { state } + } +} + +#[async_trait] +impl HttpTransport for RecordingTransport { + async fn execute(&self, _req: Request) -> Result { + Err(TransportError::Build("execute should not run".to_string())) + } + + async fn stream(&self, req: Request) -> Result { + self.state.record(req); + + let stream = futures::stream::iter(Vec::>::new()); + Ok(StreamResponse { + status: StatusCode::OK, + headers: HeaderMap::new(), + bytes: Box::pin(stream), + }) + } +} + +#[derive(Clone, Default)] +struct NoAuth; + +impl AuthProvider for NoAuth { + fn bearer_token(&self) -> Option { + None + } +} + +#[derive(Clone)] +struct StaticAuth { + token: String, + account_id: String, +} + +impl StaticAuth { + fn new(token: &str, account_id: &str) -> Self { + Self { + token: token.to_string(), + account_id: account_id.to_string(), + } + } +} + +impl AuthProvider for StaticAuth { + fn bearer_token(&self) -> Option { + Some(self.token.clone()) + } + + fn account_id(&self) -> Option { + Some(self.account_id.clone()) + } +} + +fn provider(name: &str, wire: WireApi) -> Provider { + Provider { + name: name.to_string(), + base_url: "https://example.com/v1".to_string(), + query_params: None, + wire, + headers: HeaderMap::new(), + retry: codex_api::provider::RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: false, + retry_transport: true, + }, + stream_idle_timeout: Duration::from_millis(10), + } +} + +#[derive(Clone)] +struct FlakyTransport { + state: Arc>, +} + +impl Default for FlakyTransport { + fn default() -> Self { + Self::new() + } +} + +impl FlakyTransport { + fn new() -> Self { + Self { + state: Arc::new(Mutex::new(0)), + } + } + + fn attempts(&self) -> i64 { + *self + .state + .lock() + .unwrap_or_else(|err| panic!("mutex poisoned: {err}")) + } +} + +#[async_trait] +impl HttpTransport for FlakyTransport { + async fn execute(&self, _req: Request) -> Result { + Err(TransportError::Build("execute should not run".to_string())) + } + + async fn stream(&self, _req: Request) -> Result { + let mut attempts = self + .state + .lock() + .unwrap_or_else(|err| panic!("mutex poisoned: {err}")); + *attempts += 1; + + if *attempts == 1 { + return Err(TransportError::Network("first attempt fails".to_string())); + } + + let stream = futures::stream::iter(vec![Ok(Bytes::from( + r#"event: message +data: {"id":"resp-1","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"hi"}]}]} + +"#, + ))]); + + Ok(StreamResponse { + status: StatusCode::OK, + headers: HeaderMap::new(), + bytes: Box::pin(stream), + }) + } +} + +#[tokio::test] +async fn chat_client_uses_chat_completions_path_for_chat_wire() -> Result<()> { + let state = RecordingState::default(); + let transport = RecordingTransport::new(state.clone()); + let client = ChatClient::new(transport, provider("openai", WireApi::Chat), NoAuth); + + let body = serde_json::json!({ "echo": true }); + let _stream = client.stream(body, HeaderMap::new()).await?; + + let requests = state.take_stream_requests(); + assert_path_ends_with(&requests, "/chat/completions"); + Ok(()) +} + +#[tokio::test] +async fn chat_client_uses_responses_path_for_responses_wire() -> Result<()> { + let state = RecordingState::default(); + let transport = RecordingTransport::new(state.clone()); + let client = ChatClient::new(transport, provider("openai", WireApi::Responses), NoAuth); + + let body = serde_json::json!({ "echo": true }); + let _stream = client.stream(body, HeaderMap::new()).await?; + + let requests = state.take_stream_requests(); + assert_path_ends_with(&requests, "/responses"); + Ok(()) +} + +#[tokio::test] +async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()> { + let state = RecordingState::default(); + let transport = RecordingTransport::new(state.clone()); + let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth); + + let body = serde_json::json!({ "echo": true }); + let _stream = client.stream(body, HeaderMap::new()).await?; + + let requests = state.take_stream_requests(); + assert_path_ends_with(&requests, "/responses"); + Ok(()) +} + +#[tokio::test] +async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> { + let state = RecordingState::default(); + let transport = RecordingTransport::new(state.clone()); + let client = ResponsesClient::new(transport, provider("openai", WireApi::Chat), NoAuth); + + let body = serde_json::json!({ "echo": true }); + let _stream = client.stream(body, HeaderMap::new()).await?; + + let requests = state.take_stream_requests(); + assert_path_ends_with(&requests, "/chat/completions"); + Ok(()) +} + +#[tokio::test] +async fn streaming_client_adds_auth_headers() -> Result<()> { + let state = RecordingState::default(); + let transport = RecordingTransport::new(state.clone()); + let auth = StaticAuth::new("secret-token", "acct-1"); + let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), auth); + + let body = serde_json::json!({ "model": "gpt-test" }); + let _stream = client.stream(body, HeaderMap::new()).await?; + + let requests = state.take_stream_requests(); + assert_eq!(requests.len(), 1); + let req = &requests[0]; + + let auth_header = req.headers.get(http::header::AUTHORIZATION); + assert!(auth_header.is_some(), "missing auth header"); + assert_eq!( + auth_header.unwrap().to_str().ok(), + Some("Bearer secret-token") + ); + + let account_header = req.headers.get("ChatGPT-Account-ID"); + assert!(account_header.is_some(), "missing account header"); + assert_eq!(account_header.unwrap().to_str().ok(), Some("acct-1")); + + let accept_header = req.headers.get(http::header::ACCEPT); + assert!(accept_header.is_some(), "missing Accept header"); + assert_eq!( + accept_header.unwrap().to_str().ok(), + Some("text/event-stream") + ); + Ok(()) +} + +#[tokio::test] +async fn streaming_client_retries_on_transport_error() -> Result<()> { + let transport = FlakyTransport::new(); + + let mut provider = provider("openai", WireApi::Responses); + provider.retry.max_attempts = 2; + + let client = ResponsesClient::new(transport.clone(), provider, NoAuth); + + let prompt = codex_api::Prompt { + instructions: "Say hi".to_string(), + input: vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "hi".to_string(), + }], + }], + tools: Vec::::new(), + parallel_tool_calls: false, + output_schema: None, + }; + + let options = ResponsesOptions::default(); + + let _stream = client.stream_prompt("gpt-test", &prompt, options).await?; + assert_eq!(transport.attempts(), 2); + Ok(()) +} diff --git a/codex-rs/codex-api/tests/sse_end_to_end.rs b/codex-rs/codex-api/tests/sse_end_to_end.rs new file mode 100644 index 000000000..b91cf3a5d --- /dev/null +++ b/codex-rs/codex-api/tests/sse_end_to_end.rs @@ -0,0 +1,229 @@ +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; +use codex_api::AggregateStreamExt; +use codex_api::AuthProvider; +use codex_api::Provider; +use codex_api::ResponseEvent; +use codex_api::ResponsesClient; +use codex_api::WireApi; +use codex_client::HttpTransport; +use codex_client::Request; +use codex_client::Response; +use codex_client::StreamResponse; +use codex_client::TransportError; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use futures::StreamExt; +use http::HeaderMap; +use http::StatusCode; +use pretty_assertions::assert_eq; +use serde_json::Value; + +#[derive(Clone)] +struct FixtureSseTransport { + body: String, +} + +impl FixtureSseTransport { + fn new(body: String) -> Self { + Self { body } + } +} + +#[async_trait] +impl HttpTransport for FixtureSseTransport { + async fn execute(&self, _req: Request) -> Result { + Err(TransportError::Build("execute should not run".to_string())) + } + + async fn stream(&self, _req: Request) -> Result { + let stream = futures::stream::iter(vec![Ok::(Bytes::from( + self.body.clone(), + ))]); + Ok(StreamResponse { + status: StatusCode::OK, + headers: HeaderMap::new(), + bytes: Box::pin(stream), + }) + } +} + +#[derive(Clone, Default)] +struct NoAuth; + +impl AuthProvider for NoAuth { + fn bearer_token(&self) -> Option { + None + } +} + +fn provider(name: &str, wire: WireApi) -> Provider { + Provider { + name: name.to_string(), + base_url: "https://example.com/v1".to_string(), + query_params: None, + wire, + headers: HeaderMap::new(), + retry: codex_api::provider::RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: false, + retry_transport: true, + }, + stream_idle_timeout: Duration::from_millis(50), + } +} + +fn build_responses_body(events: Vec) -> String { + let mut body = String::new(); + for e in events { + let kind = e + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or_else(|| panic!("fixture event missing type in SSE fixture: {e}")); + if e.as_object().map(|o| o.len() == 1).unwrap_or(false) { + body.push_str(&format!("event: {kind}\n\n")); + } else { + body.push_str(&format!("event: {kind}\ndata: {e}\n\n")); + } + } + body +} + +#[tokio::test] +async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> { + let item1 = serde_json::json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello"}] + } + }); + + let item2 = serde_json::json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "World"}] + } + }); + + let completed = serde_json::json!({ + "type": "response.completed", + "response": { "id": "resp1" } + }); + + let body = build_responses_body(vec![item1, item2, completed]); + let transport = FixtureSseTransport::new(body); + let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth); + + let mut stream = client + .stream(serde_json::json!({"echo": true}), HeaderMap::new()) + .await?; + + let mut events = Vec::new(); + while let Some(ev) = stream.next().await { + events.push(ev?); + } + + let events: Vec = events + .into_iter() + .filter(|ev| !matches!(ev, ResponseEvent::RateLimits(_))) + .collect(); + + assert_eq!(events.len(), 3); + + match &events[0] { + ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }) => { + assert_eq!(role, "assistant"); + } + other => panic!("unexpected first event: {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }) => { + assert_eq!(role, "assistant"); + } + other => panic!("unexpected second event: {other:?}"), + } + + match &events[2] { + ResponseEvent::Completed { + response_id, + token_usage, + } => { + assert_eq!(response_id, "resp1"); + assert!(token_usage.is_none()); + } + other => panic!("unexpected third event: {other:?}"), + } + + Ok(()) +} + +#[tokio::test] +async fn responses_stream_aggregates_output_text_deltas() -> Result<()> { + let delta1 = serde_json::json!({ + "type": "response.output_text.delta", + "delta": "Hello, " + }); + + let delta2 = serde_json::json!({ + "type": "response.output_text.delta", + "delta": "world" + }); + + let completed = serde_json::json!({ + "type": "response.completed", + "response": { "id": "resp-agg" } + }); + + let body = build_responses_body(vec![delta1, delta2, completed]); + let transport = FixtureSseTransport::new(body); + let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth); + + let stream = client + .stream(serde_json::json!({"echo": true}), HeaderMap::new()) + .await?; + + let mut stream = stream.aggregate(); + let mut events = Vec::new(); + while let Some(ev) = stream.next().await { + events.push(ev?); + } + + let events: Vec = events + .into_iter() + .filter(|ev| !matches!(ev, ResponseEvent::RateLimits(_))) + .collect(); + + assert_eq!(events.len(), 2); + + match &events[0] { + ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => { + let mut aggregated = String::new(); + for item in content { + if let ContentItem::OutputText { text } = item { + aggregated.push_str(text); + } + } + assert_eq!(aggregated, "Hello, world"); + } + other => panic!("unexpected first event: {other:?}"), + } + + match &events[1] { + ResponseEvent::Completed { response_id, .. } => { + assert_eq!(response_id, "resp-agg"); + } + other => panic!("unexpected second event: {other:?}"), + } + + Ok(()) +} diff --git a/codex-rs/codex-client/Cargo.toml b/codex-rs/codex-client/Cargo.toml new file mode 100644 index 000000000..2defe12fd --- /dev/null +++ b/codex-rs/codex-client/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "codex-client" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +async-trait = { workspace = true } +bytes = { workspace = true } +futures = { workspace = true } +http = { workspace = true } +reqwest = { workspace = true, features = ["json", "stream"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "time", "sync"] } +rand = { workspace = true } +eventsource-stream = { workspace = true } + +[lints] +workspace = true diff --git a/codex-rs/codex-client/README.md b/codex-rs/codex-client/README.md new file mode 100644 index 000000000..045ee7b34 --- /dev/null +++ b/codex-rs/codex-client/README.md @@ -0,0 +1,8 @@ +# codex-client + +Generic transport layer that wraps HTTP requests, retries, and streaming primitives without any Codex/OpenAI awareness. + +- Defines `HttpTransport` and a default `ReqwestTransport` plus thin `Request`/`Response` types. +- Provides retry utilities (`RetryPolicy`, `RetryOn`, `run_with_retry`, `backoff`) that callers plug into for unary and streaming calls. +- Supplies the `sse_stream` helper to turn byte streams into raw SSE `data:` frames with idle timeouts and surfaced stream errors. +- Consumed by higher-level crates like `codex-api`; it stays neutral on endpoints, headers, or API-specific error shapes. diff --git a/codex-rs/codex-client/src/error.rs b/codex-rs/codex-client/src/error.rs new file mode 100644 index 000000000..086b91a50 --- /dev/null +++ b/codex-rs/codex-client/src/error.rs @@ -0,0 +1,29 @@ +use http::HeaderMap; +use http::StatusCode; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum TransportError { + #[error("http {status}: {body:?}")] + Http { + status: StatusCode, + headers: Option, + body: Option, + }, + #[error("retry limit reached")] + RetryLimit, + #[error("timeout")] + Timeout, + #[error("network error: {0}")] + Network(String), + #[error("request build error: {0}")] + Build(String), +} + +#[derive(Debug, Error)] +pub enum StreamError { + #[error("stream failed: {0}")] + Stream(String), + #[error("timeout")] + Timeout, +} diff --git a/codex-rs/codex-client/src/lib.rs b/codex-rs/codex-client/src/lib.rs new file mode 100644 index 000000000..3ac00a21a --- /dev/null +++ b/codex-rs/codex-client/src/lib.rs @@ -0,0 +1,21 @@ +mod error; +mod request; +mod retry; +mod sse; +mod telemetry; +mod transport; + +pub use crate::error::StreamError; +pub use crate::error::TransportError; +pub use crate::request::Request; +pub use crate::request::Response; +pub use crate::retry::RetryOn; +pub use crate::retry::RetryPolicy; +pub use crate::retry::backoff; +pub use crate::retry::run_with_retry; +pub use crate::sse::sse_stream; +pub use crate::telemetry::RequestTelemetry; +pub use crate::transport::ByteStream; +pub use crate::transport::HttpTransport; +pub use crate::transport::ReqwestTransport; +pub use crate::transport::StreamResponse; diff --git a/codex-rs/codex-client/src/request.rs b/codex-rs/codex-client/src/request.rs new file mode 100644 index 000000000..f3d205de9 --- /dev/null +++ b/codex-rs/codex-client/src/request.rs @@ -0,0 +1,39 @@ +use bytes::Bytes; +use http::Method; +use reqwest::header::HeaderMap; +use serde::Serialize; +use serde_json::Value; +use std::time::Duration; + +#[derive(Debug, Clone)] +pub struct Request { + pub method: Method, + pub url: String, + pub headers: HeaderMap, + pub body: Option, + pub timeout: Option, +} + +impl Request { + pub fn new(method: Method, url: String) -> Self { + Self { + method, + url, + headers: HeaderMap::new(), + body: None, + timeout: None, + } + } + + pub fn with_json(mut self, body: &T) -> Self { + self.body = serde_json::to_value(body).ok(); + self + } +} + +#[derive(Debug, Clone)] +pub struct Response { + pub status: http::StatusCode, + pub headers: HeaderMap, + pub body: Bytes, +} diff --git a/codex-rs/codex-client/src/retry.rs b/codex-rs/codex-client/src/retry.rs new file mode 100644 index 000000000..c7bdd34b1 --- /dev/null +++ b/codex-rs/codex-client/src/retry.rs @@ -0,0 +1,73 @@ +use crate::error::TransportError; +use crate::request::Request; +use rand::Rng; +use std::future::Future; +use std::time::Duration; +use tokio::time::sleep; + +#[derive(Debug, Clone)] +pub struct RetryPolicy { + pub max_attempts: u64, + pub base_delay: Duration, + pub retry_on: RetryOn, +} + +#[derive(Debug, Clone)] +pub struct RetryOn { + pub retry_429: bool, + pub retry_5xx: bool, + pub retry_transport: bool, +} + +impl RetryOn { + pub fn should_retry(&self, err: &TransportError, attempt: u64, max_attempts: u64) -> bool { + if attempt >= max_attempts { + return false; + } + match err { + TransportError::Http { status, .. } => { + (self.retry_429 && status.as_u16() == 429) + || (self.retry_5xx && status.is_server_error()) + } + TransportError::Timeout | TransportError::Network(_) => self.retry_transport, + _ => false, + } + } +} + +pub fn backoff(base: Duration, attempt: u64) -> Duration { + if attempt == 0 { + return base; + } + let exp = 2u64.saturating_pow(attempt as u32 - 1); + let millis = base.as_millis() as u64; + let raw = millis.saturating_mul(exp); + let jitter: f64 = rand::rng().random_range(0.9..1.1); + Duration::from_millis((raw as f64 * jitter) as u64) +} + +pub async fn run_with_retry( + policy: RetryPolicy, + mut make_req: impl FnMut() -> Request, + op: F, +) -> Result +where + F: Fn(Request, u64) -> Fut, + Fut: Future>, +{ + for attempt in 0..=policy.max_attempts { + let req = make_req(); + match op(req, attempt).await { + Ok(resp) => return Ok(resp), + Err(err) + if policy + .retry_on + .should_retry(&err, attempt, policy.max_attempts) => + { + sleep(backoff(policy.base_delay, attempt + 1)).await; + } + Err(err) => return Err(err), + } + } + Err(TransportError::RetryLimit) +} diff --git a/codex-rs/codex-client/src/sse.rs b/codex-rs/codex-client/src/sse.rs new file mode 100644 index 000000000..f3aba3a2c --- /dev/null +++ b/codex-rs/codex-client/src/sse.rs @@ -0,0 +1,48 @@ +use crate::error::StreamError; +use crate::transport::ByteStream; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use tokio::sync::mpsc; +use tokio::time::Duration; +use tokio::time::timeout; + +/// Minimal SSE helper that forwards raw `data:` frames as UTF-8 strings. +/// +/// Errors and idle timeouts are sent as `Err(StreamError)` before the task exits. +pub fn sse_stream( + stream: ByteStream, + idle_timeout: Duration, + tx: mpsc::Sender>, +) { + tokio::spawn(async move { + let mut stream = stream + .map(|res| res.map_err(|e| StreamError::Stream(e.to_string()))) + .eventsource(); + + loop { + match timeout(idle_timeout, stream.next()).await { + Ok(Some(Ok(ev))) => { + if tx.send(Ok(ev.data.clone())).await.is_err() { + return; + } + } + Ok(Some(Err(e))) => { + let _ = tx.send(Err(StreamError::Stream(e.to_string()))).await; + return; + } + Ok(None) => { + let _ = tx + .send(Err(StreamError::Stream( + "stream closed before completion".into(), + ))) + .await; + return; + } + Err(_) => { + let _ = tx.send(Err(StreamError::Timeout)).await; + return; + } + } + } + }); +} diff --git a/codex-rs/codex-client/src/telemetry.rs b/codex-rs/codex-client/src/telemetry.rs new file mode 100644 index 000000000..457d47f4f --- /dev/null +++ b/codex-rs/codex-client/src/telemetry.rs @@ -0,0 +1,14 @@ +use crate::error::TransportError; +use http::StatusCode; +use std::time::Duration; + +/// API specific telemetry. +pub trait RequestTelemetry: Send + Sync { + fn on_request( + &self, + attempt: u64, + status: Option, + error: Option<&TransportError>, + duration: Duration, + ); +} diff --git a/codex-rs/codex-client/src/transport.rs b/codex-rs/codex-client/src/transport.rs new file mode 100644 index 000000000..f64e72fb6 --- /dev/null +++ b/codex-rs/codex-client/src/transport.rs @@ -0,0 +1,107 @@ +use crate::error::TransportError; +use crate::request::Request; +use crate::request::Response; +use async_trait::async_trait; +use bytes::Bytes; +use futures::StreamExt; +use futures::stream::BoxStream; +use http::HeaderMap; +use http::Method; +use http::StatusCode; + +pub type ByteStream = BoxStream<'static, Result>; + +pub struct StreamResponse { + pub status: StatusCode, + pub headers: HeaderMap, + pub bytes: ByteStream, +} + +#[async_trait] +pub trait HttpTransport: Send + Sync { + async fn execute(&self, req: Request) -> Result; + async fn stream(&self, req: Request) -> Result; +} + +#[derive(Clone, Debug)] +pub struct ReqwestTransport { + client: reqwest::Client, +} + +impl ReqwestTransport { + pub fn new(client: reqwest::Client) -> Self { + Self { client } + } + + fn build(&self, req: Request) -> Result { + let mut builder = self + .client + .request( + Method::from_bytes(req.method.as_str().as_bytes()).unwrap_or(Method::GET), + &req.url, + ) + .headers(req.headers); + if let Some(timeout) = req.timeout { + builder = builder.timeout(timeout); + } + if let Some(body) = req.body { + builder = builder.json(&body); + } + Ok(builder) + } + + fn map_error(err: reqwest::Error) -> TransportError { + if err.is_timeout() { + TransportError::Timeout + } else { + TransportError::Network(err.to_string()) + } + } +} + +#[async_trait] +impl HttpTransport for ReqwestTransport { + async fn execute(&self, req: Request) -> Result { + let builder = self.build(req)?; + let resp = builder.send().await.map_err(Self::map_error)?; + let status = resp.status(); + let headers = resp.headers().clone(); + let bytes = resp.bytes().await.map_err(Self::map_error)?; + if !status.is_success() { + let body = String::from_utf8(bytes.to_vec()).ok(); + return Err(TransportError::Http { + status, + headers: Some(headers), + body, + }); + } + Ok(Response { + status, + headers, + body: bytes, + }) + } + + async fn stream(&self, req: Request) -> Result { + let builder = self.build(req)?; + let resp = builder.send().await.map_err(Self::map_error)?; + let status = resp.status(); + let headers = resp.headers().clone(); + if !status.is_success() { + let body = resp.text().await.ok(); + return Err(TransportError::Http { + status, + headers: Some(headers), + body, + }); + } + let stream = resp + .bytes_stream() + .map(|result| result.map_err(Self::map_error)); + Ok(StreamResponse { + status, + headers, + bytes: Box::pin(stream), + }) + } +} diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index d6db71973..705bab57d 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -18,12 +18,12 @@ askama = { workspace = true } async-channel = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } -bytes = { workspace = true } chrono = { workspace = true, features = ["serde"] } chardetng = { workspace = true } codex-app-server-protocol = { workspace = true } codex-apply-patch = { workspace = true } codex-async-utils = { workspace = true } +codex-api = { workspace = true } codex-execpolicy = { workspace = true } codex-file-search = { workspace = true } codex-git = { workspace = true } diff --git a/codex-rs/core/src/api_bridge.rs b/codex-rs/core/src/api_bridge.rs new file mode 100644 index 000000000..b9f802ae6 --- /dev/null +++ b/codex-rs/core/src/api_bridge.rs @@ -0,0 +1,154 @@ +use chrono::DateTime; +use chrono::Utc; +use codex_api::AuthProvider as ApiAuthProvider; +use codex_api::TransportError; +use codex_api::error::ApiError; +use codex_api::rate_limits::parse_rate_limit; +use http::HeaderMap; +use serde::Deserialize; + +use crate::auth::CodexAuth; +use crate::error::CodexErr; +use crate::error::RetryLimitReachedError; +use crate::error::UnexpectedResponseError; +use crate::error::UsageLimitReachedError; +use crate::model_provider_info::ModelProviderInfo; +use crate::token_data::PlanType; + +pub(crate) fn map_api_error(err: ApiError) -> CodexErr { + match err { + ApiError::ContextWindowExceeded => CodexErr::ContextWindowExceeded, + ApiError::QuotaExceeded => CodexErr::QuotaExceeded, + ApiError::UsageNotIncluded => CodexErr::UsageNotIncluded, + ApiError::Retryable { message, delay } => CodexErr::Stream(message, delay), + ApiError::Stream(msg) => CodexErr::Stream(msg, None), + ApiError::Api { status, message } => CodexErr::UnexpectedStatus(UnexpectedResponseError { + status, + body: message, + request_id: None, + }), + ApiError::Transport(transport) => match transport { + TransportError::Http { + status, + headers, + body, + } => { + if status == http::StatusCode::INTERNAL_SERVER_ERROR { + CodexErr::InternalServerError + } else if status == http::StatusCode::TOO_MANY_REQUESTS { + if let Some(body) = body + && let Ok(err) = serde_json::from_str::(&body) + { + if err.error.error_type.as_deref() == Some("usage_limit_reached") { + let rate_limits = headers.as_ref().and_then(parse_rate_limit); + let resets_at = err + .error + .resets_at + .and_then(|seconds| DateTime::::from_timestamp(seconds, 0)); + return CodexErr::UsageLimitReached(UsageLimitReachedError { + plan_type: err.error.plan_type, + resets_at, + rate_limits, + }); + } else if err.error.error_type.as_deref() == Some("usage_not_included") { + return CodexErr::UsageNotIncluded; + } + } + + CodexErr::RetryLimit(RetryLimitReachedError { + status, + request_id: extract_request_id(headers.as_ref()), + }) + } else { + CodexErr::UnexpectedStatus(UnexpectedResponseError { + status, + body: body.unwrap_or_default(), + request_id: extract_request_id(headers.as_ref()), + }) + } + } + TransportError::RetryLimit => CodexErr::RetryLimit(RetryLimitReachedError { + status: http::StatusCode::INTERNAL_SERVER_ERROR, + request_id: None, + }), + TransportError::Timeout => CodexErr::Timeout, + TransportError::Network(msg) | TransportError::Build(msg) => { + CodexErr::Stream(msg, None) + } + }, + ApiError::RateLimit(msg) => CodexErr::Stream(msg, None), + } +} + +fn extract_request_id(headers: Option<&HeaderMap>) -> Option { + headers.and_then(|map| { + ["cf-ray", "x-request-id", "x-oai-request-id"] + .iter() + .find_map(|name| { + map.get(*name) + .and_then(|v| v.to_str().ok()) + .map(str::to_string) + }) + }) +} + +pub(crate) async fn auth_provider_from_auth( + auth: Option, + provider: &ModelProviderInfo, +) -> crate::error::Result { + if let Some(api_key) = provider.api_key()? { + return Ok(CoreAuthProvider { + token: Some(api_key), + account_id: None, + }); + } + + if let Some(token) = provider.experimental_bearer_token.clone() { + return Ok(CoreAuthProvider { + token: Some(token), + account_id: None, + }); + } + + if let Some(auth) = auth { + let token = auth.get_token().await?; + Ok(CoreAuthProvider { + token: Some(token), + account_id: auth.get_account_id(), + }) + } else { + Ok(CoreAuthProvider { + token: None, + account_id: None, + }) + } +} + +#[derive(Debug, Deserialize)] +struct UsageErrorResponse { + error: UsageErrorBody, +} + +#[derive(Debug, Deserialize)] +struct UsageErrorBody { + #[serde(rename = "type")] + error_type: Option, + plan_type: Option, + resets_at: Option, +} + +#[derive(Clone, Default)] +pub(crate) struct CoreAuthProvider { + token: Option, + account_id: Option, +} + +impl ApiAuthProvider for CoreAuthProvider { + fn bearer_token(&self) -> Option { + self.token.clone() + } + + fn account_id(&self) -> Option { + self.account_id.clone() + } +} diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs deleted file mode 100644 index 785a4d4ce..000000000 --- a/codex-rs/core/src/chat_completions.rs +++ /dev/null @@ -1,981 +0,0 @@ -use std::time::Duration; - -use crate::ModelProviderInfo; -use crate::client_common::Prompt; -use crate::client_common::ResponseEvent; -use crate::client_common::ResponseStream; -use crate::default_client::CodexHttpClient; -use crate::error::CodexErr; -use crate::error::ConnectionFailedError; -use crate::error::ResponseStreamFailed; -use crate::error::Result; -use crate::error::RetryLimitReachedError; -use crate::error::UnexpectedResponseError; -use crate::model_family::ModelFamily; -use crate::tools::spec::create_tools_json_for_chat_completions_api; -use crate::util::backoff; -use bytes::Bytes; -use codex_otel::otel_event_manager::OtelEventManager; -use codex_protocol::models::ContentItem; -use codex_protocol::models::FunctionCallOutputContentItem; -use codex_protocol::models::ReasoningItemContent; -use codex_protocol::models::ResponseItem; -use codex_protocol::protocol::SessionSource; -use codex_protocol::protocol::SubAgentSource; -use eventsource_stream::Eventsource; -use futures::Stream; -use futures::StreamExt; -use futures::TryStreamExt; -use reqwest::StatusCode; -use serde_json::json; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use tokio::sync::mpsc; -use tokio::time::timeout; -use tracing::debug; -use tracing::trace; - -/// Implementation for the classic Chat Completions API. -pub(crate) async fn stream_chat_completions( - prompt: &Prompt, - model_family: &ModelFamily, - client: &CodexHttpClient, - provider: &ModelProviderInfo, - otel_event_manager: &OtelEventManager, - session_source: &SessionSource, -) -> Result { - if prompt.output_schema.is_some() { - return Err(CodexErr::UnsupportedOperation( - "output_schema is not supported for Chat Completions API".to_string(), - )); - } - - // Build messages array - let mut messages = Vec::::new(); - - let full_instructions = prompt.get_full_instructions(model_family); - messages.push(json!({"role": "system", "content": full_instructions})); - - let input = prompt.get_formatted_input(); - - // Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user. - // - If the last emitted message is a user message, drop all reasoning. - // - Otherwise, for each Reasoning item after the last user message, attach it - // to the immediate previous assistant message (stop turns) or the immediate - // next assistant anchor (tool-call turns: function/local shell call, or assistant message). - let mut reasoning_by_anchor_index: std::collections::HashMap = - std::collections::HashMap::new(); - - // Determine the last role that would be emitted to Chat Completions. - let mut last_emitted_role: Option<&str> = None; - for item in &input { - match item { - ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()), - ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { - last_emitted_role = Some("assistant") - } - ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"), - ResponseItem::Reasoning { .. } | ResponseItem::Other => {} - ResponseItem::CustomToolCall { .. } => {} - ResponseItem::CustomToolCallOutput { .. } => {} - ResponseItem::WebSearchCall { .. } => {} - ResponseItem::GhostSnapshot { .. } => {} - ResponseItem::CompactionSummary { .. } => {} - } - } - - // Find the last user message index in the input. - let mut last_user_index: Option = None; - for (idx, item) in input.iter().enumerate() { - if let ResponseItem::Message { role, .. } = item - && role == "user" - { - last_user_index = Some(idx); - } - } - - // Attach reasoning only if the conversation does not end with a user message. - if !matches!(last_emitted_role, Some("user")) { - for (idx, item) in input.iter().enumerate() { - // Only consider reasoning that appears after the last user message. - if let Some(u_idx) = last_user_index - && idx <= u_idx - { - continue; - } - - if let ResponseItem::Reasoning { - content: Some(items), - .. - } = item - { - let mut text = String::new(); - for entry in items { - match entry { - ReasoningItemContent::ReasoningText { text: segment } - | ReasoningItemContent::Text { text: segment } => text.push_str(segment), - } - } - if text.trim().is_empty() { - continue; - } - - // Prefer immediate previous assistant message (stop turns) - let mut attached = false; - if idx > 0 - && let ResponseItem::Message { role, .. } = &input[idx - 1] - && role == "assistant" - { - reasoning_by_anchor_index - .entry(idx - 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - attached = true; - } - - // Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message) - if !attached && idx + 1 < input.len() { - match &input[idx + 1] { - ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { - reasoning_by_anchor_index - .entry(idx + 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - } - ResponseItem::Message { role, .. } if role == "assistant" => { - reasoning_by_anchor_index - .entry(idx + 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - } - _ => {} - } - } - } - } - } - - // Track last assistant text we emitted to avoid duplicate assistant messages - // in the outbound Chat Completions payload (can happen if a final - // aggregated assistant message was recorded alongside an earlier partial). - let mut last_assistant_text: Option = None; - - for (idx, item) in input.iter().enumerate() { - match item { - ResponseItem::Message { role, content, .. } => { - // Build content either as a plain string (typical for assistant text) - // or as an array of content items when images are present (user/tool multimodal). - let mut text = String::new(); - let mut items: Vec = Vec::new(); - let mut saw_image = false; - - for c in content { - match c { - ContentItem::InputText { text: t } - | ContentItem::OutputText { text: t } => { - text.push_str(t); - items.push(json!({"type":"text","text": t})); - } - ContentItem::InputImage { image_url } => { - saw_image = true; - items.push(json!({"type":"image_url","image_url": {"url": image_url}})); - } - } - } - - // Skip exact-duplicate assistant messages. - if role == "assistant" { - if let Some(prev) = &last_assistant_text - && prev == &text - { - continue; - } - last_assistant_text = Some(text.clone()); - } - - // For assistant messages, always send a plain string for compatibility. - // For user messages, if an image is present, send an array of content items. - let content_value = if role == "assistant" { - json!(text) - } else if saw_image { - json!(items) - } else { - json!(text) - }; - - let mut msg = json!({"role": role, "content": content_value}); - if role == "assistant" - && let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::FunctionCall { - name, - arguments, - call_id, - .. - } => { - let mut msg = json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": call_id, - "type": "function", - "function": { - "name": name, - "arguments": arguments, - } - }] - }); - if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::LocalShellCall { - id, - call_id: _, - status, - action, - } => { - // Confirm with API team. - let mut msg = json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": id.clone().unwrap_or_else(|| "".to_string()), - "type": "local_shell_call", - "status": status, - "action": action, - }] - }); - if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::FunctionCallOutput { call_id, output } => { - // Prefer structured content items when available (e.g., images) - // otherwise fall back to the legacy plain-string content. - let content_value = if let Some(items) = &output.content_items { - let mapped: Vec = items - .iter() - .map(|it| match it { - FunctionCallOutputContentItem::InputText { text } => { - json!({"type":"text","text": text}) - } - FunctionCallOutputContentItem::InputImage { image_url } => { - json!({"type":"image_url","image_url": {"url": image_url}}) - } - }) - .collect(); - json!(mapped) - } else { - json!(output.content) - }; - - messages.push(json!({ - "role": "tool", - "tool_call_id": call_id, - "content": content_value, - })); - } - ResponseItem::CustomToolCall { - id, - call_id: _, - name, - input, - status: _, - } => { - messages.push(json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": id, - "type": "custom", - "custom": { - "name": name, - "input": input, - } - }] - })); - } - ResponseItem::CustomToolCallOutput { call_id, output } => { - messages.push(json!({ - "role": "tool", - "tool_call_id": call_id, - "content": output, - })); - } - ResponseItem::GhostSnapshot { .. } => { - // Ghost snapshots annotate history but are not sent to the model. - continue; - } - ResponseItem::Reasoning { .. } - | ResponseItem::WebSearchCall { .. } - | ResponseItem::Other - | ResponseItem::CompactionSummary { .. } => { - // Omit these items from the conversation history. - continue; - } - } - } - - let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; - let payload = json!({ - "model": model_family.slug, - "messages": messages, - "stream": true, - "tools": tools_json, - }); - - debug!( - "POST to {}: {}", - provider.get_full_url(&None), - payload.to_string() - ); - - let mut attempt = 0; - let max_retries = provider.request_max_retries(); - loop { - attempt += 1; - - let mut req_builder = provider.create_request_builder(client, &None).await?; - - // Include subagent header only for subagent sessions. - if let SessionSource::SubAgent(sub) = session_source.clone() { - let subagent = if let SubAgentSource::Other(label) = sub { - label - } else { - serde_json::to_value(&sub) - .ok() - .and_then(|v| v.as_str().map(std::string::ToString::to_string)) - .unwrap_or_else(|| "other".to_string()) - }; - req_builder = req_builder.header("x-openai-subagent", subagent); - } - - let res = otel_event_manager - .log_request(attempt, || { - req_builder - .header(reqwest::header::ACCEPT, "text/event-stream") - .json(&payload) - .send() - }) - .await; - - match res { - Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(1600); - let stream = resp.bytes_stream().map_err(|e| { - CodexErr::ResponseStreamFailed(ResponseStreamFailed { - source: e, - request_id: None, - }) - }); - tokio::spawn(process_chat_sse( - stream, - tx_event, - provider.stream_idle_timeout(), - otel_event_manager.clone(), - )); - return Ok(ResponseStream { rx_event }); - } - Ok(res) => { - let status = res.status(); - if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) { - let body = (res.text().await).unwrap_or_default(); - return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError { - status, - body, - request_id: None, - })); - } - - if attempt > max_retries { - return Err(CodexErr::RetryLimit(RetryLimitReachedError { - status, - request_id: None, - })); - } - - let retry_after_secs = res - .headers() - .get(reqwest::header::RETRY_AFTER) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()); - - let delay = retry_after_secs - .map(|s| Duration::from_millis(s * 1_000)) - .unwrap_or_else(|| backoff(attempt)); - tokio::time::sleep(delay).await; - } - Err(e) => { - if attempt > max_retries { - return Err(CodexErr::ConnectionFailed(ConnectionFailedError { - source: e, - })); - } - let delay = backoff(attempt); - tokio::time::sleep(delay).await; - } - } - } -} - -async fn append_assistant_text( - tx_event: &mpsc::Sender>, - assistant_item: &mut Option, - text: String, -) { - if assistant_item.is_none() { - let item = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![], - }; - *assistant_item = Some(item.clone()); - let _ = tx_event - .send(Ok(ResponseEvent::OutputItemAdded(item))) - .await; - } - - if let Some(ResponseItem::Message { content, .. }) = assistant_item { - content.push(ContentItem::OutputText { text: text.clone() }); - let _ = tx_event - .send(Ok(ResponseEvent::OutputTextDelta(text.clone()))) - .await; - } -} - -async fn append_reasoning_text( - tx_event: &mpsc::Sender>, - reasoning_item: &mut Option, - text: String, -) { - if reasoning_item.is_none() { - let item = ResponseItem::Reasoning { - id: String::new(), - summary: Vec::new(), - content: Some(vec![]), - encrypted_content: None, - }; - *reasoning_item = Some(item.clone()); - let _ = tx_event - .send(Ok(ResponseEvent::OutputItemAdded(item))) - .await; - } - - if let Some(ResponseItem::Reasoning { - content: Some(content), - .. - }) = reasoning_item - { - let content_index = content.len() as i64; - content.push(ReasoningItemContent::ReasoningText { text: text.clone() }); - - let _ = tx_event - .send(Ok(ResponseEvent::ReasoningContentDelta { - delta: text.clone(), - content_index, - })) - .await; - } -} -/// Lightweight SSE processor for the Chat Completions streaming format. The -/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest -/// of the pipeline can stay agnostic of the underlying wire format. -async fn process_chat_sse( - stream: S, - tx_event: mpsc::Sender>, - idle_timeout: Duration, - otel_event_manager: OtelEventManager, -) where - S: Stream> + Unpin, -{ - let mut stream = stream.eventsource(); - - // State to accumulate a function call across streaming chunks. - // OpenAI may split the `arguments` string over multiple `delta` events - // until the chunk whose `finish_reason` is `tool_calls` is emitted. We - // keep collecting the pieces here and forward a single - // `ResponseItem::FunctionCall` once the call is complete. - #[derive(Default)] - struct FunctionCallState { - name: Option, - arguments: String, - call_id: Option, - active: bool, - } - - let mut fn_call_state = FunctionCallState::default(); - let mut assistant_item: Option = None; - let mut reasoning_item: Option = None; - - loop { - let start = std::time::Instant::now(); - let response = timeout(idle_timeout, stream.next()).await; - let duration = start.elapsed(); - otel_event_manager.log_sse_event(&response, duration); - - let sse = match response { - Ok(Some(Ok(ev))) => ev, - Ok(Some(Err(e))) => { - let _ = tx_event - .send(Err(CodexErr::Stream(e.to_string(), None))) - .await; - return; - } - Ok(None) => { - // Stream closed gracefully – emit Completed with dummy id. - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - return; - } - Err(_) => { - let _ = tx_event - .send(Err(CodexErr::Stream( - "idle timeout waiting for SSE".into(), - None, - ))) - .await; - return; - } - }; - - // OpenAI Chat streaming sends a literal string "[DONE]" when finished. - if sse.data.trim() == "[DONE]" { - // Emit any finalized items before closing so downstream consumers receive - // terminal events for both assistant content and raw reasoning. - if let Some(item) = assistant_item { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - if let Some(item) = reasoning_item { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - return; - } - - // Parse JSON chunk - let chunk: serde_json::Value = match serde_json::from_str(&sse.data) { - Ok(v) => v, - Err(_) => continue, - }; - trace!("chat_completions received SSE chunk: {chunk:?}"); - - let choice_opt = chunk.get("choices").and_then(|c| c.get(0)); - - if let Some(choice) = choice_opt { - // Handle assistant content tokens as streaming deltas. - if let Some(content) = choice - .get("delta") - .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()) - && !content.is_empty() - { - append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await; - } - - // Forward any reasoning/thinking deltas if present. - // Some providers stream `reasoning` as a plain string while others - // nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`). - if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) { - let mut maybe_text = reasoning_val - .as_str() - .map(str::to_string) - .filter(|s| !s.is_empty()); - - if maybe_text.is_none() && reasoning_val.is_object() { - if let Some(s) = reasoning_val - .get("text") - .and_then(|t| t.as_str()) - .filter(|s| !s.is_empty()) - { - maybe_text = Some(s.to_string()); - } else if let Some(s) = reasoning_val - .get("content") - .and_then(|t| t.as_str()) - .filter(|s| !s.is_empty()) - { - maybe_text = Some(s.to_string()); - } - } - - if let Some(reasoning) = maybe_text { - // Accumulate so we can emit a terminal Reasoning item at the end. - append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await; - } - } - - // Some providers only include reasoning on the final message object. - if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning")) - { - // Accept either a plain string or an object with { text | content } - if let Some(s) = message_reasoning.as_str() { - if !s.is_empty() { - append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await; - } - } else if let Some(obj) = message_reasoning.as_object() - && let Some(s) = obj - .get("text") - .and_then(|v| v.as_str()) - .or_else(|| obj.get("content").and_then(|v| v.as_str())) - && !s.is_empty() - { - append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await; - } - } - - // Handle streaming function / tool calls. - if let Some(tool_calls) = choice - .get("delta") - .and_then(|d| d.get("tool_calls")) - .and_then(|tc| tc.as_array()) - && let Some(tool_call) = tool_calls.first() - { - // Mark that we have an active function call in progress. - fn_call_state.active = true; - - // Extract call_id if present. - if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) { - fn_call_state.call_id.get_or_insert_with(|| id.to_string()); - } - - // Extract function details if present. - if let Some(function) = tool_call.get("function") { - if let Some(name) = function.get("name").and_then(|n| n.as_str()) { - fn_call_state.name.get_or_insert_with(|| name.to_string()); - } - - if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str()) - { - fn_call_state.arguments.push_str(args_fragment); - } - } - } - - // Emit end-of-turn when finish_reason signals completion. - if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) - && !finish_reason.is_empty() - { - match finish_reason { - "tool_calls" if fn_call_state.active => { - // First, flush the terminal raw reasoning so UIs can finalize - // the reasoning stream before any exec/tool events begin. - if let Some(item) = reasoning_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - // Then emit the FunctionCall response item. - let item = ResponseItem::FunctionCall { - id: None, - name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()), - arguments: fn_call_state.arguments.clone(), - call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new), - }; - - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - "stop" => { - // Regular turn without tool-call. Emit the final assistant message - // as a single OutputItemDone so non-delta consumers see the result. - if let Some(item) = assistant_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - // Also emit a terminal Reasoning item so UIs can finalize raw reasoning. - if let Some(item) = reasoning_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - } - _ => {} - } - - // Emit Completed regardless of reason so the agent can advance. - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - - // Prepare for potential next turn (should not happen in same stream). - // fn_call_state = FunctionCallState::default(); - - return; // End processing for this SSE stream. - } - } - } -} - -/// Optional client-side aggregation helper -/// -/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from -/// [`process_chat_sse`] into a *running* assistant message, **suppressing the -/// per-token deltas**. The stream stays silent while the model is thinking -/// and only emits two events per turn: -/// -/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message -/// (fully concatenated). -/// 2. The original `ResponseEvent::Completed` right after it. -/// -/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers. -/// -/// The adapter is intentionally *lossless*: callers who do **not** opt in via -/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified -/// events. -#[derive(Copy, Clone, Eq, PartialEq)] -enum AggregateMode { - AggregatedOnly, - Streaming, -} -pub(crate) struct AggregatedChatStream { - inner: S, - cumulative: String, - cumulative_reasoning: String, - pending: std::collections::VecDeque, - mode: AggregateMode, -} - -impl Stream for AggregatedChatStream -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - // First, flush any buffered events from the previous call. - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - - loop { - match Pin::new(&mut this.inner).poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { - // If this is an incremental assistant message chunk, accumulate but - // do NOT emit yet. Forward any other item (e.g. FunctionCall) right - // away so downstream consumers see it. - - let is_assistant_message = matches!( - &item, - codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant" - ); - - if is_assistant_message { - match this.mode { - AggregateMode::AggregatedOnly => { - // Only use the final assistant message if we have not - // seen any deltas; otherwise, deltas already built the - // cumulative text and this would duplicate it. - if this.cumulative.is_empty() - && let codex_protocol::models::ResponseItem::Message { - content, - .. - } = &item - && let Some(text) = content.iter().find_map(|c| match c { - codex_protocol::models::ContentItem::OutputText { - text, - } => Some(text), - _ => None, - }) - { - this.cumulative.push_str(text); - } - // Swallow assistant message here; emit on Completed. - continue; - } - AggregateMode::Streaming => { - // In streaming mode, if we have not seen any deltas, forward - // the final assistant message directly. If deltas were seen, - // suppress the final message to avoid duplication. - if this.cumulative.is_empty() { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( - item, - )))); - } else { - continue; - } - } - } - } - - // Not an assistant message – forward immediately. - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); - } - Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); - } - Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - }))) => { - // Build any aggregated items in the correct order: Reasoning first, then Message. - let mut emitted_any = false; - - if !this.cumulative_reasoning.is_empty() - && matches!(this.mode, AggregateMode::AggregatedOnly) - { - let aggregated_reasoning = - codex_protocol::models::ResponseItem::Reasoning { - id: String::new(), - summary: Vec::new(), - content: Some(vec![ - codex_protocol::models::ReasoningItemContent::ReasoningText { - text: std::mem::take(&mut this.cumulative_reasoning), - }, - ]), - encrypted_content: None, - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); - emitted_any = true; - } - - // Always emit the final aggregated assistant message when any - // content deltas have been observed. In AggregatedOnly mode this - // is the sole assistant output; in Streaming mode this finalizes - // the streamed deltas into a terminal OutputItemDone so callers - // can persist/render the message once per turn. - if !this.cumulative.is_empty() { - let aggregated_message = codex_protocol::models::ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![codex_protocol::models::ContentItem::OutputText { - text: std::mem::take(&mut this.cumulative), - }], - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_message)); - emitted_any = true; - } - - // Always emit Completed last when anything was aggregated. - if emitted_any { - this.pending.push_back(ResponseEvent::Completed { - response_id: response_id.clone(), - token_usage: token_usage.clone(), - }); - // Return the first pending event now. - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - } - - // Nothing aggregated – forward Completed directly. - return Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - }))); - } - Poll::Ready(Some(Ok(ResponseEvent::Created))) => { - // These events are exclusive to the Responses API and - // will never appear in a Chat Completions stream. - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { - // Always accumulate deltas so we can emit a final OutputItemDone at Completed. - this.cumulative.push_str(&delta); - if matches!(this.mode, AggregateMode::Streaming) { - // In streaming mode, also forward the delta immediately. - return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))); - } else { - continue; - } - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta { - delta, - content_index, - }))) => { - // Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed. - this.cumulative_reasoning.push_str(&delta); - if matches!(this.mode, AggregateMode::Streaming) { - // In streaming mode, also forward the delta immediately. - return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta { - delta, - content_index, - }))); - } else { - continue; - } - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => { - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => { - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); - } - } - } - } -} - -/// Extension trait that activates aggregation on any stream of [`ResponseEvent`]. -pub(crate) trait AggregateStreamExt: Stream> + Sized { - /// Returns a new stream that emits **only** the final assistant message - /// per turn instead of every incremental delta. The produced - /// `ResponseEvent` sequence for a typical text turn looks like: - /// - /// ```ignore - /// OutputItemDone() - /// Completed - /// ``` - /// - /// No other `OutputItemDone` events will be seen by the caller. - /// - /// Usage: - /// - /// ```ignore - /// let agg_stream = client.stream(&prompt).await?.aggregate(); - /// while let Some(event) = agg_stream.next().await { - /// // event now contains cumulative text - /// } - /// ``` - fn aggregate(self) -> AggregatedChatStream { - AggregatedChatStream::new(self, AggregateMode::AggregatedOnly) - } -} - -impl AggregateStreamExt for T where T: Stream> + Sized {} - -impl AggregatedChatStream { - fn new(inner: S, mode: AggregateMode) -> Self { - AggregatedChatStream { - inner, - cumulative: String::new(), - cumulative_reasoning: String::new(), - pending: std::collections::VecDeque::new(), - mode, - } - } - - pub(crate) fn streaming_mode(inner: S) -> Self { - Self::new(inner, AggregateMode::Streaming) - } -} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 68bd30f4e..82839522c 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,12 +1,22 @@ -use std::io::BufRead; -use std::path::Path; use std::sync::Arc; -use std::sync::OnceLock; -use std::time::Duration; -use bytes::Bytes; -use chrono::DateTime; -use chrono::Utc; +use crate::api_bridge::auth_provider_from_auth; +use crate::api_bridge::map_api_error; +use codex_api::AggregateStreamExt; +use codex_api::ChatClient as ApiChatClient; +use codex_api::CompactClient as ApiCompactClient; +use codex_api::CompactionInput as ApiCompactionInput; +use codex_api::Prompt as ApiPrompt; +use codex_api::RequestTelemetry; +use codex_api::ReqwestTransport; +use codex_api::ResponseStream as ApiResponseStream; +use codex_api::ResponsesClient as ApiResponsesClient; +use codex_api::ResponsesOptions as ApiResponsesOptions; +use codex_api::SseTelemetry; +use codex_api::TransportError; +use codex_api::common::Reasoning; +use codex_api::create_text_param_for_request; +use codex_api::error::ApiError; use codex_app_server_protocol::AuthMode; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::ConversationId; @@ -14,90 +24,40 @@ use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::SessionSource; -use eventsource_stream::Eventsource; -use futures::prelude::*; -use regex_lite::Regex; +use eventsource_stream::Event; +use eventsource_stream::EventStreamError; +use futures::StreamExt; +use http::HeaderMap as ApiHeaderMap; +use http::HeaderValue; +use http::StatusCode as HttpStatusCode; use reqwest::StatusCode; -use reqwest::header::HeaderMap; -use serde::Deserialize; -use serde::Serialize; use serde_json::Value; +use std::time::Duration; use tokio::sync::mpsc; -use tokio::time::timeout; -use tokio_util::io::ReaderStream; -use tracing::debug; -use tracing::enabled; -use tracing::trace; use tracing::warn; use crate::AuthManager; -use crate::auth::CodexAuth; use crate::auth::RefreshTokenError; -use crate::chat_completions::AggregateStreamExt; -use crate::chat_completions::stream_chat_completions; use crate::client_common::Prompt; -use crate::client_common::Reasoning; use crate::client_common::ResponseEvent; use crate::client_common::ResponseStream; -use crate::client_common::ResponsesApiRequest; -use crate::client_common::create_text_param_for_request; use crate::config::Config; -use crate::default_client::CodexHttpClient; -use crate::default_client::create_client; +use crate::default_client::build_reqwest_client; use crate::error::CodexErr; -use crate::error::ConnectionFailedError; -use crate::error::ResponseStreamFailed; use crate::error::Result; -use crate::error::RetryLimitReachedError; -use crate::error::UnexpectedResponseError; -use crate::error::UsageLimitReachedError; use crate::flags::CODEX_RS_SSE_FIXTURE; use crate::model_family::ModelFamily; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; use crate::openai_model_info::get_model_info; -use crate::protocol::CreditsSnapshot; -use crate::protocol::RateLimitSnapshot; -use crate::protocol::RateLimitWindow; -use crate::protocol::TokenUsage; -use crate::token_data::PlanType; +use crate::tools::spec::create_tools_json_for_chat_completions_api; use crate::tools::spec::create_tools_json_for_responses_api; -use crate::util::backoff; - -#[derive(Debug, Deserialize)] -struct ErrorResponse { - error: Error, -} - -#[derive(Debug, Deserialize)] -struct Error { - r#type: Option, - code: Option, - message: Option, - - // Optional fields available on "usage_limit_reached" and "usage_not_included" errors - plan_type: Option, - resets_at: Option, -} - -#[derive(Debug, Serialize)] -struct CompactHistoryRequest<'a> { - model: &'a str, - input: &'a [ResponseItem], - instructions: &'a str, -} - -#[derive(Debug, Deserialize)] -struct CompactHistoryResponse { - output: Vec, -} #[derive(Debug, Clone)] pub struct ModelClient { config: Arc, auth_manager: Option>, otel_event_manager: OtelEventManager, - client: CodexHttpClient, provider: ModelProviderInfo, conversation_id: ConversationId, effort: Option, @@ -117,13 +77,10 @@ impl ModelClient { conversation_id: ConversationId, session_source: SessionSource, ) -> Self { - let client = create_client(); - Self { config, auth_manager, otel_event_manager, - client, provider, conversation_id, effort, @@ -154,65 +111,102 @@ impl ModelClient { &self.provider } + /// Streams a single model turn using either the Responses or Chat + /// Completions wire API, depending on the configured provider. + /// + /// For Chat providers, the underlying stream is optionally aggregated + /// based on the `show_raw_agent_reasoning` flag in the config. pub async fn stream(&self, prompt: &Prompt) -> Result { match self.provider.wire_api { - WireApi::Responses => self.stream_responses(prompt).await, + WireApi::Responses => self.stream_responses_api(prompt).await, WireApi::Chat => { - // Create the raw streaming connection first. - let response_stream = stream_chat_completions( - prompt, - &self.config.model_family, - &self.client, - &self.provider, - &self.otel_event_manager, - &self.session_source, - ) - .await?; + let api_stream = self.stream_chat_completions(prompt).await?; - // Wrap it with the aggregation adapter so callers see *only* - // the final assistant message per turn (matching the - // behaviour of the Responses API). - let mut aggregated = if self.config.show_raw_agent_reasoning { - crate::chat_completions::AggregatedChatStream::streaming_mode(response_stream) + if self.config.show_raw_agent_reasoning { + Ok(map_response_stream( + api_stream.streaming_mode(), + self.otel_event_manager.clone(), + )) } else { - response_stream.aggregate() - }; - - // Bridge the aggregated stream back into a standard - // `ResponseStream` by forwarding events through a channel. - let (tx, rx) = mpsc::channel::>(16); - - tokio::spawn(async move { - use futures::StreamExt; - while let Some(ev) = aggregated.next().await { - // Exit early if receiver hung up. - if tx.send(ev).await.is_err() { - break; - } - } - }); - - Ok(ResponseStream { rx_event: rx }) + Ok(map_response_stream( + api_stream.aggregate(), + self.otel_event_manager.clone(), + )) + } } } } - /// Implementation for the OpenAI *Responses* experimental API. - async fn stream_responses(&self, prompt: &Prompt) -> Result { - if let Some(path) = &*CODEX_RS_SSE_FIXTURE { - // short circuit for tests - warn!(path, "Streaming from fixture"); - return stream_from_fixture( - path, - self.provider.clone(), - self.otel_event_manager.clone(), - ) - .await; + /// Streams a turn via the OpenAI Chat Completions API. + /// + /// This path is only used when the provider is configured with + /// `WireApi::Chat`; it does not support `output_schema` today. + async fn stream_chat_completions(&self, prompt: &Prompt) -> Result { + if prompt.output_schema.is_some() { + return Err(CodexErr::UnsupportedOperation( + "output_schema is not supported for Chat Completions API".to_string(), + )); } let auth_manager = self.auth_manager.clone(); + let instructions = prompt + .get_full_instructions(&self.config.model_family) + .into_owned(); + let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; + let api_prompt = build_api_prompt(prompt, instructions, tools_json); + let conversation_id = self.conversation_id.to_string(); + let session_source = self.session_source.clone(); - let full_instructions = prompt.get_full_instructions(&self.config.model_family); + let mut refreshed = false; + loop { + let auth = auth_manager.as_ref().and_then(|m| m.auth()); + let api_provider = self + .provider + .to_api_provider(auth.as_ref().map(|a| a.mode))?; + let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?; + let transport = ReqwestTransport::new(build_reqwest_client()); + let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry(); + let client = ApiChatClient::new(transport, api_provider, api_auth) + .with_telemetry(Some(request_telemetry), Some(sse_telemetry)); + + let stream_result = client + .stream_prompt( + &self.config.model, + &api_prompt, + Some(conversation_id.clone()), + Some(session_source.clone()), + ) + .await; + + match stream_result { + Ok(stream) => return Ok(stream), + Err(ApiError::Transport(TransportError::Http { status, .. })) + if status == StatusCode::UNAUTHORIZED => + { + handle_unauthorized(status, &mut refreshed, &auth_manager, &auth).await?; + continue; + } + Err(err) => return Err(map_api_error(err)), + } + } + } + + /// Streams a turn via the OpenAI Responses API. + /// + /// Handles SSE fixtures, reasoning summaries, verbosity, and the + /// `text` controls used for output schemas. + async fn stream_responses_api(&self, prompt: &Prompt) -> Result { + if let Some(path) = &*CODEX_RS_SSE_FIXTURE { + warn!(path, "Streaming from fixture"); + let stream = codex_api::stream_from_fixture(path, self.provider.stream_idle_timeout()) + .map_err(map_api_error)?; + return Ok(map_response_stream(stream, self.otel_event_manager.clone())); + } + + let auth_manager = self.auth_manager.clone(); + let instructions = prompt + .get_full_instructions(&self.config.model_family) + .into_owned(); let tools_json: Vec = create_tools_json_for_responses_api(&prompt.tools)?; let reasoning = if self.config.model_family.supports_reasoning_summaries { @@ -232,8 +226,6 @@ impl ModelClient { vec![] }; - let input_with_instructions = prompt.get_formatted_input(); - let verbosity = if self.config.model_family.support_verbosity { self.config .model_verbosity @@ -248,241 +240,49 @@ impl ModelClient { None }; - // Only include `text.verbosity` for GPT-5 family models let text = create_text_param_for_request(verbosity, &prompt.output_schema); + let api_prompt = build_api_prompt(prompt, instructions.clone(), tools_json); + let conversation_id = self.conversation_id.to_string(); + let session_source = self.session_source.clone(); - // In general, we want to explicitly send `store: false` when using the Responses API, - // but in practice, the Azure Responses API rejects `store: false`: - // - // - If store = false and id is sent an error is thrown that ID is not found - // - If store = false and id is not sent an error is thrown that ID is required - // - // For Azure, we send `store: true` and preserve reasoning item IDs. - let azure_workaround = self.provider.is_azure_responses_endpoint(); + let mut refreshed = false; + loop { + let auth = auth_manager.as_ref().and_then(|m| m.auth()); + let api_provider = self + .provider + .to_api_provider(auth.as_ref().map(|a| a.mode))?; + let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?; + let transport = ReqwestTransport::new(build_reqwest_client()); + let (request_telemetry, sse_telemetry) = self.build_streaming_telemetry(); + let client = ApiResponsesClient::new(transport, api_provider, api_auth) + .with_telemetry(Some(request_telemetry), Some(sse_telemetry)); - let payload = ResponsesApiRequest { - model: &self.config.model, - instructions: &full_instructions, - input: &input_with_instructions, - tools: &tools_json, - tool_choice: "auto", - parallel_tool_calls: prompt.parallel_tool_calls, - reasoning, - store: azure_workaround, - stream: true, - include, - prompt_cache_key: Some(self.conversation_id.to_string()), - text, - }; - - let mut payload_json = serde_json::to_value(&payload)?; - if azure_workaround { - attach_item_ids(&mut payload_json, &input_with_instructions); - } - - let max_attempts = self.provider.request_max_retries(); - for attempt in 0..=max_attempts { - match self - .attempt_stream_responses(attempt, &payload_json, &auth_manager) - .await - { - Ok(stream) => { - return Ok(stream); - } - Err(StreamAttemptError::Fatal(e)) => { - return Err(e); - } - Err(retryable_attempt_error) => { - if attempt == max_attempts { - return Err(retryable_attempt_error.into_error()); - } - - tokio::time::sleep(retryable_attempt_error.delay(attempt)).await; - } - } - } - - unreachable!("stream_responses_attempt should always return"); - } - - /// Single attempt to start a streaming Responses API call. - async fn attempt_stream_responses( - &self, - attempt: u64, - payload_json: &Value, - auth_manager: &Option>, - ) -> std::result::Result { - // Always fetch the latest auth in case a prior attempt refreshed the token. - let auth = auth_manager.as_ref().and_then(|m| m.auth()); - - trace!( - "POST to {}: {}", - self.provider.get_full_url(&auth), - payload_json.to_string() - ); - - let mut req_builder = self - .provider - .create_request_builder(&self.client, &auth) - .await - .map_err(StreamAttemptError::Fatal)?; - - // Include subagent header only for subagent sessions. - if let SessionSource::SubAgent(sub) = &self.session_source { - let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub { - label.clone() - } else { - serde_json::to_value(sub) - .ok() - .and_then(|v| v.as_str().map(std::string::ToString::to_string)) - .unwrap_or_else(|| "other".to_string()) + let options = ApiResponsesOptions { + reasoning: reasoning.clone(), + include: include.clone(), + prompt_cache_key: Some(conversation_id.clone()), + text: text.clone(), + store_override: None, + conversation_id: Some(conversation_id.clone()), + session_source: Some(session_source.clone()), }; - req_builder = req_builder.header("x-openai-subagent", subagent); - } - req_builder = req_builder - // Send session_id for compatibility. - .header("conversation_id", self.conversation_id.to_string()) - .header("session_id", self.conversation_id.to_string()) - .header(reqwest::header::ACCEPT, "text/event-stream") - .json(payload_json); + let stream_result = client + .stream_prompt(&self.config.model, &api_prompt, options) + .await; - if let Some(auth) = auth.as_ref() - && auth.mode == AuthMode::ChatGPT - && let Some(account_id) = auth.get_account_id() - { - req_builder = req_builder.header("chatgpt-account-id", account_id); - } - - let res = self - .otel_event_manager - .log_request(attempt, || req_builder.send()) - .await; - - let mut request_id = None; - if let Ok(resp) = &res { - request_id = resp - .headers() - .get("cf-ray") - .map(|v| v.to_str().unwrap_or_default().to_string()); - } - - match res { - Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(1600); - - if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers()) - && tx_event - .send(Ok(ResponseEvent::RateLimits(snapshot))) - .await - .is_err() - { - debug!("receiver dropped rate limit snapshot event"); + match stream_result { + Ok(stream) => { + return Ok(map_response_stream(stream, self.otel_event_manager.clone())); } - - // spawn task to process SSE - let stream = resp.bytes_stream().map_err(move |e| { - CodexErr::ResponseStreamFailed(ResponseStreamFailed { - source: e, - request_id: request_id.clone(), - }) - }); - tokio::spawn(process_sse( - stream, - tx_event, - self.provider.stream_idle_timeout(), - self.otel_event_manager.clone(), - )); - - Ok(ResponseStream { rx_event }) + Err(ApiError::Transport(TransportError::Http { status, .. })) + if status == StatusCode::UNAUTHORIZED => + { + handle_unauthorized(status, &mut refreshed, &auth_manager, &auth).await?; + continue; + } + Err(err) => return Err(map_api_error(err)), } - Ok(res) => { - let status = res.status(); - - // Pull out Retry‑After header if present. - let retry_after_secs = res - .headers() - .get(reqwest::header::RETRY_AFTER) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()); - let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000)); - - if status == StatusCode::UNAUTHORIZED - && let Some(manager) = auth_manager.as_ref() - && let Some(auth) = auth.as_ref() - && auth.mode == AuthMode::ChatGPT - && let Err(err) = manager.refresh_token().await - { - let stream_error = match err { - RefreshTokenError::Permanent(failed) => { - StreamAttemptError::Fatal(CodexErr::RefreshTokenFailed(failed)) - } - RefreshTokenError::Transient(other) => { - StreamAttemptError::RetryableTransportError(CodexErr::Io(other)) - } - }; - return Err(stream_error); - } - - // The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx - // errors. When we bubble early with only the HTTP status the caller sees an opaque - // "unexpected status 400 Bad Request" which makes debugging nearly impossible. - // Instead, read (and include) the response text so higher layers and users see the - // exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is - // small and this branch only runs on error paths so the extra allocation is - // negligible. - if !(status == StatusCode::TOO_MANY_REQUESTS - || status == StatusCode::UNAUTHORIZED - || status.is_server_error()) - { - // Surface the error body to callers. Use `unwrap_or_default` per Clippy. - let body = res.text().await.unwrap_or_default(); - return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus( - UnexpectedResponseError { - status, - body, - request_id: None, - }, - ))); - } - - if status == StatusCode::TOO_MANY_REQUESTS { - let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers()); - let body = res.json::().await.ok(); - if let Some(ErrorResponse { error }) = body { - if error.r#type.as_deref() == Some("usage_limit_reached") { - // Prefer the plan_type provided in the error message if present - // because it's more up to date than the one encoded in the auth - // token. - let plan_type = error - .plan_type - .or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type)); - let resets_at = error - .resets_at - .and_then(|seconds| DateTime::::from_timestamp(seconds, 0)); - let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError { - plan_type, - resets_at, - rate_limits: rate_limit_snapshot, - }); - return Err(StreamAttemptError::Fatal(codex_err)); - } else if error.r#type.as_deref() == Some("usage_not_included") { - return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded)); - } else if is_quota_exceeded_error(&error) { - return Err(StreamAttemptError::Fatal(CodexErr::QuotaExceeded)); - } - } - } - - Err(StreamAttemptError::RetryableHttpError { - status, - retry_after, - request_id, - }) - } - Err(e) => Err(StreamAttemptError::RetryableTransportError( - CodexErr::ConnectionFailed(ConnectionFailedError { source: e }), - )), } } @@ -522,16 +322,35 @@ impl ModelClient { self.auth_manager.clone() } + /// Compacts the current conversation history using the Compact endpoint. + /// + /// This is a unary call (no streaming) that returns a new list of + /// `ResponseItem`s representing the compacted transcript. pub async fn compact_conversation_history(&self, prompt: &Prompt) -> Result> { if prompt.input.is_empty() { return Ok(Vec::new()); } let auth_manager = self.auth_manager.clone(); let auth = auth_manager.as_ref().and_then(|m| m.auth()); - let mut req_builder = self + let api_provider = self .provider - .create_compact_request_builder(&self.client, &auth) - .await?; + .to_api_provider(auth.as_ref().map(|a| a.mode))?; + let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?; + let transport = ReqwestTransport::new(build_reqwest_client()); + let request_telemetry = self.build_request_telemetry(); + let client = ApiCompactClient::new(transport, api_provider, api_auth) + .with_telemetry(Some(request_telemetry)); + + let instructions = prompt + .get_full_instructions(&self.config.model_family) + .into_owned(); + let payload = ApiCompactionInput { + model: &self.config.model, + input: &prompt.input, + instructions: &instructions, + }; + + let mut extra_headers = ApiHeaderMap::new(); if let SessionSource::SubAgent(sub) = &self.session_source { let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub { label.clone() @@ -541,1109 +360,183 @@ impl ModelClient { .and_then(|v| v.as_str().map(std::string::ToString::to_string)) .unwrap_or_else(|| "other".to_string()) }; - req_builder = req_builder.header("x-openai-subagent", subagent); - } - if let Some(auth) = auth.as_ref() - && auth.mode == AuthMode::ChatGPT - && let Some(account_id) = auth.get_account_id() - { - req_builder = req_builder.header("chatgpt-account-id", account_id); - } - let payload = CompactHistoryRequest { - model: &self.config.model, - input: &prompt.input, - instructions: &prompt.get_full_instructions(&self.config.model_family), - }; - - if enabled!(tracing::Level::TRACE) { - trace!( - "POST to {}: {}", - self.provider - .get_compact_url(&auth) - .unwrap_or("".to_string()), - serde_json::to_value(&payload).unwrap_or_default() - ); + if let Ok(val) = HeaderValue::from_str(&subagent) { + extra_headers.insert("x-openai-subagent", val); + } } - let response = req_builder - .json(&payload) - .send() + client + .compact_input(&payload, extra_headers) .await - .map_err(|source| CodexErr::ConnectionFailed(ConnectionFailedError { source }))?; - let status = response.status(); - let body = response - .text() - .await - .map_err(|source| CodexErr::ConnectionFailed(ConnectionFailedError { source }))?; - if !status.is_success() { - return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError { - status, - body, - request_id: None, - })); - } - let CompactHistoryResponse { output } = serde_json::from_str(&body)?; - Ok(output) + .map_err(map_api_error) } } -enum StreamAttemptError { - RetryableHttpError { - status: StatusCode, - retry_after: Option, - request_id: Option, - }, - RetryableTransportError(CodexErr), - Fatal(CodexErr), -} - -impl StreamAttemptError { - /// attempt is 0-based. - fn delay(&self, attempt: u64) -> Duration { - // backoff() uses 1-based attempts. - let backoff_attempt = attempt + 1; - match self { - Self::RetryableHttpError { retry_after, .. } => { - retry_after.unwrap_or_else(|| backoff(backoff_attempt)) - } - Self::RetryableTransportError { .. } => backoff(backoff_attempt), - Self::Fatal(_) => { - // Should not be called on Fatal errors. - Duration::from_secs(0) - } - } +impl ModelClient { + /// Builds request and SSE telemetry for streaming API calls (Chat/Responses). + fn build_streaming_telemetry(&self) -> (Arc, Arc) { + let telemetry = Arc::new(ApiTelemetry::new(self.otel_event_manager.clone())); + let request_telemetry: Arc = telemetry.clone(); + let sse_telemetry: Arc = telemetry; + (request_telemetry, sse_telemetry) } - fn into_error(self) -> CodexErr { - match self { - Self::RetryableHttpError { - status, request_id, .. - } => { - if status == StatusCode::INTERNAL_SERVER_ERROR { - CodexErr::InternalServerError - } else { - CodexErr::RetryLimit(RetryLimitReachedError { status, request_id }) - } - } - Self::RetryableTransportError(error) => error, - Self::Fatal(error) => error, - } + /// Builds request telemetry for unary API calls (e.g., Compact endpoint). + fn build_request_telemetry(&self) -> Arc { + let telemetry = Arc::new(ApiTelemetry::new(self.otel_event_manager.clone())); + let request_telemetry: Arc = telemetry; + request_telemetry } } -#[derive(Debug, Deserialize, Serialize)] -struct SseEvent { - #[serde(rename = "type")] - kind: String, - response: Option, - item: Option, - delta: Option, - summary_index: Option, - content_index: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompleted { - id: String, - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompletedUsage { - input_tokens: i64, - input_tokens_details: Option, - output_tokens: i64, - output_tokens_details: Option, - total_tokens: i64, -} - -impl From for TokenUsage { - fn from(val: ResponseCompletedUsage) -> Self { - TokenUsage { - input_tokens: val.input_tokens, - cached_input_tokens: val - .input_tokens_details - .map(|d| d.cached_tokens) - .unwrap_or(0), - output_tokens: val.output_tokens, - reasoning_output_tokens: val - .output_tokens_details - .map(|d| d.reasoning_tokens) - .unwrap_or(0), - total_tokens: val.total_tokens, - } +/// Adapts the core `Prompt` type into the `codex-api` payload shape. +fn build_api_prompt(prompt: &Prompt, instructions: String, tools_json: Vec) -> ApiPrompt { + ApiPrompt { + instructions, + input: prompt.get_formatted_input(), + tools: tools_json, + parallel_tool_calls: prompt.parallel_tool_calls, + output_schema: prompt.output_schema.clone(), } } -#[derive(Debug, Deserialize)] -struct ResponseCompletedInputTokensDetails { - cached_tokens: i64, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompletedOutputTokensDetails { - reasoning_tokens: i64, -} - -fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) { - let Some(input_value) = payload_json.get_mut("input") else { - return; - }; - let serde_json::Value::Array(items) = input_value else { - return; - }; - - for (value, item) in items.iter_mut().zip(original_items.iter()) { - if let ResponseItem::Reasoning { id, .. } - | ResponseItem::Message { id: Some(id), .. } - | ResponseItem::WebSearchCall { id: Some(id), .. } - | ResponseItem::FunctionCall { id: Some(id), .. } - | ResponseItem::LocalShellCall { id: Some(id), .. } - | ResponseItem::CustomToolCall { id: Some(id), .. } = item - { - if id.is_empty() { - continue; - } - - if let Some(obj) = value.as_object_mut() { - obj.insert("id".to_string(), Value::String(id.clone())); - } - } - } -} - -fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option { - let primary = parse_rate_limit_window( - headers, - "x-codex-primary-used-percent", - "x-codex-primary-window-minutes", - "x-codex-primary-reset-at", - ); - - let secondary = parse_rate_limit_window( - headers, - "x-codex-secondary-used-percent", - "x-codex-secondary-window-minutes", - "x-codex-secondary-reset-at", - ); - - let credits = parse_credits_snapshot(headers); - - Some(RateLimitSnapshot { - primary, - secondary, - credits, - }) -} - -fn parse_rate_limit_window( - headers: &HeaderMap, - used_percent_header: &str, - window_minutes_header: &str, - resets_at_header: &str, -) -> Option { - let used_percent: Option = parse_header_f64(headers, used_percent_header); - - used_percent.and_then(|used_percent| { - let window_minutes = parse_header_i64(headers, window_minutes_header); - let resets_at = parse_header_i64(headers, resets_at_header); - - let has_data = used_percent != 0.0 - || window_minutes.is_some_and(|minutes| minutes != 0) - || resets_at.is_some(); - - has_data.then_some(RateLimitWindow { - used_percent, - window_minutes, - resets_at, - }) - }) -} - -fn parse_credits_snapshot(headers: &HeaderMap) -> Option { - let has_credits = parse_header_bool(headers, "x-codex-credits-has-credits")?; - let unlimited = parse_header_bool(headers, "x-codex-credits-unlimited")?; - let balance = parse_header_str(headers, "x-codex-credits-balance") - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(std::string::ToString::to_string); - Some(CreditsSnapshot { - has_credits, - unlimited, - balance, - }) -} - -fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option { - parse_header_str(headers, name)? - .parse::() - .ok() - .filter(|v| v.is_finite()) -} - -fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option { - parse_header_str(headers, name)?.parse::().ok() -} - -fn parse_header_bool(headers: &HeaderMap, name: &str) -> Option { - let raw = parse_header_str(headers, name)?; - if raw.eq_ignore_ascii_case("true") || raw == "1" { - Some(true) - } else if raw.eq_ignore_ascii_case("false") || raw == "0" { - Some(false) - } else { - None - } -} - -fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { - headers.get(name)?.to_str().ok() -} - -async fn process_sse( - stream: S, - tx_event: mpsc::Sender>, - idle_timeout: Duration, - otel_event_manager: OtelEventManager, -) where - S: Stream> + Unpin, +fn map_response_stream(api_stream: S, otel_event_manager: OtelEventManager) -> ResponseStream +where + S: futures::Stream> + + Unpin + + Send + + 'static, { - let mut stream = stream.eventsource(); - - // If the stream stays completely silent for an extended period treat it as disconnected. - // The response id returned from the "complete" message. - let mut response_completed: Option = None; - let mut response_error: Option = None; - - loop { - let start = std::time::Instant::now(); - let response = timeout(idle_timeout, stream.next()).await; - let duration = start.elapsed(); - otel_event_manager.log_sse_event(&response, duration); - - let sse = match response { - Ok(Some(Ok(sse))) => sse, - Ok(Some(Err(e))) => { - debug!("SSE Error: {e:#}"); - let event = CodexErr::Stream(e.to_string(), None); - let _ = tx_event.send(Err(event)).await; - return; - } - Ok(None) => { - match response_completed { - Some(ResponseCompleted { - id: response_id, - usage, - }) => { - if let Some(token_usage) = &usage { - otel_event_manager.sse_event_completed( - token_usage.input_tokens, - token_usage.output_tokens, - token_usage - .input_tokens_details - .as_ref() - .map(|d| d.cached_tokens), - token_usage - .output_tokens_details - .as_ref() - .map(|d| d.reasoning_tokens), - token_usage.total_tokens, - ); - } - let event = ResponseEvent::Completed { - response_id, - token_usage: usage.map(Into::into), - }; - let _ = tx_event.send(Ok(event)).await; - } - None => { - let error = response_error.unwrap_or(CodexErr::Stream( - "stream closed before response.completed".into(), - None, - )); - otel_event_manager.see_event_completed_failed(&error); - - let _ = tx_event.send(Err(error)).await; - } - } - return; - } - Err(_) => { - let _ = tx_event - .send(Err(CodexErr::Stream( - "idle timeout waiting for SSE".into(), - None, - ))) - .await; - return; - } - }; - - let raw = sse.data.clone(); - trace!("SSE event: {}", raw); - - let event: SseEvent = match serde_json::from_str(&sse.data) { - Ok(event) => event, - Err(e) => { - debug!("Failed to parse SSE event: {e}, data: {}", &sse.data); - continue; - } - }; - - match event.kind.as_str() { - // Individual output item finalised. Forward immediately so the - // rest of the agent can stream assistant text/functions *live* - // instead of waiting for the final `response.completed` envelope. - // - // IMPORTANT: We used to ignore these events and forward the - // duplicated `output` array embedded in the `response.completed` - // payload. That produced two concrete issues: - // 1. No real‑time streaming – the user only saw output after the - // entire turn had finished, which broke the "typing" UX and - // made long‑running turns look stalled. - // 2. Duplicate `function_call_output` items – both the - // individual *and* the completed array were forwarded, which - // confused the backend and triggered 400 - // "previous_response_not_found" errors because the duplicated - // IDs did not match the incremental turn chain. - // - // The fix is to forward the incremental events *as they come* and - // drop the duplicated list inside `response.completed`. - "response.output_item.done" => { - let Some(item_val) = event.item else { continue }; - let Ok(item) = serde_json::from_value::(item_val) else { - debug!("failed to parse ResponseItem from output_item.done"); - continue; - }; - - let event = ResponseEvent::OutputItemDone(item); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.output_text.delta" => { - if let Some(delta) = event.delta { - let event = ResponseEvent::OutputTextDelta(delta); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.reasoning_summary_text.delta" => { - if let (Some(delta), Some(summary_index)) = (event.delta, event.summary_index) { - let event = ResponseEvent::ReasoningSummaryDelta { - delta, - summary_index, - }; - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.reasoning_text.delta" => { - if let (Some(delta), Some(content_index)) = (event.delta, event.content_index) { - let event = ResponseEvent::ReasoningContentDelta { - delta, - content_index, - }; - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.created" => { - if event.response.is_some() { - let _ = tx_event.send(Ok(ResponseEvent::Created {})).await; - } - } - "response.failed" => { - if let Some(resp_val) = event.response { - response_error = Some(CodexErr::Stream( - "response.failed event received".to_string(), - None, - )); - - let error = resp_val.get("error"); - - if let Some(error) = error { - match serde_json::from_value::(error.clone()) { - Ok(error) => { - if is_context_window_error(&error) { - response_error = Some(CodexErr::ContextWindowExceeded); - } else if is_quota_exceeded_error(&error) { - response_error = Some(CodexErr::QuotaExceeded); - } else { - let delay = try_parse_retry_after(&error); - let message = error.message.clone().unwrap_or_default(); - response_error = Some(CodexErr::Stream(message, delay)); - } - } - Err(e) => { - let error = format!("failed to parse ErrorResponse: {e}"); - debug!(error); - response_error = Some(CodexErr::Stream(error, None)) - } - } - } - } - } - // Final response completed – includes array of output items & id - "response.completed" => { - if let Some(resp_val) = event.response { - match serde_json::from_value::(resp_val) { - Ok(r) => { - response_completed = Some(r); - } - Err(e) => { - let error = format!("failed to parse ResponseCompleted: {e}"); - debug!(error); - response_error = Some(CodexErr::Stream(error, None)); - continue; - } - }; - }; - } - "response.content_part.done" - | "response.function_call_arguments.delta" - | "response.custom_tool_call_input.delta" - | "response.custom_tool_call_input.done" // also emitted as response.output_item.done - | "response.in_progress" - | "response.output_text.done" => {} - "response.output_item.added" => { - let Some(item_val) = event.item else { continue }; - let Ok(item) = serde_json::from_value::(item_val) else { - debug!("failed to parse ResponseItem from output_item.done"); - continue; - }; - - let event = ResponseEvent::OutputItemAdded(item); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.reasoning_summary_part.added" => { - if let Some(summary_index) = event.summary_index { - // Boundary between reasoning summary sections (e.g., titles). - let event = ResponseEvent::ReasoningSummaryPartAdded { summary_index }; - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.reasoning_summary_text.done" => {} - _ => {} - } - } -} - -/// used in tests to stream from a text SSE file -async fn stream_from_fixture( - path: impl AsRef, - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, -) -> Result { let (tx_event, rx_event) = mpsc::channel::>(1600); - let f = std::fs::File::open(path.as_ref())?; - let lines = std::io::BufReader::new(f).lines(); - - // insert \n\n after each line for proper SSE parsing - let mut content = String::new(); - for line in lines { - content.push_str(&line?); - content.push_str("\n\n"); - } - - let rdr = std::io::Cursor::new(content); - let stream = ReaderStream::new(rdr).map_err(CodexErr::Io); - tokio::spawn(process_sse( - stream, - tx_event, - provider.stream_idle_timeout(), - otel_event_manager, - )); - Ok(ResponseStream { rx_event }) -} - -fn rate_limit_regex() -> &'static Regex { - static RE: OnceLock = OnceLock::new(); - - // Match both OpenAI-style messages like "Please try again in 1.898s" - // and Azure OpenAI-style messages like "Try again in 35 seconds". - #[expect(clippy::unwrap_used)] - RE.get_or_init(|| Regex::new(r"(?i)try again in\s*(\d+(?:\.\d+)?)\s*(s|ms|seconds?)").unwrap()) -} - -fn try_parse_retry_after(err: &Error) -> Option { - if err.code != Some("rate_limit_exceeded".to_string()) { - return None; - } - - // parse retry hints like "try again in 1.898s" or - // "Try again in 35 seconds" using regex - let re = rate_limit_regex(); - if let Some(message) = &err.message - && let Some(captures) = re.captures(message) - { - let seconds = captures.get(1); - let unit = captures.get(2); - - if let (Some(value), Some(unit)) = (seconds, unit) { - let value = value.as_str().parse::().ok()?; - let unit = unit.as_str().to_ascii_lowercase(); - - if unit == "s" || unit.starts_with("second") { - return Some(Duration::from_secs_f64(value)); - } else if unit == "ms" { - return Some(Duration::from_millis(value as u64)); - } - } - } - None -} - -fn is_context_window_error(error: &Error) -> bool { - error.code.as_deref() == Some("context_length_exceeded") -} - -fn is_quota_exceeded_error(error: &Error) -> bool { - error.code.as_deref() == Some("insufficient_quota") -} - -#[cfg(test)] -mod tests { - use super::*; - use assert_matches::assert_matches; - use serde_json::json; - use tokio::sync::mpsc; - use tokio_test::io::Builder as IoBuilder; - use tokio_util::io::ReaderStream; - - // ──────────────────────────── - // Helpers - // ──────────────────────────── - - /// Runs the SSE parser on pre-chunked byte slices and returns every event - /// (including any final `Err` from a stream-closure check). - async fn collect_events( - chunks: &[&[u8]], - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, - ) -> Vec> { - let mut builder = IoBuilder::new(); - for chunk in chunks { - builder.read(chunk); - } - - let reader = builder.build(); - let stream = ReaderStream::new(reader).map_err(CodexErr::Io); - let (tx, mut rx) = mpsc::channel::>(16); - tokio::spawn(process_sse( - stream, - tx, - provider.stream_idle_timeout(), - otel_event_manager, - )); - - let mut events = Vec::new(); - while let Some(ev) = rx.recv().await { - events.push(ev); - } - events - } - - /// Builds an in-memory SSE stream from JSON fixtures and returns only the - /// successfully parsed events (panics on internal channel errors). - async fn run_sse( - events: Vec, - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, - ) -> Vec { - let mut body = String::new(); - for e in events { - let kind = e - .get("type") - .and_then(|v| v.as_str()) - .expect("fixture event missing type"); - if e.as_object().map(|o| o.len() == 1).unwrap_or(false) { - body.push_str(&format!("event: {kind}\n\n")); - } else { - body.push_str(&format!("event: {kind}\ndata: {e}\n\n")); - } - } - - let (tx, mut rx) = mpsc::channel::>(8); - let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io); - tokio::spawn(process_sse( - stream, - tx, - provider.stream_idle_timeout(), - otel_event_manager, - )); - - let mut out = Vec::new(); - while let Some(ev) = rx.recv().await { - out.push(ev.expect("channel closed")); - } - out - } - - fn otel_event_manager() -> OtelEventManager { - OtelEventManager::new( - ConversationId::new(), - "test", - "test", - None, - Some("test@test.com".to_string()), - Some(AuthMode::ChatGPT), - false, - "test".to_string(), - ) - } - - // ──────────────────────────── - // Tests from `implement-test-for-responses-api-sse-parser` - // ──────────────────────────── - - #[tokio::test] - async fn parses_items_and_completed() { - let item1 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "Hello"}] - } - }) - .to_string(); - - let item2 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "World"}] - } - }) - .to_string(); - - let completed = json!({ - "type": "response.completed", - "response": { "id": "resp1" } - }) - .to_string(); - - let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); - let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n"); - let sse3 = format!("event: response.completed\ndata: {completed}\n\n"); - - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events( - &[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()], - provider, - otel_event_manager, - ) - .await; - - assert_eq!(events.len(), 3); - - matches!( - &events[0], - Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) - if role == "assistant" - ); - - matches!( - &events[1], - Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) - if role == "assistant" - ); - - match &events[2] { - Ok(ResponseEvent::Completed { - response_id, - token_usage, - }) => { - assert_eq!(response_id, "resp1"); - assert!(token_usage.is_none()); - } - other => panic!("unexpected third event: {other:?}"), - } - } - - #[tokio::test] - async fn error_when_missing_completed() { - let item1 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "Hello"}] - } - }) - .to_string(); - - let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 2); - - matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))); - - match &events[1] { - Err(CodexErr::Stream(msg, _)) => { - assert_eq!(msg, "stream closed before response.completed") - } - other => panic!("unexpected second event: {other:?}"), - } - } - - #[tokio::test] - async fn error_when_error_event() { - let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(CodexErr::Stream(msg, delay)) => { - assert_eq!( - msg, - "Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more." - ); - assert_eq!(*delay, Some(Duration::from_secs_f64(11.054))); - } - other => panic!("unexpected second event: {other:?}"), - } - } - - #[tokio::test] - async fn context_window_error_is_fatal() { - let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(err @ CodexErr::ContextWindowExceeded) => { - assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string()); - } - other => panic!("unexpected context window event: {other:?}"), - } - } - - #[tokio::test] - async fn context_window_error_with_newline_is_fatal() { - let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(err @ CodexErr::ContextWindowExceeded) => { - assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string()); - } - other => panic!("unexpected context window event: {other:?}"), - } - } - - #[tokio::test] - async fn quota_exceeded_error_is_fatal() { - let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_quota","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"insufficient_quota","message":"You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors."},"incomplete_details":null}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(err @ CodexErr::QuotaExceeded) => { - assert_eq!(err.to_string(), CodexErr::QuotaExceeded.to_string()); - } - other => panic!("unexpected quota exceeded event: {other:?}"), - } - } - - // ──────────────────────────── - // Table-driven test from `main` - // ──────────────────────────── - - /// Verifies that the adapter produces the right `ResponseEvent` for a - /// variety of incoming `type` values. - #[tokio::test] - async fn table_driven_event_kinds() { - struct TestCase { - name: &'static str, - event: serde_json::Value, - expect_first: fn(&ResponseEvent) -> bool, - expected_len: usize, - } - - fn is_created(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::Created) - } - fn is_output(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::OutputItemDone(_)) - } - fn is_completed(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::Completed { .. }) - } - - let completed = json!({ - "type": "response.completed", - "response": { - "id": "c", - "usage": { - "input_tokens": 0, - "input_tokens_details": null, - "output_tokens": 0, - "output_tokens_details": null, - "total_tokens": 0 - }, - "output": [] - } - }); - - let cases = vec![ - TestCase { - name: "created", - event: json!({"type": "response.created", "response": {}}), - expect_first: is_created, - expected_len: 2, - }, - TestCase { - name: "output_item.done", - event: json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [ - {"type": "output_text", "text": "hi"} - ] + let manager = otel_event_manager; + + tokio::spawn(async move { + let mut logged_error = false; + let mut api_stream = api_stream; + while let Some(event) = api_stream.next().await { + match event { + Ok(ResponseEvent::Completed { + response_id, + token_usage, + }) => { + if let Some(usage) = &token_usage { + manager.sse_event_completed( + usage.input_tokens, + usage.output_tokens, + Some(usage.cached_input_tokens), + Some(usage.reasoning_output_tokens), + usage.total_tokens, + ); } - }), - expect_first: is_output, - expected_len: 2, - }, - TestCase { - name: "unknown", - event: json!({"type": "response.new_tool_event"}), - expect_first: is_completed, - expected_len: 1, - }, - ]; - - for case in cases { - let mut evs = vec![case.event]; - evs.push(completed.clone()); - - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let out = run_sse(evs, provider, otel_event_manager).await; - assert_eq!(out.len(), case.expected_len, "case {}", case.name); - assert!( - (case.expect_first)(&out[0]), - "first event mismatch in case {}", - case.name - ); + if tx_event + .send(Ok(ResponseEvent::Completed { + response_id, + token_usage, + })) + .await + .is_err() + { + return; + } + } + Ok(event) => { + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + Err(err) => { + let mapped = map_api_error(err); + if !logged_error { + manager.see_event_completed_failed(&mapped); + logged_error = true; + } + if tx_event.send(Err(mapped)).await.is_err() { + return; + } + } + } } + }); + + ResponseStream { rx_event } +} + +/// Handles a 401 response by optionally refreshing ChatGPT tokens once. +/// +/// When refresh succeeds, the caller should retry the API call; otherwise +/// the mapped `CodexErr` is returned to the caller. +async fn handle_unauthorized( + status: StatusCode, + refreshed: &mut bool, + auth_manager: &Option>, + auth: &Option, +) -> Result<()> { + if *refreshed { + return Err(map_unauthorized_status(status)); } - #[test] - fn test_try_parse_retry_after() { - let err = Error { - r#type: None, - message: Some("Rate limit reached for gpt-5.1 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()), - code: Some("rate_limit_exceeded".to_string()), - plan_type: None, - resets_at: None - }; - - let delay = try_parse_retry_after(&err); - assert_eq!(delay, Some(Duration::from_millis(28))); - } - - #[test] - fn test_try_parse_retry_after_no_delay() { - let err = Error { - r#type: None, - message: Some("Rate limit reached for gpt-5.1 in organization on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()), - code: Some("rate_limit_exceeded".to_string()), - plan_type: None, - resets_at: None - }; - let delay = try_parse_retry_after(&err); - assert_eq!(delay, Some(Duration::from_secs_f64(1.898))); - } - - #[test] - fn test_try_parse_retry_after_azure() { - let err = Error { - r#type: None, - message: Some("Rate limit exceeded. Try again in 35 seconds.".to_string()), - code: Some("rate_limit_exceeded".to_string()), - plan_type: None, - resets_at: None, - }; - let delay = try_parse_retry_after(&err); - assert_eq!(delay, Some(Duration::from_secs(35))); - } - - #[test] - fn error_response_deserializes_schema_known_plan_type_and_serializes_back() { - use crate::token_data::KnownPlan; - use crate::token_data::PlanType; - - let json = - r#"{"error":{"type":"usage_limit_reached","plan_type":"pro","resets_at":1704067200}}"#; - let resp: ErrorResponse = serde_json::from_str(json).expect("should deserialize schema"); - - assert_matches!(resp.error.plan_type, Some(PlanType::Known(KnownPlan::Pro))); - - let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type"); - assert_eq!(plan_json, "\"pro\""); - } - - #[test] - fn error_response_deserializes_schema_unknown_plan_type_and_serializes_back() { - use crate::token_data::PlanType; - - let json = - r#"{"error":{"type":"usage_limit_reached","plan_type":"vip","resets_at":1704067260}}"#; - let resp: ErrorResponse = serde_json::from_str(json).expect("should deserialize schema"); - - assert_matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"); - - let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type"); - assert_eq!(plan_json, "\"vip\""); + if let Some(manager) = auth_manager.as_ref() + && let Some(auth) = auth.as_ref() + && auth.mode == AuthMode::ChatGPT + { + match manager.refresh_token().await { + Ok(_) => { + *refreshed = true; + Ok(()) + } + Err(RefreshTokenError::Permanent(failed)) => Err(CodexErr::RefreshTokenFailed(failed)), + Err(RefreshTokenError::Transient(other)) => Err(CodexErr::Io(other)), + } + } else { + Err(map_unauthorized_status(status)) + } +} + +fn map_unauthorized_status(status: StatusCode) -> CodexErr { + map_api_error(ApiError::Transport(TransportError::Http { + status, + headers: None, + body: None, + })) +} + +struct ApiTelemetry { + otel_event_manager: OtelEventManager, +} + +impl ApiTelemetry { + fn new(otel_event_manager: OtelEventManager) -> Self { + Self { otel_event_manager } + } +} + +impl RequestTelemetry for ApiTelemetry { + fn on_request( + &self, + attempt: u64, + status: Option, + error: Option<&TransportError>, + duration: Duration, + ) { + let error_message = error.map(std::string::ToString::to_string); + self.otel_event_manager.record_api_request( + attempt, + status.map(|s| s.as_u16()), + error_message.as_deref(), + duration, + ); + } +} + +impl SseTelemetry for ApiTelemetry { + fn on_sse_poll( + &self, + result: &std::result::Result< + Option>>, + tokio::time::error::Elapsed, + >, + duration: Duration, + ) { + self.otel_event_manager.log_sse_event(result, duration); } } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 21f6fc657..a249ca6fc 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -1,16 +1,11 @@ use crate::client_common::tools::ToolSpec; use crate::error::Result; use crate::model_family::ModelFamily; -use crate::protocol::RateLimitSnapshot; -use crate::protocol::TokenUsage; +pub use codex_api::common::ResponseEvent; use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; -use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; -use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; -use codex_protocol::config_types::Verbosity as VerbosityConfig; use codex_protocol::models::ResponseItem; use futures::Stream; use serde::Deserialize; -use serde::Serialize; use serde_json::Value; use std::borrow::Cow; use std::collections::HashSet; @@ -184,104 +179,6 @@ fn strip_total_output_header(output: &str) -> Option<(&str, u32)> { Some((remainder, total_lines)) } -#[derive(Debug)] -pub enum ResponseEvent { - Created, - OutputItemDone(ResponseItem), - OutputItemAdded(ResponseItem), - Completed { - response_id: String, - token_usage: Option, - }, - OutputTextDelta(String), - ReasoningSummaryDelta { - delta: String, - summary_index: i64, - }, - ReasoningContentDelta { - delta: String, - content_index: i64, - }, - ReasoningSummaryPartAdded { - summary_index: i64, - }, - RateLimits(RateLimitSnapshot), -} - -#[derive(Debug, Serialize)] -pub(crate) struct Reasoning { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) effort: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) summary: Option, -} - -#[derive(Debug, Serialize, Default, Clone)] -#[serde(rename_all = "snake_case")] -pub(crate) enum TextFormatType { - #[default] - JsonSchema, -} - -#[derive(Debug, Serialize, Default, Clone)] -pub(crate) struct TextFormat { - pub(crate) r#type: TextFormatType, - pub(crate) strict: bool, - pub(crate) schema: Value, - pub(crate) name: String, -} - -/// Controls under the `text` field in the Responses API for GPT-5. -#[derive(Debug, Serialize, Default, Clone)] -pub(crate) struct TextControls { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) verbosity: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) format: Option, -} - -#[derive(Debug, Serialize, Default, Clone)] -#[serde(rename_all = "lowercase")] -pub(crate) enum OpenAiVerbosity { - Low, - #[default] - Medium, - High, -} - -impl From for OpenAiVerbosity { - fn from(v: VerbosityConfig) -> Self { - match v { - VerbosityConfig::Low => OpenAiVerbosity::Low, - VerbosityConfig::Medium => OpenAiVerbosity::Medium, - VerbosityConfig::High => OpenAiVerbosity::High, - } - } -} - -/// Request object that is serialized as JSON and POST'ed when using the -/// Responses API. -#[derive(Debug, Serialize)] -pub(crate) struct ResponsesApiRequest<'a> { - pub(crate) model: &'a str, - pub(crate) instructions: &'a str, - // TODO(mbolin): ResponseItem::Other should not be serialized. Currently, - // we code defensively to avoid this case, but perhaps we should use a - // separate enum for serialization. - pub(crate) input: &'a Vec, - pub(crate) tools: &'a [serde_json::Value], - pub(crate) tool_choice: &'static str, - pub(crate) parallel_tool_calls: bool, - pub(crate) reasoning: Option, - pub(crate) store: bool, - pub(crate) stream: bool, - pub(crate) include: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) prompt_cache_key: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) text: Option, -} - pub(crate) mod tools { use crate::tools::spec::JsonSchema; use serde::Deserialize; @@ -341,25 +238,6 @@ pub(crate) mod tools { } } -pub(crate) fn create_text_param_for_request( - verbosity: Option, - output_schema: &Option, -) -> Option { - if verbosity.is_none() && output_schema.is_none() { - return None; - } - - Some(TextControls { - verbosity: verbosity.map(std::convert::Into::into), - format: output_schema.as_ref().map(|schema| TextFormat { - r#type: TextFormatType::JsonSchema, - strict: true, - schema: schema.clone(), - name: "codex_output_schema".to_string(), - }), - }) -} - pub struct ResponseStream { pub(crate) rx_event: mpsc::Receiver>, } @@ -375,6 +253,10 @@ impl Stream for ResponseStream { #[cfg(test)] mod tests { use crate::model_family::find_family_for_model; + use codex_api::ResponsesApiRequest; + use codex_api::common::OpenAiVerbosity; + use codex_api::common::TextControls; + use codex_api::create_text_param_for_request; use pretty_assertions::assert_eq; use super::*; diff --git a/codex-rs/core/src/default_client.rs b/codex-rs/core/src/default_client.rs index 8e4635460..29986c401 100644 --- a/codex-rs/core/src/default_client.rs +++ b/codex-rs/core/src/default_client.rs @@ -258,6 +258,11 @@ fn sanitize_user_agent(candidate: String, fallback: &str) -> String { /// Create an HTTP client with default `originator` and `User-Agent` headers set. pub fn create_client() -> CodexHttpClient { + let inner = build_reqwest_client(); + CodexHttpClient::new(inner) +} + +pub fn build_reqwest_client() -> reqwest::Client { use reqwest::header::HeaderMap; let mut headers = HeaderMap::new(); @@ -272,8 +277,7 @@ pub fn create_client() -> CodexHttpClient { builder = builder.no_proxy(); } - let inner = builder.build().unwrap_or_else(|_| reqwest::Client::new()); - CodexHttpClient::new(inner) + builder.build().unwrap_or_else(|_| reqwest::Client::new()) } fn is_sandboxed() -> bool { diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 805943a2e..7a9440eb2 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -5,10 +5,10 @@ // the TUI or the tracing stack). #![deny(clippy::print_stdout, clippy::print_stderr)] +pub mod api_bridge; mod apply_patch; pub mod auth; pub mod bash; -mod chat_completions; mod client; mod client_common; pub mod codex; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 3ab341eea..4912a6469 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -5,11 +5,13 @@ //! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers` //! key. These override or extend the defaults at runtime. -use crate::CodexAuth; -use crate::default_client::CodexHttpClient; -use crate::default_client::CodexRequestBuilder; -use crate::error::CodexErr; +use codex_api::Provider as ApiProvider; +use codex_api::WireApi as ApiWireApi; +use codex_api::provider::RetryConfig as ApiRetryConfig; use codex_app_server_protocol::AuthMode; +use http::HeaderMap; +use http::header::HeaderName; +use http::header::HeaderValue; use serde::Deserialize; use serde::Serialize; use std::collections::HashMap; @@ -97,148 +99,14 @@ pub struct ModelProviderInfo { } impl ModelProviderInfo { - /// Construct a `POST` RequestBuilder for the given URL using the provided - /// [`CodexHttpClient`] applying: - /// • provider-specific headers (static + env based) - /// • Bearer auth header when an API key is available. - /// • Auth token for OAuth. - /// - /// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the - /// one produced by [`ModelProviderInfo::api_key`]. - pub async fn create_request_builder<'a>( - &'a self, - client: &'a CodexHttpClient, - auth: &Option, - ) -> crate::error::Result { - let effective_auth = self.effective_auth(auth)?; - - let url = self.get_full_url(&effective_auth); - - let mut builder = client.post(url); - - if let Some(auth) = effective_auth.as_ref() { - builder = builder.bearer_auth(auth.get_token().await?); - } - - Ok(self.apply_http_headers(builder)) - } - - pub async fn create_compact_request_builder<'a>( - &'a self, - client: &'a CodexHttpClient, - auth: &Option, - ) -> crate::error::Result { - if self.wire_api != WireApi::Responses { - return Err(CodexErr::UnsupportedOperation( - "Compaction endpoint requires Responses API providers".to_string(), - )); - } - let effective_auth = self.effective_auth(auth)?; - - let url = self.get_compact_url(&effective_auth).ok_or_else(|| { - CodexErr::UnsupportedOperation( - "Compaction endpoint requires Responses API providers".to_string(), - ) - })?; - - let mut builder = client.post(url); - - if let Some(auth) = effective_auth.as_ref() { - builder = builder.bearer_auth(auth.get_token().await?); - } - - Ok(self.apply_http_headers(builder)) - } - - fn effective_auth(&self, auth: &Option) -> crate::error::Result> { - if let Some(secret_key) = &self.experimental_bearer_token { - return Ok(Some(CodexAuth::from_api_key(secret_key))); - } - - match self.api_key() { - Ok(Some(key)) => Ok(Some(CodexAuth::from_api_key(&key))), - Ok(None) => Ok(auth.clone()), - Err(err) => { - if auth.is_some() { - Ok(auth.clone()) - } else { - Err(err) - } - } - } - } - - fn get_query_string(&self) -> String { - self.query_params - .as_ref() - .map_or_else(String::new, |params| { - let full_params = params - .iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join("&"); - format!("?{full_params}") - }) - } - - pub(crate) fn get_full_url(&self, auth: &Option) -> String { - let default_base_url = if matches!( - auth, - Some(CodexAuth { - mode: AuthMode::ChatGPT, - .. - }) - ) { - "https://chatgpt.com/backend-api/codex" - } else { - "https://api.openai.com/v1" - }; - let query_string = self.get_query_string(); - let base_url = self - .base_url - .clone() - .unwrap_or(default_base_url.to_string()); - - match self.wire_api { - WireApi::Responses => format!("{base_url}/responses{query_string}"), - WireApi::Chat => format!("{base_url}/chat/completions{query_string}"), - } - } - - pub(crate) fn get_compact_url(&self, auth: &Option) -> Option { - if self.wire_api != WireApi::Responses { - return None; - } - let full = self.get_full_url(auth); - if let Some((path, query)) = full.split_once('?') { - Some(format!("{path}/compact?{query}")) - } else { - Some(format!("{full}/compact")) - } - } - - pub(crate) fn is_azure_responses_endpoint(&self) -> bool { - if self.wire_api != WireApi::Responses { - return false; - } - - if self.name.eq_ignore_ascii_case("azure") { - return true; - } - - self.base_url - .as_ref() - .map(|base| matches_azure_responses_base_url(base)) - .unwrap_or(false) - } - - /// Apply provider-specific HTTP headers (both static and environment-based) - /// onto an existing [`CodexRequestBuilder`] and return the updated - /// builder. - fn apply_http_headers(&self, mut builder: CodexRequestBuilder) -> CodexRequestBuilder { + #[allow(dead_code)] + fn build_header_map(&self) -> crate::error::Result { + let mut headers = HeaderMap::new(); if let Some(extra) = &self.http_headers { for (k, v) in extra { - builder = builder.header(k, v); + if let (Ok(name), Ok(value)) = (HeaderName::try_from(k), HeaderValue::try_from(v)) { + headers.insert(name, value); + } } } @@ -246,12 +114,52 @@ impl ModelProviderInfo { for (header, env_var) in env_headers { if let Ok(val) = std::env::var(env_var) && !val.trim().is_empty() + && let (Ok(name), Ok(value)) = + (HeaderName::try_from(header), HeaderValue::try_from(val)) { - builder = builder.header(header, val); + headers.insert(name, value); } } } - builder + + Ok(headers) + } + + pub(crate) fn to_api_provider( + &self, + auth_mode: Option, + ) -> crate::error::Result { + let default_base_url = if matches!(auth_mode, Some(AuthMode::ChatGPT)) { + "https://chatgpt.com/backend-api/codex" + } else { + "https://api.openai.com/v1" + }; + let base_url = self + .base_url + .clone() + .unwrap_or_else(|| default_base_url.to_string()); + + let headers = self.build_header_map()?; + let retry = ApiRetryConfig { + max_attempts: self.request_max_retries(), + base_delay: Duration::from_millis(200), + retry_429: false, + retry_5xx: true, + retry_transport: true, + }; + + Ok(ApiProvider { + name: self.name.clone(), + base_url, + query_params: self.query_params.clone(), + wire: match self.wire_api { + WireApi::Responses => ApiWireApi::Responses, + WireApi::Chat => ApiWireApi::Chat, + }, + headers, + retry, + stream_idle_timeout: self.stream_idle_timeout(), + }) } /// If `env_key` is Some, returns the API key for this provider if present @@ -409,18 +317,6 @@ pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> M } } -fn matches_azure_responses_base_url(base_url: &str) -> bool { - let base = base_url.to_ascii_lowercase(); - const AZURE_MARKERS: [&str; 5] = [ - "openai.azure.", - "cognitiveservices.azure.", - "aoai.azure.", - "azure-api.", - "azurefd.", - ]; - AZURE_MARKERS.iter().any(|marker| base.contains(marker)) -} - #[cfg(test)] mod tests { use super::*; @@ -517,8 +413,16 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } #[test] fn detects_azure_responses_base_urls() { - fn provider_for(base_url: &str) -> ModelProviderInfo { - ModelProviderInfo { + let positive_cases = [ + "https://foo.openai.azure.com/openai", + "https://foo.openai.azure.us/openai/deployments/bar", + "https://foo.cognitiveservices.azure.cn/openai", + "https://foo.aoai.azure.com/openai", + "https://foo.openai.azure-api.net/openai", + "https://foo.z01.azurefd.net/", + ]; + for base_url in positive_cases { + let provider = ModelProviderInfo { name: "test".into(), base_url: Some(base_url.into()), env_key: None, @@ -532,21 +436,10 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } stream_max_retries: None, stream_idle_timeout_ms: None, requires_openai_auth: false, - } - } - - let positive_cases = [ - "https://foo.openai.azure.com/openai", - "https://foo.openai.azure.us/openai/deployments/bar", - "https://foo.cognitiveservices.azure.cn/openai", - "https://foo.aoai.azure.com/openai", - "https://foo.openai.azure-api.net/openai", - "https://foo.z01.azurefd.net/", - ]; - for base_url in positive_cases { - let provider = provider_for(base_url); + }; + let api = provider.to_api_provider(None).expect("api provider"); assert!( - provider.is_azure_responses_endpoint(), + api.is_azure_responses_endpoint(), "expected {base_url} to be detected as Azure" ); } @@ -566,7 +459,8 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } stream_idle_timeout_ms: None, requires_openai_auth: false, }; - assert!(named_provider.is_azure_responses_endpoint()); + let named_api = named_provider.to_api_provider(None).expect("api provider"); + assert!(named_api.is_azure_responses_endpoint()); let negative_cases = [ "https://api.openai.com/v1", @@ -574,9 +468,24 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } "https://myproxy.azurewebsites.net/openai", ]; for base_url in negative_cases { - let provider = provider_for(base_url); + let provider = ModelProviderInfo { + name: "test".into(), + base_url: Some(base_url.into()), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + }; + let api = provider.to_api_provider(None).expect("api provider"); assert!( - !provider.is_azure_responses_endpoint(), + !api.is_azure_responses_endpoint(), "expected {base_url} not to be detected as Azure" ); } diff --git a/codex-rs/otel/src/otel_event_manager.rs b/codex-rs/otel/src/otel_event_manager.rs index fde351cd6..b6bc07e79 100644 --- a/codex-rs/otel/src/otel_event_manager.rs +++ b/codex-rs/otel/src/otel_event_manager.rs @@ -131,7 +131,18 @@ impl OtelEventManager { Ok(response) => (Some(response.status().as_u16()), None), Err(error) => (error.status().map(|s| s.as_u16()), Some(error.to_string())), }; + self.record_api_request(attempt, status, error.as_deref(), duration); + response + } + + pub fn record_api_request( + &self, + attempt: u64, + status: Option, + error: Option<&str>, + duration: Duration, + ) { tracing::event!( tracing::Level::INFO, event.name = "codex.api_request", @@ -149,8 +160,6 @@ impl OtelEventManager { error.message = error, attempt = attempt, ); - - response } pub fn log_sse_event(