Add models endpoint (#7603)

- Use the codex-api crate to introduce models endpoint. 
- Add `models` to codex core tests helpers
- Add `ModelsInfo` for the endpoint return type
This commit is contained in:
Ahmed Ibrahim 2025-12-04 12:57:54 -08:00 committed by GitHub
parent 6e6338aa87
commit 903b7774bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 457 additions and 2 deletions

2
codex-rs/Cargo.lock generated
View file

@ -858,6 +858,7 @@ dependencies = [
"http",
"pretty_assertions",
"regex-lite",
"reqwest",
"serde",
"serde_json",
"thiserror 2.0.17",
@ -865,6 +866,7 @@ dependencies = [
"tokio-test",
"tokio-util",
"tracing",
"wiremock",
]
[[package]]

View file

@ -25,6 +25,8 @@ anyhow = { workspace = true }
assert_matches = { workspace = true }
pretty_assertions = { workspace = true }
tokio-test = { workspace = true }
wiremock = { workspace = true }
reqwest = { workspace = true }
[lints]
workspace = true

View file

@ -1,4 +1,5 @@
pub mod chat;
pub mod compact;
pub mod models;
pub mod responses;
mod streaming;

View file

@ -0,0 +1,216 @@
use crate::auth::AuthProvider;
use crate::auth::add_auth_headers;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::Method;
use std::sync::Arc;
pub struct ModelsClient<T: HttpTransport, A: AuthProvider> {
transport: T,
provider: Provider,
auth: A,
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
}
impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
transport,
provider,
auth,
request_telemetry: None,
}
}
pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
self.request_telemetry = request;
self
}
fn path(&self) -> &'static str {
"models"
}
pub async fn list_models(
&self,
client_version: &str,
extra_headers: HeaderMap,
) -> Result<ModelsResponse, ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::GET, self.path());
req.headers.extend(extra_headers.clone());
let separator = if req.url.contains('?') { '&' } else { '?' };
req.url = format!("{}{}client_version={client_version}", req.url, separator);
add_auth_headers(&self.auth, req)
};
let resp = run_with_request_telemetry(
self.provider.retry.to_policy(),
self.request_telemetry.clone(),
builder,
|req| self.transport.execute(req),
)
.await?;
serde_json::from_slice::<ModelsResponse>(&resp.body).map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
String::from_utf8_lossy(&resp.body)
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use crate::provider::WireApi;
use async_trait::async_trait;
use codex_client::Request;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
use http::HeaderMap;
use http::StatusCode;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
#[derive(Clone, Default)]
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
}
#[async_trait]
impl HttpTransport for CapturingTransport {
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
*self.last_request.lock().unwrap() = Some(req);
let body = serde_json::to_vec(&*self.body).unwrap();
Ok(Response {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: body.into(),
})
}
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
Err(TransportError::Build("stream should not run".to_string()))
}
}
#[derive(Clone, Default)]
struct DummyAuth;
impl AuthProvider for DummyAuth {
fn bearer_token(&self) -> Option<String> {
None
}
}
fn provider(base_url: &str) -> Provider {
Provider {
name: "test".to_string(),
base_url: base_url.to_string(),
query_params: None,
wire: WireApi::Responses,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: Duration::from_secs(1),
}
}
#[tokio::test]
async fn appends_client_version_query() {
let response = ModelsResponse { models: Vec::new() };
let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
};
let client = ModelsClient::new(
transport.clone(),
provider("https://example.com/api/codex"),
DummyAuth,
);
let result = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");
assert_eq!(result.models.len(), 0);
let url = transport
.last_request
.lock()
.unwrap()
.as_ref()
.unwrap()
.url
.clone();
assert_eq!(
url,
"https://example.com/api/codex/models?client_version=0.99.0"
);
}
#[tokio::test]
async fn parses_models_response() {
let response = ModelsResponse {
models: vec![
serde_json::from_value(json!({
"slug": "gpt-test",
"display_name": "gpt-test",
"description": "desc",
"default_reasoning_level": "medium",
"supported_reasoning_levels": ["low", "medium", "high"],
"shell_type": "shell_command",
"visibility": "list",
"minimal_client_version": [0, 99, 0],
"supported_in_api": true,
"priority": 1
}))
.unwrap(),
],
};
let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
};
let client = ModelsClient::new(
transport,
provider("https://example.com/api/codex"),
DummyAuth,
);
let result = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");
assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(result.models[0].supported_in_api, true);
assert_eq!(result.models[0].priority, 1);
}
}

View file

@ -22,6 +22,7 @@ pub use crate::common::create_text_param_for_request;
pub use crate::endpoint::chat::AggregateStreamExt;
pub use crate::endpoint::chat::ChatClient;
pub use crate::endpoint::compact::CompactClient;
pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::responses::ResponsesClient;
pub use crate::endpoint::responses::ResponsesOptions;
pub use crate::error::ApiError;

View file

@ -0,0 +1,100 @@
use codex_api::AuthProvider;
use codex_api::ModelsClient;
use codex_api::provider::Provider;
use codex_api::provider::RetryConfig;
use codex_api::provider::WireApi;
use codex_client::ReqwestTransport;
use codex_protocol::openai_models::ClientVersion;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelVisibility;
use codex_protocol::openai_models::ModelsResponse;
use codex_protocol::openai_models::ReasoningLevel;
use codex_protocol::openai_models::ShellType;
use http::HeaderMap;
use http::Method;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
#[derive(Clone, Default)]
struct DummyAuth;
impl AuthProvider for DummyAuth {
fn bearer_token(&self) -> Option<String> {
None
}
}
fn provider(base_url: &str) -> Provider {
Provider {
name: "test".to_string(),
base_url: base_url.to_string(),
query_params: None,
wire: WireApi::Responses,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: std::time::Duration::from_millis(1),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: std::time::Duration::from_secs(1),
}
}
#[tokio::test]
async fn models_client_hits_models_endpoint() {
let server = MockServer::start().await;
let base_url = format!("{}/api/codex", server.uri());
let response = ModelsResponse {
models: vec![ModelInfo {
slug: "gpt-test".to_string(),
display_name: "gpt-test".to_string(),
description: Some("desc".to_string()),
default_reasoning_level: ReasoningLevel::Medium,
supported_reasoning_levels: vec![
ReasoningLevel::Low,
ReasoningLevel::Medium,
ReasoningLevel::High,
],
shell_type: ShellType::ShellCommand,
visibility: ModelVisibility::List,
minimal_client_version: ClientVersion(0, 1, 0),
supported_in_api: true,
priority: 1,
}],
};
Mock::given(method("GET"))
.and(path("/api/codex/models"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "application/json")
.set_body_json(&response),
)
.mount(&server)
.await;
let transport = ReqwestTransport::new(reqwest::Client::new());
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);
let result = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("models request should succeed");
assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
let received = server
.received_requests()
.await
.expect("should capture requests");
assert_eq!(received.len(), 1);
assert_eq!(received[0].method, Method::GET.as_str());
assert_eq!(received[0].url.path(), "/api/codex/models");
}

View file

@ -3,6 +3,7 @@ use std::sync::Mutex;
use anyhow::Result;
use base64::Engine;
use codex_protocol::openai_models::ModelsResponse;
use serde_json::Value;
use wiremock::BodyPrintLimit;
use wiremock::Match;
@ -193,6 +194,38 @@ impl ResponsesRequest {
}
}
#[derive(Debug, Clone)]
pub struct ModelsMock {
requests: Arc<Mutex<Vec<wiremock::Request>>>,
}
impl ModelsMock {
fn new() -> Self {
Self {
requests: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn requests(&self) -> Vec<wiremock::Request> {
self.requests.lock().unwrap().clone()
}
pub fn single_request_path(&self) -> String {
let requests = self.requests.lock().unwrap();
if requests.len() != 1 {
panic!("expected 1 request, got {}", requests.len());
}
requests.first().unwrap().url.path().to_string()
}
}
impl Match for ModelsMock {
fn matches(&self, request: &wiremock::Request) -> bool {
self.requests.lock().unwrap().push(request.clone());
true
}
}
impl Match for ResponseMock {
fn matches(&self, request: &wiremock::Request) -> bool {
self.requests
@ -560,6 +593,14 @@ fn compact_mock() -> (MockBuilder, ResponseMock) {
(mock, response_mock)
}
fn models_mock() -> (MockBuilder, ModelsMock) {
let models_mock = ModelsMock::new();
let mock = Mock::given(method("GET"))
.and(path_regex(".*/models$"))
.and(models_mock.clone());
(mock, models_mock)
}
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String) -> ResponseMock
where
M: wiremock::Match + Send + Sync + 'static,
@ -616,11 +657,29 @@ pub async fn mount_compact_json_once(server: &MockServer, body: serde_json::Valu
response_mock
}
pub async fn mount_models_once(server: &MockServer, body: ModelsResponse) -> ModelsMock {
let (mock, models_mock) = models_mock();
mock.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "application/json")
.set_body_json(body.clone()),
)
.up_to_n_times(1)
.mount(server)
.await;
models_mock
}
pub async fn start_mock_server() -> MockServer {
MockServer::builder()
let server = MockServer::builder()
.body_print_limit(BodyPrintLimit::Limited(80_000))
.start()
.await
.await;
// Provide a default `/models` response so tests remain hermetic when the client queries it.
let _ = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await;
server
}
#[derive(Clone)]

View file

@ -73,3 +73,77 @@ pub struct ModelPreset {
/// Whether this preset should appear in the picker UI.
pub show_in_picker: bool,
}
/// Visibility of a model in the picker or APIs.
#[derive(
Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display,
)]
#[serde(rename_all = "lowercase")]
#[strum(serialize_all = "lowercase")]
pub enum ModelVisibility {
List,
Hide,
None,
}
/// Reasoning support level reported by the backend.
#[derive(
Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display,
)]
#[serde(rename_all = "lowercase")]
#[strum(serialize_all = "lowercase")]
pub enum ReasoningLevel {
None,
Minimal,
Low,
Medium,
High,
XHigh,
}
/// Shell execution capability for a model.
#[derive(
Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display,
)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum ShellType {
Default,
Local,
UnifiedExec,
Disabled,
ShellCommand,
}
/// Semantic version triple encoded as an array in JSON (e.g. [0, 62, 0]).
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema)]
pub struct ClientVersion(pub i32, pub i32, pub i32);
/// Model metadata returned by the Codex backend `/models` endpoint.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema)]
pub struct ModelInfo {
pub slug: String,
pub display_name: String,
#[serde(default)]
pub description: Option<String>,
pub default_reasoning_level: ReasoningLevel,
pub supported_reasoning_levels: Vec<ReasoningLevel>,
pub shell_type: ShellType,
#[serde(default = "default_visibility")]
pub visibility: ModelVisibility,
pub minimal_client_version: ClientVersion,
#[serde(default)]
pub supported_in_api: bool,
#[serde(default)]
pub priority: i32,
}
/// Response wrapper for `/models`.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema, Default)]
pub struct ModelsResponse {
pub models: Vec<ModelInfo>,
}
fn default_visibility() -> ModelVisibility {
ModelVisibility::None
}