Add models endpoint (#7603)
- Use the codex-api crate to introduce models endpoint. - Add `models` to codex core tests helpers - Add `ModelsInfo` for the endpoint return type
This commit is contained in:
parent
6e6338aa87
commit
903b7774bc
8 changed files with 457 additions and 2 deletions
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
|
|
@ -858,6 +858,7 @@ dependencies = [
|
|||
"http",
|
||||
"pretty_assertions",
|
||||
"regex-lite",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
|
|
@ -865,6 +866,7 @@ dependencies = [
|
|||
"tokio-test",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ anyhow = { workspace = true }
|
|||
assert_matches = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
pub mod chat;
|
||||
pub mod compact;
|
||||
pub mod models;
|
||||
pub mod responses;
|
||||
mod streaming;
|
||||
|
|
|
|||
216
codex-rs/codex-api/src/endpoint/models.rs
Normal file
216
codex-rs/codex-api/src/endpoint/models.rs
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
use crate::auth::AuthProvider;
|
||||
use crate::auth::add_auth_headers;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::telemetry::run_with_request_telemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ModelsClient<T: HttpTransport, A: AuthProvider> {
|
||||
transport: T,
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
|
||||
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
provider,
|
||||
auth,
|
||||
request_telemetry: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
|
||||
self.request_telemetry = request;
|
||||
self
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
"models"
|
||||
}
|
||||
|
||||
pub async fn list_models(
|
||||
&self,
|
||||
client_version: &str,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ModelsResponse, ApiError> {
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::GET, self.path());
|
||||
req.headers.extend(extra_headers.clone());
|
||||
|
||||
let separator = if req.url.contains('?') { '&' } else { '?' };
|
||||
req.url = format!("{}{}client_version={client_version}", req.url, separator);
|
||||
|
||||
add_auth_headers(&self.auth, req)
|
||||
};
|
||||
|
||||
let resp = run_with_request_telemetry(
|
||||
self.provider.retry.to_policy(),
|
||||
self.request_telemetry.clone(),
|
||||
builder,
|
||||
|req| self.transport.execute(req),
|
||||
)
|
||||
.await?;
|
||||
|
||||
serde_json::from_slice::<ModelsResponse>(&resp.body).map_err(|e| {
|
||||
ApiError::Stream(format!(
|
||||
"failed to decode models response: {e}; body: {}",
|
||||
String::from_utf8_lossy(&resp.body)
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use async_trait::async_trait;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use http::HeaderMap;
|
||||
use http::StatusCode;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct CapturingTransport {
|
||||
last_request: Arc<Mutex<Option<Request>>>,
|
||||
body: Arc<ModelsResponse>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for CapturingTransport {
|
||||
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
|
||||
*self.last_request.lock().unwrap() = Some(req);
|
||||
let body = serde_json::to_vec(&*self.body).unwrap();
|
||||
Ok(Response {
|
||||
status: StatusCode::OK,
|
||||
headers: HeaderMap::new(),
|
||||
body: body.into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
|
||||
Err(TransportError::Build("stream should not run".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct DummyAuth;
|
||||
|
||||
impl AuthProvider for DummyAuth {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn provider(base_url: &str) -> Provider {
|
||||
Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: base_url.to_string(),
|
||||
query_params: None,
|
||||
wire: WireApi::Responses,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn appends_client_version_query() {
|
||||
let response = ModelsResponse { models: Vec::new() };
|
||||
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
transport.clone(),
|
||||
provider("https://example.com/api/codex"),
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let result = client
|
||||
.list_models("0.99.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 0);
|
||||
|
||||
let url = transport
|
||||
.last_request
|
||||
.lock()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.url
|
||||
.clone();
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://example.com/api/codex/models?client_version=0.99.0"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parses_models_response() {
|
||||
let response = ModelsResponse {
|
||||
models: vec![
|
||||
serde_json::from_value(json!({
|
||||
"slug": "gpt-test",
|
||||
"display_name": "gpt-test",
|
||||
"description": "desc",
|
||||
"default_reasoning_level": "medium",
|
||||
"supported_reasoning_levels": ["low", "medium", "high"],
|
||||
"shell_type": "shell_command",
|
||||
"visibility": "list",
|
||||
"minimal_client_version": [0, 99, 0],
|
||||
"supported_in_api": true,
|
||||
"priority": 1
|
||||
}))
|
||||
.unwrap(),
|
||||
],
|
||||
};
|
||||
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
transport,
|
||||
provider("https://example.com/api/codex"),
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let result = client
|
||||
.list_models("0.99.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 1);
|
||||
assert_eq!(result.models[0].slug, "gpt-test");
|
||||
assert_eq!(result.models[0].supported_in_api, true);
|
||||
assert_eq!(result.models[0].priority, 1);
|
||||
}
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ pub use crate::common::create_text_param_for_request;
|
|||
pub use crate::endpoint::chat::AggregateStreamExt;
|
||||
pub use crate::endpoint::chat::ChatClient;
|
||||
pub use crate::endpoint::compact::CompactClient;
|
||||
pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
pub use crate::endpoint::responses::ResponsesOptions;
|
||||
pub use crate::error::ApiError;
|
||||
|
|
|
|||
100
codex-rs/codex-api/tests/models_integration.rs
Normal file
100
codex-rs/codex-api/tests/models_integration.rs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
use codex_api::AuthProvider;
|
||||
use codex_api::ModelsClient;
|
||||
use codex_api::provider::Provider;
|
||||
use codex_api::provider::RetryConfig;
|
||||
use codex_api::provider::WireApi;
|
||||
use codex_client::ReqwestTransport;
|
||||
use codex_protocol::openai_models::ClientVersion;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::openai_models::ModelVisibility;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use codex_protocol::openai_models::ReasoningLevel;
|
||||
use codex_protocol::openai_models::ShellType;
|
||||
use http::HeaderMap;
|
||||
use http::Method;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct DummyAuth;
|
||||
|
||||
impl AuthProvider for DummyAuth {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn provider(base_url: &str) -> Provider {
|
||||
Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: base_url.to_string(),
|
||||
query_params: None,
|
||||
wire: WireApi::Responses,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: std::time::Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: std::time::Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn models_client_hits_models_endpoint() {
|
||||
let server = MockServer::start().await;
|
||||
let base_url = format!("{}/api/codex", server.uri());
|
||||
|
||||
let response = ModelsResponse {
|
||||
models: vec![ModelInfo {
|
||||
slug: "gpt-test".to_string(),
|
||||
display_name: "gpt-test".to_string(),
|
||||
description: Some("desc".to_string()),
|
||||
default_reasoning_level: ReasoningLevel::Medium,
|
||||
supported_reasoning_levels: vec![
|
||||
ReasoningLevel::Low,
|
||||
ReasoningLevel::Medium,
|
||||
ReasoningLevel::High,
|
||||
],
|
||||
shell_type: ShellType::ShellCommand,
|
||||
visibility: ModelVisibility::List,
|
||||
minimal_client_version: ClientVersion(0, 1, 0),
|
||||
supported_in_api: true,
|
||||
priority: 1,
|
||||
}],
|
||||
};
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/api/codex/models"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "application/json")
|
||||
.set_body_json(&response),
|
||||
)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let transport = ReqwestTransport::new(reqwest::Client::new());
|
||||
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);
|
||||
|
||||
let result = client
|
||||
.list_models("0.1.0", HeaderMap::new())
|
||||
.await
|
||||
.expect("models request should succeed");
|
||||
|
||||
assert_eq!(result.models.len(), 1);
|
||||
assert_eq!(result.models[0].slug, "gpt-test");
|
||||
|
||||
let received = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("should capture requests");
|
||||
assert_eq!(received.len(), 1);
|
||||
assert_eq!(received[0].method, Method::GET.as_str());
|
||||
assert_eq!(received[0].url.path(), "/api/codex/models");
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ use std::sync::Mutex;
|
|||
|
||||
use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use serde_json::Value;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Match;
|
||||
|
|
@ -193,6 +194,38 @@ impl ResponsesRequest {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelsMock {
|
||||
requests: Arc<Mutex<Vec<wiremock::Request>>>,
|
||||
}
|
||||
|
||||
impl ModelsMock {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
requests: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn requests(&self) -> Vec<wiremock::Request> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn single_request_path(&self) -> String {
|
||||
let requests = self.requests.lock().unwrap();
|
||||
if requests.len() != 1 {
|
||||
panic!("expected 1 request, got {}", requests.len());
|
||||
}
|
||||
requests.first().unwrap().url.path().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Match for ModelsMock {
|
||||
fn matches(&self, request: &wiremock::Request) -> bool {
|
||||
self.requests.lock().unwrap().push(request.clone());
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Match for ResponseMock {
|
||||
fn matches(&self, request: &wiremock::Request) -> bool {
|
||||
self.requests
|
||||
|
|
@ -560,6 +593,14 @@ fn compact_mock() -> (MockBuilder, ResponseMock) {
|
|||
(mock, response_mock)
|
||||
}
|
||||
|
||||
fn models_mock() -> (MockBuilder, ModelsMock) {
|
||||
let models_mock = ModelsMock::new();
|
||||
let mock = Mock::given(method("GET"))
|
||||
.and(path_regex(".*/models$"))
|
||||
.and(models_mock.clone());
|
||||
(mock, models_mock)
|
||||
}
|
||||
|
||||
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String) -> ResponseMock
|
||||
where
|
||||
M: wiremock::Match + Send + Sync + 'static,
|
||||
|
|
@ -616,11 +657,29 @@ pub async fn mount_compact_json_once(server: &MockServer, body: serde_json::Valu
|
|||
response_mock
|
||||
}
|
||||
|
||||
pub async fn mount_models_once(server: &MockServer, body: ModelsResponse) -> ModelsMock {
|
||||
let (mock, models_mock) = models_mock();
|
||||
mock.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "application/json")
|
||||
.set_body_json(body.clone()),
|
||||
)
|
||||
.up_to_n_times(1)
|
||||
.mount(server)
|
||||
.await;
|
||||
models_mock
|
||||
}
|
||||
|
||||
pub async fn start_mock_server() -> MockServer {
|
||||
MockServer::builder()
|
||||
let server = MockServer::builder()
|
||||
.body_print_limit(BodyPrintLimit::Limited(80_000))
|
||||
.start()
|
||||
.await
|
||||
.await;
|
||||
|
||||
// Provide a default `/models` response so tests remain hermetic when the client queries it.
|
||||
let _ = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await;
|
||||
|
||||
server
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
|
|||
|
|
@ -73,3 +73,77 @@ pub struct ModelPreset {
|
|||
/// Whether this preset should appear in the picker UI.
|
||||
pub show_in_picker: bool,
|
||||
}
|
||||
|
||||
/// Visibility of a model in the picker or APIs.
|
||||
#[derive(
|
||||
Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display,
|
||||
)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum ModelVisibility {
|
||||
List,
|
||||
Hide,
|
||||
None,
|
||||
}
|
||||
|
||||
/// Reasoning support level reported by the backend.
|
||||
#[derive(
|
||||
Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display,
|
||||
)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum ReasoningLevel {
|
||||
None,
|
||||
Minimal,
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
XHigh,
|
||||
}
|
||||
|
||||
/// Shell execution capability for a model.
|
||||
#[derive(
|
||||
Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema, EnumIter, Display,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum ShellType {
|
||||
Default,
|
||||
Local,
|
||||
UnifiedExec,
|
||||
Disabled,
|
||||
ShellCommand,
|
||||
}
|
||||
|
||||
/// Semantic version triple encoded as an array in JSON (e.g. [0, 62, 0]).
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, TS, JsonSchema)]
|
||||
pub struct ClientVersion(pub i32, pub i32, pub i32);
|
||||
|
||||
/// Model metadata returned by the Codex backend `/models` endpoint.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema)]
|
||||
pub struct ModelInfo {
|
||||
pub slug: String,
|
||||
pub display_name: String,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub default_reasoning_level: ReasoningLevel,
|
||||
pub supported_reasoning_levels: Vec<ReasoningLevel>,
|
||||
pub shell_type: ShellType,
|
||||
#[serde(default = "default_visibility")]
|
||||
pub visibility: ModelVisibility,
|
||||
pub minimal_client_version: ClientVersion,
|
||||
#[serde(default)]
|
||||
pub supported_in_api: bool,
|
||||
#[serde(default)]
|
||||
pub priority: i32,
|
||||
}
|
||||
|
||||
/// Response wrapper for `/models`.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema, Default)]
|
||||
pub struct ModelsResponse {
|
||||
pub models: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
fn default_visibility() -> ModelVisibility {
|
||||
ModelVisibility::None
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue