core-agent-ide/codex-rs/lmstudio/src/client.rs
rugvedS07 837bc98a1d
LM Studio OSS Support (#2312)
## Overview

Adds LM Studio OSS support. Closes #1883


### Changes
This PR enhances the behavior of `--oss` flag to support LM Studio as a
provider. Additionally, it introduces a new flag`--local-provider` which
can take in `lmstudio` or `ollama` as values if the user wants to
explicitly choose which one to use.

If no provider is specified `codex --oss` will auto-select the provider
based on whichever is running.

#### Additional enhancements 
The default can be set using `oss-provider` in config like:

```
oss_provider = "lmstudio"
```

For non-interactive users, they will need to either provide the provider
as an arg or have it in their `config.toml`

### Notes
For best performance, [set the default context
length](https://lmstudio.ai/docs/app/advanced/per-model) for gpt-oss to
the maximum your machine can support

---------

Co-authored-by: Matt Clayton <matt@lmstudio.ai>
Co-authored-by: Eric Traut <etraut@openai.com>
2025-11-17 11:49:09 -08:00

397 lines
13 KiB
Rust

use codex_core::LMSTUDIO_OSS_PROVIDER_ID;
use codex_core::config::Config;
use std::io;
use std::path::Path;
#[derive(Clone)]
pub struct LMStudioClient {
client: reqwest::Client,
base_url: String,
}
const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'.";
impl LMStudioClient {
pub async fn try_from_provider(config: &Config) -> std::io::Result<Self> {
let provider = config
.model_providers
.get(LMSTUDIO_OSS_PROVIDER_ID)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("Built-in provider {LMSTUDIO_OSS_PROVIDER_ID} not found",),
)
})?;
let base_url = provider.base_url.as_ref().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"oss provider must have a base_url",
)
})?;
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
let client = LMStudioClient {
client,
base_url: base_url.to_string(),
};
client.check_server().await?;
Ok(client)
}
async fn check_server(&self) -> io::Result<()> {
let url = format!("{}/models", self.base_url.trim_end_matches('/'));
let response = self.client.get(&url).send().await;
if let Ok(resp) = response {
if resp.status().is_success() {
Ok(())
} else {
Err(io::Error::other(format!(
"Server returned error: {} {LMSTUDIO_CONNECTION_ERROR}",
resp.status()
)))
}
} else {
Err(io::Error::other(LMSTUDIO_CONNECTION_ERROR))
}
}
// Load a model by sending an empty request with max_tokens 1
pub async fn load_model(&self, model: &str) -> io::Result<()> {
let url = format!("{}/responses", self.base_url.trim_end_matches('/'));
let request_body = serde_json::json!({
"model": model,
"input": "",
"max_output_tokens": 1
});
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
if response.status().is_success() {
tracing::info!("Successfully loaded model '{model}'");
Ok(())
} else {
Err(io::Error::other(format!(
"Failed to load model: {}",
response.status()
)))
}
}
// Return the list of models available on the LM Studio server.
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
let url = format!("{}/models", self.base_url.trim_end_matches('/'));
let response = self
.client
.get(&url)
.send()
.await
.map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
if response.status().is_success() {
let json: serde_json::Value = response.json().await.map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}"))
})?;
let models = json["data"]
.as_array()
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "No 'data' array in response")
})?
.iter()
.filter_map(|model| model["id"].as_str())
.map(std::string::ToString::to_string)
.collect();
Ok(models)
} else {
Err(io::Error::other(format!(
"Failed to fetch models: {}",
response.status()
)))
}
}
// Find lms, checking fallback paths if not in PATH
fn find_lms() -> std::io::Result<String> {
Self::find_lms_with_home_dir(None)
}
fn find_lms_with_home_dir(home_dir: Option<&str>) -> std::io::Result<String> {
// First try 'lms' in PATH
if which::which("lms").is_ok() {
return Ok("lms".to_string());
}
// Platform-specific fallback paths
let home = match home_dir {
Some(dir) => dir.to_string(),
None => {
#[cfg(unix)]
{
std::env::var("HOME").unwrap_or_default()
}
#[cfg(windows)]
{
std::env::var("USERPROFILE").unwrap_or_default()
}
}
};
#[cfg(unix)]
let fallback_path = format!("{home}/.lmstudio/bin/lms");
#[cfg(windows)]
let fallback_path = format!("{home}/.lmstudio/bin/lms.exe");
if Path::new(&fallback_path).exists() {
Ok(fallback_path)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"LM Studio not found. Please install LM Studio from https://lmstudio.ai/",
))
}
}
pub async fn download_model(&self, model: &str) -> std::io::Result<()> {
let lms = Self::find_lms()?;
eprintln!("Downloading model: {model}");
let status = std::process::Command::new(&lms)
.args(["get", "--yes", model])
.stdout(std::process::Stdio::inherit())
.stderr(std::process::Stdio::null())
.status()
.map_err(|e| {
std::io::Error::other(format!("Failed to execute '{lms} get --yes {model}': {e}"))
})?;
if !status.success() {
return Err(std::io::Error::other(format!(
"Model download failed with exit code: {}",
status.code().unwrap_or(-1)
)));
}
tracing::info!("Successfully downloaded model '{model}'");
Ok(())
}
/// Low-level constructor given a raw host root, e.g. "http://localhost:1234".
#[cfg(test)]
fn from_host_root(host_root: impl Into<String>) -> Self {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Self {
client,
base_url: host_root.into(),
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
#[tokio::test]
async fn test_fetch_models_happy_path() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} is set; skipping test_fetch_models_happy_path",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/models"))
.respond_with(
wiremock::ResponseTemplate::new(200).set_body_raw(
serde_json::json!({
"data": [
{"id": "openai/gpt-oss-20b"},
]
})
.to_string(),
"application/json",
),
)
.mount(&server)
.await;
let client = LMStudioClient::from_host_root(server.uri());
let models = client.fetch_models().await.expect("fetch models");
assert!(models.contains(&"openai/gpt-oss-20b".to_string()));
}
#[tokio::test]
async fn test_fetch_models_no_data_array() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} is set; skipping test_fetch_models_no_data_array",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/models"))
.respond_with(
wiremock::ResponseTemplate::new(200)
.set_body_raw(serde_json::json!({}).to_string(), "application/json"),
)
.mount(&server)
.await;
let client = LMStudioClient::from_host_root(server.uri());
let result = client.fetch_models().await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("No 'data' array in response")
);
}
#[tokio::test]
async fn test_fetch_models_server_error() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} is set; skipping test_fetch_models_server_error",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/models"))
.respond_with(wiremock::ResponseTemplate::new(500))
.mount(&server)
.await;
let client = LMStudioClient::from_host_root(server.uri());
let result = client.fetch_models().await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Failed to fetch models: 500")
);
}
#[tokio::test]
async fn test_check_server_happy_path() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} is set; skipping test_check_server_happy_path",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/models"))
.respond_with(wiremock::ResponseTemplate::new(200))
.mount(&server)
.await;
let client = LMStudioClient::from_host_root(server.uri());
client
.check_server()
.await
.expect("server check should pass");
}
#[tokio::test]
async fn test_check_server_error() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} is set; skipping test_check_server_error",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/models"))
.respond_with(wiremock::ResponseTemplate::new(404))
.mount(&server)
.await;
let client = LMStudioClient::from_host_root(server.uri());
let result = client.check_server().await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Server returned error: 404")
);
}
#[test]
fn test_find_lms() {
let result = LMStudioClient::find_lms();
match result {
Ok(_) => {
// lms was found in PATH - that's fine
}
Err(e) => {
// Expected error when LM Studio not installed
assert!(e.to_string().contains("LM Studio not found"));
}
}
}
#[test]
fn test_find_lms_with_mock_home() {
// Test fallback path construction without touching env vars
#[cfg(unix)]
{
let result = LMStudioClient::find_lms_with_home_dir(Some("/test/home"));
if let Err(e) = result {
assert!(e.to_string().contains("LM Studio not found"));
}
}
#[cfg(windows)]
{
let result = LMStudioClient::find_lms_with_home_dir(Some("C:\\test\\home"));
if let Err(e) = result {
assert!(e.to_string().contains("LM Studio not found"));
}
}
}
#[test]
fn test_from_host_root() {
let client = LMStudioClient::from_host_root("http://localhost:1234");
assert_eq!(client.base_url, "http://localhost:1234");
let client = LMStudioClient::from_host_root("https://example.com:8080/api");
assert_eq!(client.base_url, "https://example.com:8080/api");
}
}