This makes stdio mcp servers more flexible by allowing users to specify the cwd to run the server command from and adding additional environment variables to be passed through to the server. Example config using the test server in this repo: ```toml [mcp_servers.test_stdio] cwd = "/Users/<user>/code/codex/codex-rs" command = "cargo" args = ["run", "--bin", "test_stdio_server"] env_vars = ["MCP_TEST_VALUE"] ``` @bolinfest I know you hate these env var tests but let's roll with this for now. I may take a stab at the env guard + serial macro at some point.
282 lines
8.3 KiB
Rust
282 lines
8.3 KiB
Rust
use std::collections::HashMap;
|
|
use std::env;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Context;
|
|
use anyhow::Result;
|
|
use anyhow::anyhow;
|
|
use mcp_types::CallToolResult;
|
|
use reqwest::ClientBuilder;
|
|
use reqwest::header::HeaderMap;
|
|
use reqwest::header::HeaderName;
|
|
use reqwest::header::HeaderValue;
|
|
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
|
use rmcp::service::ServiceError;
|
|
use serde_json::Value;
|
|
use tokio::time;
|
|
|
|
pub(crate) async fn run_with_timeout<F, T>(
|
|
fut: F,
|
|
timeout: Option<Duration>,
|
|
label: &str,
|
|
) -> Result<T>
|
|
where
|
|
F: std::future::Future<Output = Result<T, ServiceError>>,
|
|
{
|
|
if let Some(duration) = timeout {
|
|
let result = time::timeout(duration, fut)
|
|
.await
|
|
.with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?;
|
|
result.map_err(|err| anyhow!("{label} failed: {err}"))
|
|
} else {
|
|
fut.await.map_err(|err| anyhow!("{label} failed: {err}"))
|
|
}
|
|
}
|
|
|
|
pub(crate) fn convert_call_tool_result(result: RmcpCallToolResult) -> Result<CallToolResult> {
|
|
let mut value = serde_json::to_value(result)?;
|
|
if let Some(obj) = value.as_object_mut()
|
|
&& (obj.get("content").is_none()
|
|
|| obj.get("content").is_some_and(serde_json::Value::is_null))
|
|
{
|
|
obj.insert("content".to_string(), Value::Array(Vec::new()));
|
|
}
|
|
serde_json::from_value(value).context("failed to convert call tool result")
|
|
}
|
|
|
|
/// Convert from mcp-types to Rust SDK types.
|
|
///
|
|
/// The Rust SDK types are the same as our mcp-types crate because they are both
|
|
/// derived from the same MCP specification.
|
|
/// As a result, it should be safe to convert directly from one to the other.
|
|
pub(crate) fn convert_to_rmcp<T, U>(value: T) -> Result<U>
|
|
where
|
|
T: serde::Serialize,
|
|
U: serde::de::DeserializeOwned,
|
|
{
|
|
let json = serde_json::to_value(value)?;
|
|
serde_json::from_value(json).map_err(|err| anyhow!(err))
|
|
}
|
|
|
|
/// Convert from Rust SDK types to mcp-types.
|
|
///
|
|
/// The Rust SDK types are the same as our mcp-types crate because they are both
|
|
/// derived from the same MCP specification.
|
|
/// As a result, it should be safe to convert directly from one to the other.
|
|
pub(crate) fn convert_to_mcp<T, U>(value: T) -> Result<U>
|
|
where
|
|
T: serde::Serialize,
|
|
U: serde::de::DeserializeOwned,
|
|
{
|
|
let json = serde_json::to_value(value)?;
|
|
serde_json::from_value(json).map_err(|err| anyhow!(err))
|
|
}
|
|
|
|
pub(crate) fn create_env_for_mcp_server(
|
|
extra_env: Option<HashMap<String, String>>,
|
|
env_vars: &[String],
|
|
) -> HashMap<String, String> {
|
|
DEFAULT_ENV_VARS
|
|
.iter()
|
|
.copied()
|
|
.chain(env_vars.iter().map(String::as_str))
|
|
.filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value)))
|
|
.chain(extra_env.unwrap_or_default())
|
|
.collect()
|
|
}
|
|
|
|
pub(crate) fn build_default_headers(
|
|
http_headers: Option<HashMap<String, String>>,
|
|
env_http_headers: Option<HashMap<String, String>>,
|
|
) -> Result<HeaderMap> {
|
|
let mut headers = HeaderMap::new();
|
|
|
|
if let Some(static_headers) = http_headers {
|
|
for (name, value) in static_headers {
|
|
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
|
|
Ok(name) => name,
|
|
Err(err) => {
|
|
tracing::warn!("invalid HTTP header name `{name}`: {err}");
|
|
continue;
|
|
}
|
|
};
|
|
let header_value = match HeaderValue::from_str(value.as_str()) {
|
|
Ok(value) => value,
|
|
Err(err) => {
|
|
tracing::warn!("invalid HTTP header value for `{name}`: {err}");
|
|
continue;
|
|
}
|
|
};
|
|
headers.insert(header_name, header_value);
|
|
}
|
|
}
|
|
|
|
if let Some(env_headers) = env_http_headers {
|
|
for (name, env_var) in env_headers {
|
|
if let Ok(value) = env::var(&env_var) {
|
|
if value.trim().is_empty() {
|
|
continue;
|
|
}
|
|
|
|
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
|
|
Ok(name) => name,
|
|
Err(err) => {
|
|
tracing::warn!("invalid HTTP header name `{name}`: {err}");
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let header_value = match HeaderValue::from_str(value.as_str()) {
|
|
Ok(value) => value,
|
|
Err(err) => {
|
|
tracing::warn!(
|
|
"invalid HTTP header value read from {env_var} for `{name}`: {err}"
|
|
);
|
|
continue;
|
|
}
|
|
};
|
|
headers.insert(header_name, header_value);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(headers)
|
|
}
|
|
|
|
pub(crate) fn apply_default_headers(
|
|
builder: ClientBuilder,
|
|
default_headers: &HeaderMap,
|
|
) -> ClientBuilder {
|
|
if default_headers.is_empty() {
|
|
builder
|
|
} else {
|
|
builder.default_headers(default_headers.clone())
|
|
}
|
|
}
|
|
|
|
#[cfg(unix)]
|
|
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
|
"HOME",
|
|
"LOGNAME",
|
|
"PATH",
|
|
"SHELL",
|
|
"USER",
|
|
"__CF_USER_TEXT_ENCODING",
|
|
"LANG",
|
|
"LC_ALL",
|
|
"TERM",
|
|
"TMPDIR",
|
|
"TZ",
|
|
];
|
|
|
|
#[cfg(windows)]
|
|
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
|
"PATH",
|
|
"PATHEXT",
|
|
"USERNAME",
|
|
"USERDOMAIN",
|
|
"USERPROFILE",
|
|
"TEMP",
|
|
"TMP",
|
|
];
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use mcp_types::ContentBlock;
|
|
use pretty_assertions::assert_eq;
|
|
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
|
use serde_json::json;
|
|
|
|
use serial_test::serial;
|
|
use std::ffi::OsString;
|
|
|
|
struct EnvVarGuard {
|
|
key: String,
|
|
original: Option<OsString>,
|
|
}
|
|
|
|
impl EnvVarGuard {
|
|
fn set(key: &str, value: &str) -> Self {
|
|
let original = std::env::var_os(key);
|
|
unsafe {
|
|
std::env::set_var(key, value);
|
|
}
|
|
Self {
|
|
key: key.to_string(),
|
|
original,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for EnvVarGuard {
|
|
fn drop(&mut self) {
|
|
if let Some(value) = &self.original {
|
|
unsafe {
|
|
std::env::set_var(&self.key, value);
|
|
}
|
|
} else {
|
|
unsafe {
|
|
std::env::remove_var(&self.key);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn create_env_honors_overrides() {
|
|
let value = "custom".to_string();
|
|
let env =
|
|
create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])), &[]);
|
|
assert_eq!(env.get("TZ"), Some(&value));
|
|
}
|
|
|
|
#[test]
|
|
#[serial(extra_rmcp_env)]
|
|
fn create_env_includes_additional_whitelisted_variables() {
|
|
let custom_var = "EXTRA_RMCP_ENV";
|
|
let value = "from-env";
|
|
let _guard = EnvVarGuard::set(custom_var, value);
|
|
let env = create_env_for_mcp_server(None, &[custom_var.to_string()]);
|
|
assert_eq!(env.get(custom_var), Some(&value.to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn convert_call_tool_result_defaults_missing_content() -> Result<()> {
|
|
let structured_content = json!({ "key": "value" });
|
|
let rmcp_result = RmcpCallToolResult {
|
|
content: vec![],
|
|
structured_content: Some(structured_content.clone()),
|
|
is_error: Some(true),
|
|
meta: None,
|
|
};
|
|
|
|
let result = convert_call_tool_result(rmcp_result)?;
|
|
|
|
assert!(result.content.is_empty());
|
|
assert_eq!(result.structured_content, Some(structured_content));
|
|
assert_eq!(result.is_error, Some(true));
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn convert_call_tool_result_preserves_existing_content() -> Result<()> {
|
|
let rmcp_result = RmcpCallToolResult::success(vec![rmcp::model::Content::text("hello")]);
|
|
|
|
let result = convert_call_tool_result(rmcp_result)?;
|
|
|
|
assert_eq!(result.content.len(), 1);
|
|
match &result.content[0] {
|
|
ContentBlock::TextContent(text_content) => {
|
|
assert_eq!(text_content.text, "hello");
|
|
assert_eq!(text_content.r#type, "text");
|
|
}
|
|
other => panic!("expected text content got {other:?}"),
|
|
}
|
|
assert_eq!(result.structured_content, None);
|
|
assert_eq!(result.is_error, Some(false));
|
|
|
|
Ok(())
|
|
}
|
|
}
|