## Overview Adds LM Studio OSS support. Closes #1883 ### Changes This PR enhances the behavior of `--oss` flag to support LM Studio as a provider. Additionally, it introduces a new flag`--local-provider` which can take in `lmstudio` or `ollama` as values if the user wants to explicitly choose which one to use. If no provider is specified `codex --oss` will auto-select the provider based on whichever is running. #### Additional enhancements The default can be set using `oss-provider` in config like: ``` oss_provider = "lmstudio" ``` For non-interactive users, they will need to either provide the provider as an arg or have it in their `config.toml` ### Notes For best performance, [set the default context length](https://lmstudio.ai/docs/app/advanced/per-model) for gpt-oss to the maximum your machine can support --------- Co-authored-by: Matt Clayton <matt@lmstudio.ai> Co-authored-by: Eric Traut <etraut@openai.com>
397 lines
13 KiB
Rust
397 lines
13 KiB
Rust
use codex_core::LMSTUDIO_OSS_PROVIDER_ID;
|
|
use codex_core::config::Config;
|
|
use std::io;
|
|
use std::path::Path;
|
|
|
|
#[derive(Clone)]
|
|
pub struct LMStudioClient {
|
|
client: reqwest::Client,
|
|
base_url: String,
|
|
}
|
|
|
|
const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'.";
|
|
|
|
impl LMStudioClient {
|
|
pub async fn try_from_provider(config: &Config) -> std::io::Result<Self> {
|
|
let provider = config
|
|
.model_providers
|
|
.get(LMSTUDIO_OSS_PROVIDER_ID)
|
|
.ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::NotFound,
|
|
format!("Built-in provider {LMSTUDIO_OSS_PROVIDER_ID} not found",),
|
|
)
|
|
})?;
|
|
let base_url = provider.base_url.as_ref().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidData,
|
|
"oss provider must have a base_url",
|
|
)
|
|
})?;
|
|
|
|
let client = reqwest::Client::builder()
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
.build()
|
|
.unwrap_or_else(|_| reqwest::Client::new());
|
|
|
|
let client = LMStudioClient {
|
|
client,
|
|
base_url: base_url.to_string(),
|
|
};
|
|
client.check_server().await?;
|
|
|
|
Ok(client)
|
|
}
|
|
|
|
async fn check_server(&self) -> io::Result<()> {
|
|
let url = format!("{}/models", self.base_url.trim_end_matches('/'));
|
|
let response = self.client.get(&url).send().await;
|
|
|
|
if let Ok(resp) = response {
|
|
if resp.status().is_success() {
|
|
Ok(())
|
|
} else {
|
|
Err(io::Error::other(format!(
|
|
"Server returned error: {} {LMSTUDIO_CONNECTION_ERROR}",
|
|
resp.status()
|
|
)))
|
|
}
|
|
} else {
|
|
Err(io::Error::other(LMSTUDIO_CONNECTION_ERROR))
|
|
}
|
|
}
|
|
|
|
// Load a model by sending an empty request with max_tokens 1
|
|
pub async fn load_model(&self, model: &str) -> io::Result<()> {
|
|
let url = format!("{}/responses", self.base_url.trim_end_matches('/'));
|
|
|
|
let request_body = serde_json::json!({
|
|
"model": model,
|
|
"input": "",
|
|
"max_output_tokens": 1
|
|
});
|
|
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.header("Content-Type", "application/json")
|
|
.json(&request_body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
|
|
|
|
if response.status().is_success() {
|
|
tracing::info!("Successfully loaded model '{model}'");
|
|
Ok(())
|
|
} else {
|
|
Err(io::Error::other(format!(
|
|
"Failed to load model: {}",
|
|
response.status()
|
|
)))
|
|
}
|
|
}
|
|
|
|
// Return the list of models available on the LM Studio server.
|
|
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
|
|
let url = format!("{}/models", self.base_url.trim_end_matches('/'));
|
|
let response = self
|
|
.client
|
|
.get(&url)
|
|
.send()
|
|
.await
|
|
.map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
|
|
|
|
if response.status().is_success() {
|
|
let json: serde_json::Value = response.json().await.map_err(|e| {
|
|
io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}"))
|
|
})?;
|
|
let models = json["data"]
|
|
.as_array()
|
|
.ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidData, "No 'data' array in response")
|
|
})?
|
|
.iter()
|
|
.filter_map(|model| model["id"].as_str())
|
|
.map(std::string::ToString::to_string)
|
|
.collect();
|
|
Ok(models)
|
|
} else {
|
|
Err(io::Error::other(format!(
|
|
"Failed to fetch models: {}",
|
|
response.status()
|
|
)))
|
|
}
|
|
}
|
|
|
|
// Find lms, checking fallback paths if not in PATH
|
|
fn find_lms() -> std::io::Result<String> {
|
|
Self::find_lms_with_home_dir(None)
|
|
}
|
|
|
|
fn find_lms_with_home_dir(home_dir: Option<&str>) -> std::io::Result<String> {
|
|
// First try 'lms' in PATH
|
|
if which::which("lms").is_ok() {
|
|
return Ok("lms".to_string());
|
|
}
|
|
|
|
// Platform-specific fallback paths
|
|
let home = match home_dir {
|
|
Some(dir) => dir.to_string(),
|
|
None => {
|
|
#[cfg(unix)]
|
|
{
|
|
std::env::var("HOME").unwrap_or_default()
|
|
}
|
|
#[cfg(windows)]
|
|
{
|
|
std::env::var("USERPROFILE").unwrap_or_default()
|
|
}
|
|
}
|
|
};
|
|
|
|
#[cfg(unix)]
|
|
let fallback_path = format!("{home}/.lmstudio/bin/lms");
|
|
|
|
#[cfg(windows)]
|
|
let fallback_path = format!("{home}/.lmstudio/bin/lms.exe");
|
|
|
|
if Path::new(&fallback_path).exists() {
|
|
Ok(fallback_path)
|
|
} else {
|
|
Err(std::io::Error::new(
|
|
std::io::ErrorKind::NotFound,
|
|
"LM Studio not found. Please install LM Studio from https://lmstudio.ai/",
|
|
))
|
|
}
|
|
}
|
|
|
|
pub async fn download_model(&self, model: &str) -> std::io::Result<()> {
|
|
let lms = Self::find_lms()?;
|
|
eprintln!("Downloading model: {model}");
|
|
|
|
let status = std::process::Command::new(&lms)
|
|
.args(["get", "--yes", model])
|
|
.stdout(std::process::Stdio::inherit())
|
|
.stderr(std::process::Stdio::null())
|
|
.status()
|
|
.map_err(|e| {
|
|
std::io::Error::other(format!("Failed to execute '{lms} get --yes {model}': {e}"))
|
|
})?;
|
|
|
|
if !status.success() {
|
|
return Err(std::io::Error::other(format!(
|
|
"Model download failed with exit code: {}",
|
|
status.code().unwrap_or(-1)
|
|
)));
|
|
}
|
|
|
|
tracing::info!("Successfully downloaded model '{model}'");
|
|
Ok(())
|
|
}
|
|
|
|
/// Low-level constructor given a raw host root, e.g. "http://localhost:1234".
|
|
#[cfg(test)]
|
|
fn from_host_root(host_root: impl Into<String>) -> Self {
|
|
let client = reqwest::Client::builder()
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
.build()
|
|
.unwrap_or_else(|_| reqwest::Client::new());
|
|
Self {
|
|
client,
|
|
base_url: host_root.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_fetch_models_happy_path() {
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
tracing::info!(
|
|
"{} is set; skipping test_fetch_models_happy_path",
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
);
|
|
return;
|
|
}
|
|
|
|
let server = wiremock::MockServer::start().await;
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
.and(wiremock::matchers::path("/models"))
|
|
.respond_with(
|
|
wiremock::ResponseTemplate::new(200).set_body_raw(
|
|
serde_json::json!({
|
|
"data": [
|
|
{"id": "openai/gpt-oss-20b"},
|
|
]
|
|
})
|
|
.to_string(),
|
|
"application/json",
|
|
),
|
|
)
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let client = LMStudioClient::from_host_root(server.uri());
|
|
let models = client.fetch_models().await.expect("fetch models");
|
|
assert!(models.contains(&"openai/gpt-oss-20b".to_string()));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_fetch_models_no_data_array() {
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
tracing::info!(
|
|
"{} is set; skipping test_fetch_models_no_data_array",
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
);
|
|
return;
|
|
}
|
|
|
|
let server = wiremock::MockServer::start().await;
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
.and(wiremock::matchers::path("/models"))
|
|
.respond_with(
|
|
wiremock::ResponseTemplate::new(200)
|
|
.set_body_raw(serde_json::json!({}).to_string(), "application/json"),
|
|
)
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let client = LMStudioClient::from_host_root(server.uri());
|
|
let result = client.fetch_models().await;
|
|
assert!(result.is_err());
|
|
assert!(
|
|
result
|
|
.unwrap_err()
|
|
.to_string()
|
|
.contains("No 'data' array in response")
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_fetch_models_server_error() {
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
tracing::info!(
|
|
"{} is set; skipping test_fetch_models_server_error",
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
);
|
|
return;
|
|
}
|
|
|
|
let server = wiremock::MockServer::start().await;
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
.and(wiremock::matchers::path("/models"))
|
|
.respond_with(wiremock::ResponseTemplate::new(500))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let client = LMStudioClient::from_host_root(server.uri());
|
|
let result = client.fetch_models().await;
|
|
assert!(result.is_err());
|
|
assert!(
|
|
result
|
|
.unwrap_err()
|
|
.to_string()
|
|
.contains("Failed to fetch models: 500")
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_check_server_happy_path() {
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
tracing::info!(
|
|
"{} is set; skipping test_check_server_happy_path",
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
);
|
|
return;
|
|
}
|
|
|
|
let server = wiremock::MockServer::start().await;
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
.and(wiremock::matchers::path("/models"))
|
|
.respond_with(wiremock::ResponseTemplate::new(200))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let client = LMStudioClient::from_host_root(server.uri());
|
|
client
|
|
.check_server()
|
|
.await
|
|
.expect("server check should pass");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_check_server_error() {
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
tracing::info!(
|
|
"{} is set; skipping test_check_server_error",
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
);
|
|
return;
|
|
}
|
|
|
|
let server = wiremock::MockServer::start().await;
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
.and(wiremock::matchers::path("/models"))
|
|
.respond_with(wiremock::ResponseTemplate::new(404))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let client = LMStudioClient::from_host_root(server.uri());
|
|
let result = client.check_server().await;
|
|
assert!(result.is_err());
|
|
assert!(
|
|
result
|
|
.unwrap_err()
|
|
.to_string()
|
|
.contains("Server returned error: 404")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_find_lms() {
|
|
let result = LMStudioClient::find_lms();
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
// lms was found in PATH - that's fine
|
|
}
|
|
Err(e) => {
|
|
// Expected error when LM Studio not installed
|
|
assert!(e.to_string().contains("LM Studio not found"));
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_find_lms_with_mock_home() {
|
|
// Test fallback path construction without touching env vars
|
|
#[cfg(unix)]
|
|
{
|
|
let result = LMStudioClient::find_lms_with_home_dir(Some("/test/home"));
|
|
if let Err(e) = result {
|
|
assert!(e.to_string().contains("LM Studio not found"));
|
|
}
|
|
}
|
|
|
|
#[cfg(windows)]
|
|
{
|
|
let result = LMStudioClient::find_lms_with_home_dir(Some("C:\\test\\home"));
|
|
if let Err(e) = result {
|
|
assert!(e.to_string().contains("LM Studio not found"));
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_from_host_root() {
|
|
let client = LMStudioClient::from_host_root("http://localhost:1234");
|
|
assert_eq!(client.base_url, "http://localhost:1234");
|
|
|
|
let client = LMStudioClient::from_host_root("https://example.com:8080/api");
|
|
assert_eq!(client.base_url, "https://example.com:8080/api");
|
|
}
|
|
}
|