core-agent-ide/codex-rs/rmcp-client/src/utils.rs
Casey Chow b3765a07e8
[rmcp-client] Recover from streamable HTTP 404 sessions (#13514)
## Summary
- add one-time session recovery in `RmcpClient` for streamable HTTP MCP
`404` session expiry
- rebuild the transport and retry the failed operation once after
reinitializing the client state
- extend the test server and integration coverage for `404`, `401`,
single-retry, and non-session failure scenarios

## Testing
- just fmt
- cargo test -p codex-rmcp-client (the post-rebase run lost its final
summary in the terminal; the suite had passed earlier before the rebase)
- just fix -p codex-rmcp-client
2026-03-06 10:02:42 -05:00

194 lines
5.1 KiB
Rust

use anyhow::Result;
use reqwest::ClientBuilder;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
use std::collections::HashMap;
use std::env;
pub(crate) fn create_env_for_mcp_server(
extra_env: Option<HashMap<String, String>>,
env_vars: &[String],
) -> HashMap<String, String> {
DEFAULT_ENV_VARS
.iter()
.copied()
.chain(env_vars.iter().map(String::as_str))
.filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value)))
.chain(extra_env.unwrap_or_default())
.collect()
}
pub(crate) fn build_default_headers(
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
if let Some(static_headers) = http_headers {
for (name, value) in static_headers {
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
Ok(name) => name,
Err(err) => {
tracing::warn!("invalid HTTP header name `{name}`: {err}");
continue;
}
};
let header_value = match HeaderValue::from_str(value.as_str()) {
Ok(value) => value,
Err(err) => {
tracing::warn!("invalid HTTP header value for `{name}`: {err}");
continue;
}
};
headers.insert(header_name, header_value);
}
}
if let Some(env_headers) = env_http_headers {
for (name, env_var) in env_headers {
if let Ok(value) = env::var(&env_var) {
if value.trim().is_empty() {
continue;
}
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
Ok(name) => name,
Err(err) => {
tracing::warn!("invalid HTTP header name `{name}`: {err}");
continue;
}
};
let header_value = match HeaderValue::from_str(value.as_str()) {
Ok(value) => value,
Err(err) => {
tracing::warn!(
"invalid HTTP header value read from {env_var} for `{name}`: {err}"
);
continue;
}
};
headers.insert(header_name, header_value);
}
}
}
Ok(headers)
}
pub(crate) fn apply_default_headers(
builder: ClientBuilder,
default_headers: &HeaderMap,
) -> ClientBuilder {
if default_headers.is_empty() {
builder
} else {
builder.default_headers(default_headers.clone())
}
}
#[cfg(unix)]
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
"HOME",
"LOGNAME",
"PATH",
"SHELL",
"USER",
"__CF_USER_TEXT_ENCODING",
"LANG",
"LC_ALL",
"TERM",
"TMPDIR",
"TZ",
];
#[cfg(windows)]
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
// Core path resolution
"PATH",
"PATHEXT",
// Shell and system roots
"COMSPEC",
"SYSTEMROOT",
"SYSTEMDRIVE",
// User context and profiles
"USERNAME",
"USERDOMAIN",
"USERPROFILE",
"HOMEDRIVE",
"HOMEPATH",
// Program locations
"PROGRAMFILES",
"PROGRAMFILES(X86)",
"PROGRAMW6432",
"PROGRAMDATA",
// App data and caches
"LOCALAPPDATA",
"APPDATA",
// Temp locations
"TEMP",
"TMP",
// Common shells/pwsh hints
"POWERSHELL",
"PWSH",
];
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serial_test::serial;
use std::ffi::OsString;
struct EnvVarGuard {
key: String,
original: Option<OsString>,
}
impl EnvVarGuard {
fn set(key: &str, value: &str) -> Self {
let original = std::env::var_os(key);
unsafe {
std::env::set_var(key, value);
}
Self {
key: key.to_string(),
original,
}
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
if let Some(value) = &self.original {
unsafe {
std::env::set_var(&self.key, value);
}
} else {
unsafe {
std::env::remove_var(&self.key);
}
}
}
}
#[tokio::test]
async fn create_env_honors_overrides() {
let value = "custom".to_string();
let env =
create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])), &[]);
assert_eq!(env.get("TZ"), Some(&value));
}
#[test]
#[serial(extra_rmcp_env)]
fn create_env_includes_additional_whitelisted_variables() {
let custom_var = "EXTRA_RMCP_ENV";
let value = "from-env";
let _guard = EnvVarGuard::set(custom_var, value);
let env = create_env_for_mcp_server(None, &[custom_var.to_string()]);
assert_eq!(env.get(custom_var), Some(&value.to_string()));
}
}