core-agent-ide/codex-rs/network-proxy/src/mitm.rs
viyatb-oai 8d3d58f992
feat(network-proxy): add MITM support and gate limited-mode CONNECT (#9859)
## Description
- Adds MITM support (CA load/issue, TLS termination, optional body
inspection).
- Adds `codex-network-proxy init` to create
`CODEX_HOME/network_proxy/mitm`.
- Enforces limited-mode HTTPS correctly: `CONNECT` requires MITM,
otherwise blocked with `mitm_required`.
- Keeps `origin/main` layering/reload semantics (managed layers included
in reload checks).
- Centralizes block reasons (`REASON_MITM_REQUIRED`) and removes
`println!`.
- Scope is MITM-only (no SOCKS changes).

gated by `mitm=false` (default)
2026-02-24 18:15:15 +00:00

482 lines
14 KiB
Rust

use crate::certs::ManagedMitmCa;
use crate::config::NetworkMode;
use crate::policy::normalize_host;
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
use crate::responses::blocked_text_response;
use crate::responses::text_response;
use crate::runtime::HostBlockDecision;
use crate::runtime::HostBlockReason;
use crate::state::BlockedRequest;
use crate::state::BlockedRequestArgs;
use crate::state::NetworkProxyState;
use crate::upstream::UpstreamClient;
use anyhow::Context as _;
use anyhow::Result;
use anyhow::anyhow;
use rama_core::Layer;
use rama_core::Service;
use rama_core::bytes::Bytes;
use rama_core::error::BoxError;
use rama_core::extensions::ExtensionsRef;
use rama_core::futures::stream::Stream;
use rama_core::rt::Executor;
use rama_core::service::service_fn;
use rama_http::Body;
use rama_http::BodyDataStream;
use rama_http::HeaderValue;
use rama_http::Request;
use rama_http::Response;
use rama_http::StatusCode;
use rama_http::Uri;
use rama_http::header::HOST;
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
use rama_http_backend::server::HttpServer;
use rama_http_backend::server::layer::upgrade::Upgraded;
use rama_net::proxy::ProxyTarget;
use rama_net::stream::SocketInfo;
use rama_tls_rustls::server::TlsAcceptorData;
use rama_tls_rustls::server::TlsAcceptorLayer;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context as TaskContext;
use std::task::Poll;
use tracing::info;
use tracing::warn;
/// State needed to terminate a CONNECT tunnel and enforce policy on inner HTTPS requests.
pub struct MitmState {
ca: ManagedMitmCa,
upstream: UpstreamClient,
inspect: bool,
max_body_bytes: usize,
}
#[derive(Clone)]
struct MitmPolicyContext {
target_host: String,
target_port: u16,
mode: NetworkMode,
app_state: Arc<NetworkProxyState>,
}
#[derive(Clone)]
struct MitmRequestContext {
policy: MitmPolicyContext,
mitm: Arc<MitmState>,
}
const MITM_INSPECT_BODIES: bool = false;
const MITM_MAX_BODY_BYTES: usize = 4096;
impl std::fmt::Debug for MitmState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Avoid dumping internal state (CA material, connectors, etc.) to logs.
f.debug_struct("MitmState")
.field("inspect", &self.inspect)
.field("max_body_bytes", &self.max_body_bytes)
.finish_non_exhaustive()
}
}
impl MitmState {
pub(crate) fn new(allow_upstream_proxy: bool) -> Result<Self> {
// MITM exists to make limited-mode HTTPS enforceable: once CONNECT is established, plain
// proxying would lose visibility into the inner HTTP request. We generate/load a local CA
// and issue per-host leaf certs so we can terminate TLS and apply policy.
let ca = ManagedMitmCa::load_or_create()?;
let upstream = if allow_upstream_proxy {
UpstreamClient::from_env_proxy()
} else {
UpstreamClient::direct()
};
Ok(Self {
ca,
upstream,
inspect: MITM_INSPECT_BODIES,
max_body_bytes: MITM_MAX_BODY_BYTES,
})
}
fn tls_acceptor_data_for_host(&self, host: &str) -> Result<TlsAcceptorData> {
self.ca.tls_acceptor_data_for_host(host)
}
pub(crate) fn inspect_enabled(&self) -> bool {
self.inspect
}
pub(crate) fn max_body_bytes(&self) -> usize {
self.max_body_bytes
}
}
/// Terminate the upgraded CONNECT stream with a generated leaf cert and proxy inner HTTPS traffic.
pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
let mitm = upgraded
.extensions()
.get::<Arc<MitmState>>()
.cloned()
.context("missing MITM state")?;
let app_state = upgraded
.extensions()
.get::<Arc<NetworkProxyState>>()
.cloned()
.context("missing app state")?;
let target = upgraded
.extensions()
.get::<ProxyTarget>()
.context("missing proxy target")?
.0
.clone();
let target_host = normalize_host(&target.host.to_string());
let target_port = target.port;
let acceptor_data = mitm.tls_acceptor_data_for_host(&target_host)?;
let mode = upgraded
.extensions()
.get::<NetworkMode>()
.copied()
.unwrap_or(NetworkMode::Full);
let request_ctx = Arc::new(MitmRequestContext {
policy: MitmPolicyContext {
target_host,
target_port,
mode,
app_state,
},
mitm,
});
let executor = upgraded
.extensions()
.get::<Executor>()
.cloned()
.unwrap_or_default();
let http_service = HttpServer::auto(executor).service(
(
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
)
.into_layer(service_fn({
let request_ctx = request_ctx.clone();
move |req| {
let request_ctx = request_ctx.clone();
async move { handle_mitm_request(req, request_ctx).await }
}
})),
);
let https_service = TlsAcceptorLayer::new(acceptor_data)
.with_store_client_hello(true)
.into_layer(http_service);
https_service
.serve(upgraded)
.await
.map_err(|err| anyhow!("MITM serve error: {err}"))?;
Ok(())
}
async fn handle_mitm_request(
req: Request,
request_ctx: Arc<MitmRequestContext>,
) -> Result<Response, std::convert::Infallible> {
let response = match forward_request(req, &request_ctx).await {
Ok(resp) => resp,
Err(err) => {
warn!("MITM request handling failed: {err}");
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
}
};
Ok(response)
}
async fn forward_request(req: Request, request_ctx: &MitmRequestContext) -> Result<Response> {
if let Some(response) = mitm_blocking_response(&req, &request_ctx.policy).await? {
return Ok(response);
}
let target_host = request_ctx.policy.target_host.clone();
let target_port = request_ctx.policy.target_port;
let mitm = request_ctx.mitm.clone();
let method = req.method().as_str().to_string();
let path = path_and_query(req.uri());
let log_path = path_for_log(req.uri());
let (mut parts, body) = req.into_parts();
let authority = authority_header_value(&target_host, target_port);
parts.uri = build_https_uri(&authority, &path)?;
parts
.headers
.insert(HOST, HeaderValue::from_str(&authority)?);
let inspect = mitm.inspect_enabled();
let max_body_bytes = mitm.max_body_bytes();
let body = if inspect {
inspect_body(
body,
max_body_bytes,
RequestLogContext {
host: authority.clone(),
method: method.clone(),
path: log_path.clone(),
},
)
} else {
body
};
let upstream_req = Request::from_parts(parts, body);
let upstream_resp = mitm.upstream.serve(upstream_req).await?;
respond_with_inspection(
upstream_resp,
inspect,
max_body_bytes,
&method,
&log_path,
&authority,
)
}
async fn mitm_blocking_response(
req: &Request,
policy: &MitmPolicyContext,
) -> Result<Option<Response>> {
if req.method().as_str() == "CONNECT" {
return Ok(Some(text_response(
StatusCode::METHOD_NOT_ALLOWED,
"CONNECT not supported inside MITM",
)));
}
let method = req.method().as_str().to_string();
let log_path = path_for_log(req.uri());
let client = req
.extensions()
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string());
if let Some(request_host) = extract_request_host(req) {
let normalized = normalize_host(&request_host);
if !normalized.is_empty() && normalized != policy.target_host {
warn!(
"MITM host mismatch (target={}, request_host={normalized})",
policy.target_host
);
return Ok(Some(text_response(
StatusCode::BAD_REQUEST,
"host mismatch",
)));
}
}
// CONNECT already handled allowlist/denylist + decider policy. Re-check local/private
// resolution here to defend against DNS rebinding between CONNECT and inner HTTPS requests.
if matches!(
policy
.app_state
.host_blocked(&policy.target_host, policy.target_port)
.await?,
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
) {
let reason = HostBlockReason::NotAllowedLocal.as_str();
let _ = policy
.app_state
.record_blocked(BlockedRequest::new(BlockedRequestArgs {
host: policy.target_host.clone(),
reason: reason.to_string(),
client: client.clone(),
method: Some(method.clone()),
mode: Some(policy.mode),
protocol: "https".to_string(),
decision: None,
source: None,
port: Some(policy.target_port),
}))
.await;
warn!(
"MITM blocked local/private target after CONNECT (host={}, port={}, method={method}, path={log_path})",
policy.target_host, policy.target_port
);
return Ok(Some(blocked_text_response(reason)));
}
if !policy.mode.allows_method(&method) {
let _ = policy
.app_state
.record_blocked(BlockedRequest::new(BlockedRequestArgs {
host: policy.target_host.clone(),
reason: REASON_METHOD_NOT_ALLOWED.to_string(),
client: client.clone(),
method: Some(method.clone()),
mode: Some(policy.mode),
protocol: "https".to_string(),
decision: None,
source: None,
port: Some(policy.target_port),
}))
.await;
warn!(
"MITM blocked by method policy (host={}, method={method}, path={log_path}, mode={:?}, allowed_methods=GET, HEAD, OPTIONS)",
policy.target_host, policy.mode
);
return Ok(Some(blocked_text_response(REASON_METHOD_NOT_ALLOWED)));
}
Ok(None)
}
fn respond_with_inspection(
resp: Response,
inspect: bool,
max_body_bytes: usize,
method: &str,
log_path: &str,
authority: &str,
) -> Result<Response> {
if !inspect {
return Ok(resp);
}
let (parts, body) = resp.into_parts();
let body = inspect_body(
body,
max_body_bytes,
ResponseLogContext {
host: authority.to_string(),
method: method.to_string(),
path: log_path.to_string(),
status: parts.status,
},
);
Ok(Response::from_parts(parts, body))
}
fn inspect_body<T: BodyLoggable + Send + 'static>(
body: Body,
max_body_bytes: usize,
ctx: T,
) -> Body {
Body::from_stream(InspectStream {
inner: Box::pin(body.into_data_stream()),
ctx: Some(Box::new(ctx)),
len: 0,
max_body_bytes,
})
}
struct InspectStream<T> {
inner: Pin<Box<BodyDataStream>>,
ctx: Option<Box<T>>,
len: usize,
max_body_bytes: usize,
}
impl<T: BodyLoggable> Stream for InspectStream<T> {
type Item = Result<Bytes, BoxError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.len = this.len.saturating_add(bytes.len());
Poll::Ready(Some(Ok(bytes)))
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => {
if let Some(ctx) = this.ctx.take() {
ctx.log(this.len, this.len > this.max_body_bytes);
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
struct RequestLogContext {
host: String,
method: String,
path: String,
}
struct ResponseLogContext {
host: String,
method: String,
path: String,
status: StatusCode,
}
trait BodyLoggable {
fn log(self, len: usize, truncated: bool);
}
impl BodyLoggable for RequestLogContext {
fn log(self, len: usize, truncated: bool) {
let host = self.host;
let method = self.method;
let path = self.path;
info!(
"MITM inspected request body (host={host}, method={method}, path={path}, body_len={len}, truncated={truncated})"
);
}
}
impl BodyLoggable for ResponseLogContext {
fn log(self, len: usize, truncated: bool) {
let host = self.host;
let method = self.method;
let path = self.path;
let status = self.status;
info!(
"MITM inspected response body (host={host}, method={method}, path={path}, status={status}, body_len={len}, truncated={truncated})"
);
}
}
fn extract_request_host(req: &Request) -> Option<String> {
req.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(ToString::to_string)
.or_else(|| req.uri().authority().map(|a| a.as_str().to_string()))
}
fn authority_header_value(host: &str, port: u16) -> String {
// Host header / URI authority formatting.
if host.contains(':') {
if port == 443 {
format!("[{host}]")
} else {
format!("[{host}]:{port}")
}
} else if port == 443 {
host.to_string()
} else {
format!("{host}:{port}")
}
}
fn build_https_uri(authority: &str, path: &str) -> Result<Uri> {
let target = format!("https://{authority}{path}");
Ok(target.parse()?)
}
fn path_and_query(uri: &Uri) -> String {
uri.path_and_query()
.map(rama_http::uri::PathAndQuery::as_str)
.unwrap_or("/")
.to_string()
}
fn path_for_log(uri: &Uri) -> String {
uri.path().to_string()
}
#[cfg(test)]
#[path = "mitm_tests.rs"]
mod tests;