From 903b7774bc0e04c72ee8e61c6e12f7bdbbe7d267 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Thu, 4 Dec 2025 12:57:54 -0800 Subject: [PATCH] Add models endpoint (#7603) - Use the codex-api crate to introduce models endpoint. - Add `models` to codex core tests helpers - Add `ModelsInfo` for the endpoint return type --- codex-rs/Cargo.lock | 2 + codex-rs/codex-api/Cargo.toml | 2 + codex-rs/codex-api/src/endpoint/mod.rs | 1 + codex-rs/codex-api/src/endpoint/models.rs | 216 ++++++++++++++++++ codex-rs/codex-api/src/lib.rs | 1 + .../codex-api/tests/models_integration.rs | 100 ++++++++ codex-rs/core/tests/common/responses.rs | 63 ++++- codex-rs/protocol/src/openai_models.rs | 74 ++++++ 8 files changed, 457 insertions(+), 2 deletions(-) create mode 100644 codex-rs/codex-api/src/endpoint/models.rs create mode 100644 codex-rs/codex-api/tests/models_integration.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 4429858c9..c3b6c27be 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -858,6 +858,7 @@ dependencies = [ "http", "pretty_assertions", "regex-lite", + "reqwest", "serde", "serde_json", "thiserror 2.0.17", @@ -865,6 +866,7 @@ dependencies = [ "tokio-test", "tokio-util", "tracing", + "wiremock", ] [[package]] diff --git a/codex-rs/codex-api/Cargo.toml b/codex-rs/codex-api/Cargo.toml index f79416c96..e9fc78878 100644 --- a/codex-rs/codex-api/Cargo.toml +++ b/codex-rs/codex-api/Cargo.toml @@ -25,6 +25,8 @@ anyhow = { workspace = true } assert_matches = { workspace = true } pretty_assertions = { workspace = true } tokio-test = { workspace = true } +wiremock = { workspace = true } +reqwest = { workspace = true } [lints] workspace = true diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs index 104b4c264..cb0eeb9f2 100644 --- a/codex-rs/codex-api/src/endpoint/mod.rs +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -1,4 +1,5 @@ pub mod chat; pub mod compact; +pub mod models; pub mod responses; mod streaming; diff --git a/codex-rs/codex-api/src/endpoint/models.rs b/codex-rs/codex-api/src/endpoint/models.rs new file mode 100644 index 000000000..fec8d7f29 --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/models.rs @@ -0,0 +1,216 @@ +use crate::auth::AuthProvider; +use crate::auth::add_auth_headers; +use crate::error::ApiError; +use crate::provider::Provider; +use crate::telemetry::run_with_request_telemetry; +use codex_client::HttpTransport; +use codex_client::RequestTelemetry; +use codex_protocol::openai_models::ModelsResponse; +use http::HeaderMap; +use http::Method; +use std::sync::Arc; + +pub struct ModelsClient { + transport: T, + provider: Provider, + auth: A, + request_telemetry: Option>, +} + +impl ModelsClient { + pub fn new(transport: T, provider: Provider, auth: A) -> Self { + Self { + transport, + provider, + auth, + request_telemetry: None, + } + } + + pub fn with_telemetry(mut self, request: Option>) -> Self { + self.request_telemetry = request; + self + } + + fn path(&self) -> &'static str { + "models" + } + + pub async fn list_models( + &self, + client_version: &str, + extra_headers: HeaderMap, + ) -> Result { + let builder = || { + let mut req = self.provider.build_request(Method::GET, self.path()); + req.headers.extend(extra_headers.clone()); + + let separator = if req.url.contains('?') { '&' } else { '?' }; + req.url = format!("{}{}client_version={client_version}", req.url, separator); + + add_auth_headers(&self.auth, req) + }; + + let resp = run_with_request_telemetry( + self.provider.retry.to_policy(), + self.request_telemetry.clone(), + builder, + |req| self.transport.execute(req), + ) + .await?; + + serde_json::from_slice::(&resp.body).map_err(|e| { + ApiError::Stream(format!( + "failed to decode models response: {e}; body: {}", + String::from_utf8_lossy(&resp.body) + )) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::RetryConfig; + use crate::provider::WireApi; + use async_trait::async_trait; + use codex_client::Request; + use codex_client::Response; + use codex_client::StreamResponse; + use codex_client::TransportError; + use http::HeaderMap; + use http::StatusCode; + use pretty_assertions::assert_eq; + use serde_json::json; + use std::sync::Arc; + use std::sync::Mutex; + use std::time::Duration; + + #[derive(Clone, Default)] + struct CapturingTransport { + last_request: Arc>>, + body: Arc, + } + + #[async_trait] + impl HttpTransport for CapturingTransport { + async fn execute(&self, req: Request) -> Result { + *self.last_request.lock().unwrap() = Some(req); + let body = serde_json::to_vec(&*self.body).unwrap(); + Ok(Response { + status: StatusCode::OK, + headers: HeaderMap::new(), + body: body.into(), + }) + } + + async fn stream(&self, _req: Request) -> Result { + Err(TransportError::Build("stream should not run".to_string())) + } + } + + #[derive(Clone, Default)] + struct DummyAuth; + + impl AuthProvider for DummyAuth { + fn bearer_token(&self) -> Option { + None + } + } + + fn provider(base_url: &str) -> Provider { + Provider { + name: "test".to_string(), + base_url: base_url.to_string(), + query_params: None, + wire: WireApi::Responses, + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: true, + retry_transport: true, + }, + stream_idle_timeout: Duration::from_secs(1), + } + } + + #[tokio::test] + async fn appends_client_version_query() { + let response = ModelsResponse { models: Vec::new() }; + + let transport = CapturingTransport { + last_request: Arc::new(Mutex::new(None)), + body: Arc::new(response), + }; + + let client = ModelsClient::new( + transport.clone(), + provider("https://example.com/api/codex"), + DummyAuth, + ); + + let result = client + .list_models("0.99.0", HeaderMap::new()) + .await + .expect("request should succeed"); + + assert_eq!(result.models.len(), 0); + + let url = transport + .last_request + .lock() + .unwrap() + .as_ref() + .unwrap() + .url + .clone(); + assert_eq!( + url, + "https://example.com/api/codex/models?client_version=0.99.0" + ); + } + + #[tokio::test] + async fn parses_models_response() { + let response = ModelsResponse { + models: vec![ + serde_json::from_value(json!({ + "slug": "gpt-test", + "display_name": "gpt-test", + "description": "desc", + "default_reasoning_level": "medium", + "supported_reasoning_levels": ["low", "medium", "high"], + "shell_type": "shell_command", + "visibility": "list", + "minimal_client_version": [0, 99, 0], + "supported_in_api": true, + "priority": 1 + })) + .unwrap(), + ], + }; + + let transport = CapturingTransport { + last_request: Arc::new(Mutex::new(None)), + body: Arc::new(response), + }; + + let client = ModelsClient::new( + transport, + provider("https://example.com/api/codex"), + DummyAuth, + ); + + let result = client + .list_models("0.99.0", HeaderMap::new()) + .await + .expect("request should succeed"); + + assert_eq!(result.models.len(), 1); + assert_eq!(result.models[0].slug, "gpt-test"); + assert_eq!(result.models[0].supported_in_api, true); + assert_eq!(result.models[0].priority, 1); + } +} diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index acde4b458..d0c382ac8 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -22,6 +22,7 @@ pub use crate::common::create_text_param_for_request; pub use crate::endpoint::chat::AggregateStreamExt; pub use crate::endpoint::chat::ChatClient; pub use crate::endpoint::compact::CompactClient; +pub use crate::endpoint::models::ModelsClient; pub use crate::endpoint::responses::ResponsesClient; pub use crate::endpoint::responses::ResponsesOptions; pub use crate::error::ApiError; diff --git a/codex-rs/codex-api/tests/models_integration.rs b/codex-rs/codex-api/tests/models_integration.rs new file mode 100644 index 000000000..9994fe1d4 --- /dev/null +++ b/codex-rs/codex-api/tests/models_integration.rs @@ -0,0 +1,100 @@ +use codex_api::AuthProvider; +use codex_api::ModelsClient; +use codex_api::provider::Provider; +use codex_api::provider::RetryConfig; +use codex_api::provider::WireApi; +use codex_client::ReqwestTransport; +use codex_protocol::openai_models::ClientVersion; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelVisibility; +use codex_protocol::openai_models::ModelsResponse; +use codex_protocol::openai_models::ReasoningLevel; +use codex_protocol::openai_models::ShellType; +use http::HeaderMap; +use http::Method; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +#[derive(Clone, Default)] +struct DummyAuth; + +impl AuthProvider for DummyAuth { + fn bearer_token(&self) -> Option { + None + } +} + +fn provider(base_url: &str) -> Provider { + Provider { + name: "test".to_string(), + base_url: base_url.to_string(), + query_params: None, + wire: WireApi::Responses, + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: std::time::Duration::from_millis(1), + retry_429: false, + retry_5xx: true, + retry_transport: true, + }, + stream_idle_timeout: std::time::Duration::from_secs(1), + } +} + +#[tokio::test] +async fn models_client_hits_models_endpoint() { + let server = MockServer::start().await; + let base_url = format!("{}/api/codex", server.uri()); + + let response = ModelsResponse { + models: vec![ModelInfo { + slug: "gpt-test".to_string(), + display_name: "gpt-test".to_string(), + description: Some("desc".to_string()), + default_reasoning_level: ReasoningLevel::Medium, + supported_reasoning_levels: vec![ + ReasoningLevel::Low, + ReasoningLevel::Medium, + ReasoningLevel::High, + ], + shell_type: ShellType::ShellCommand, + visibility: ModelVisibility::List, + minimal_client_version: ClientVersion(0, 1, 0), + supported_in_api: true, + priority: 1, + }], + }; + + Mock::given(method("GET")) + .and(path("/api/codex/models")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "application/json") + .set_body_json(&response), + ) + .mount(&server) + .await; + + let transport = ReqwestTransport::new(reqwest::Client::new()); + let client = ModelsClient::new(transport, provider(&base_url), DummyAuth); + + let result = client + .list_models("0.1.0", HeaderMap::new()) + .await + .expect("models request should succeed"); + + assert_eq!(result.models.len(), 1); + assert_eq!(result.models[0].slug, "gpt-test"); + + let received = server + .received_requests() + .await + .expect("should capture requests"); + assert_eq!(received.len(), 1); + assert_eq!(received[0].method, Method::GET.as_str()); + assert_eq!(received[0].url.path(), "/api/codex/models"); +} diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index a8209b513..e42b4ac94 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -3,6 +3,7 @@ use std::sync::Mutex; use anyhow::Result; use base64::Engine; +use codex_protocol::openai_models::ModelsResponse; use serde_json::Value; use wiremock::BodyPrintLimit; use wiremock::Match; @@ -193,6 +194,38 @@ impl ResponsesRequest { } } +#[derive(Debug, Clone)] +pub struct ModelsMock { + requests: Arc>>, +} + +impl ModelsMock { + fn new() -> Self { + Self { + requests: Arc::new(Mutex::new(Vec::new())), + } + } + + pub fn requests(&self) -> Vec { + self.requests.lock().unwrap().clone() + } + + pub fn single_request_path(&self) -> String { + let requests = self.requests.lock().unwrap(); + if requests.len() != 1 { + panic!("expected 1 request, got {}", requests.len()); + } + requests.first().unwrap().url.path().to_string() + } +} + +impl Match for ModelsMock { + fn matches(&self, request: &wiremock::Request) -> bool { + self.requests.lock().unwrap().push(request.clone()); + true + } +} + impl Match for ResponseMock { fn matches(&self, request: &wiremock::Request) -> bool { self.requests @@ -560,6 +593,14 @@ fn compact_mock() -> (MockBuilder, ResponseMock) { (mock, response_mock) } +fn models_mock() -> (MockBuilder, ModelsMock) { + let models_mock = ModelsMock::new(); + let mock = Mock::given(method("GET")) + .and(path_regex(".*/models$")) + .and(models_mock.clone()); + (mock, models_mock) +} + pub async fn mount_sse_once_match(server: &MockServer, matcher: M, body: String) -> ResponseMock where M: wiremock::Match + Send + Sync + 'static, @@ -616,11 +657,29 @@ pub async fn mount_compact_json_once(server: &MockServer, body: serde_json::Valu response_mock } +pub async fn mount_models_once(server: &MockServer, body: ModelsResponse) -> ModelsMock { + let (mock, models_mock) = models_mock(); + mock.respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "application/json") + .set_body_json(body.clone()), + ) + .up_to_n_times(1) + .mount(server) + .await; + models_mock +} + pub async fn start_mock_server() -> MockServer { - MockServer::builder() + let server = MockServer::builder() .body_print_limit(BodyPrintLimit::Limited(80_000)) .start() - .await + .await; + + // Provide a default `/models` response so tests remain hermetic when the client queries it. + let _ = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await; + + server } #[derive(Clone)] diff --git a/codex-rs/protocol/src/openai_models.rs b/codex-rs/protocol/src/openai_models.rs index b99c3bbde..92dcf3e4c 100644 --- a/codex-rs/protocol/src/openai_models.rs +++ b/codex-rs/protocol/src/openai_models.rs @@ -73,3 +73,77 @@ pub struct ModelPreset { /// Whether this preset should appear in the picker UI. pub show_in_picker: bool, } + +/// Visibility of a model in the picker or APIs. +#[derive( + Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display, +)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum ModelVisibility { + List, + Hide, + None, +} + +/// Reasoning support level reported by the backend. +#[derive( + Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display, +)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum ReasoningLevel { + None, + Minimal, + Low, + Medium, + High, + XHigh, +} + +/// Shell execution capability for a model. +#[derive( + Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum ShellType { + Default, + Local, + UnifiedExec, + Disabled, + ShellCommand, +} + +/// Semantic version triple encoded as an array in JSON (e.g. [0, 62, 0]). +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema)] +pub struct ClientVersion(pub i32, pub i32, pub i32); + +/// Model metadata returned by the Codex backend `/models` endpoint. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema)] +pub struct ModelInfo { + pub slug: String, + pub display_name: String, + #[serde(default)] + pub description: Option, + pub default_reasoning_level: ReasoningLevel, + pub supported_reasoning_levels: Vec, + pub shell_type: ShellType, + #[serde(default = "default_visibility")] + pub visibility: ModelVisibility, + pub minimal_client_version: ClientVersion, + #[serde(default)] + pub supported_in_api: bool, + #[serde(default)] + pub priority: i32, +} + +/// Response wrapper for `/models`. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema, Default)] +pub struct ModelsResponse { + pub models: Vec, +} + +fn default_visibility() -> ModelVisibility { + ModelVisibility::None +}