1. Added SessionSource::Custom(String) and --session-source. 2. Enforced plugin and skill products by session_source. 3. Applied the same filtering to curated background refresh.
885 lines
35 KiB
Rust
885 lines
35 KiB
Rust
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
|
|
|
use codex_arg0::Arg0DispatchPaths;
|
|
use codex_cloud_requirements::cloud_requirements_loader;
|
|
use codex_core::AuthManager;
|
|
use codex_core::config::Config;
|
|
use codex_core::config::ConfigBuilder;
|
|
use codex_core::config_loader::CloudRequirementsLoader;
|
|
use codex_core::config_loader::ConfigLayerStackOrdering;
|
|
use codex_core::config_loader::LoaderOverrides;
|
|
use codex_utils_cli::CliConfigOverrides;
|
|
use std::collections::HashMap;
|
|
use std::collections::HashSet;
|
|
use std::io::ErrorKind;
|
|
use std::io::Result as IoResult;
|
|
use std::sync::Arc;
|
|
use std::sync::RwLock;
|
|
use std::sync::atomic::AtomicBool;
|
|
|
|
use crate::message_processor::MessageProcessor;
|
|
use crate::message_processor::MessageProcessorArgs;
|
|
use crate::outgoing_message::ConnectionId;
|
|
use crate::outgoing_message::OutgoingEnvelope;
|
|
use crate::outgoing_message::OutgoingMessageSender;
|
|
use crate::transport::CHANNEL_CAPACITY;
|
|
use crate::transport::ConnectionState;
|
|
use crate::transport::OutboundConnectionState;
|
|
use crate::transport::TransportEvent;
|
|
use crate::transport::route_outgoing_envelope;
|
|
use crate::transport::start_stdio_connection;
|
|
use crate::transport::start_websocket_acceptor;
|
|
use codex_app_server_protocol::ConfigLayerSource;
|
|
use codex_app_server_protocol::ConfigWarningNotification;
|
|
use codex_app_server_protocol::JSONRPCMessage;
|
|
use codex_app_server_protocol::TextPosition as AppTextPosition;
|
|
use codex_app_server_protocol::TextRange as AppTextRange;
|
|
use codex_core::ExecPolicyError;
|
|
use codex_core::check_execpolicy_for_warnings;
|
|
use codex_core::config_loader::ConfigLoadError;
|
|
use codex_core::config_loader::TextRange as CoreTextRange;
|
|
use codex_feedback::CodexFeedback;
|
|
use codex_protocol::protocol::SessionSource;
|
|
use codex_state::log_db;
|
|
use tokio::sync::mpsc;
|
|
use tokio::task::JoinHandle;
|
|
use tokio_util::sync::CancellationToken;
|
|
use toml::Value as TomlValue;
|
|
use tracing::Level;
|
|
use tracing::error;
|
|
use tracing::info;
|
|
use tracing::warn;
|
|
use tracing_subscriber::EnvFilter;
|
|
use tracing_subscriber::Layer;
|
|
use tracing_subscriber::filter::Targets;
|
|
use tracing_subscriber::layer::SubscriberExt;
|
|
use tracing_subscriber::registry::Registry;
|
|
use tracing_subscriber::util::SubscriberInitExt;
|
|
|
|
mod app_server_tracing;
|
|
mod bespoke_event_handling;
|
|
mod codex_message_processor;
|
|
mod command_exec;
|
|
mod config_api;
|
|
mod dynamic_tools;
|
|
mod error_code;
|
|
mod external_agent_config_api;
|
|
mod filters;
|
|
mod fs_api;
|
|
mod fuzzy_file_search;
|
|
pub mod in_process;
|
|
mod message_processor;
|
|
mod models;
|
|
mod outgoing_message;
|
|
mod server_request_error;
|
|
mod thread_state;
|
|
mod thread_status;
|
|
mod transport;
|
|
|
|
pub use crate::error_code::INPUT_TOO_LARGE_ERROR_CODE;
|
|
pub use crate::error_code::INVALID_PARAMS_ERROR_CODE;
|
|
pub use crate::transport::AppServerTransport;
|
|
|
|
const LOG_FORMAT_ENV_VAR: &str = "LOG_FORMAT";
|
|
|
|
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
|
enum LogFormat {
|
|
Default,
|
|
Json,
|
|
}
|
|
|
|
type StderrLogLayer = Box<dyn Layer<Registry> + Send + Sync + 'static>;
|
|
|
|
/// Control-plane messages from the processor/transport side to the outbound router task.
|
|
///
|
|
/// `run_main_with_transport` now uses two loops/tasks:
|
|
/// - processor loop: handles incoming JSON-RPC and request dispatch
|
|
/// - outbound loop: performs potentially slow writes to per-connection writers
|
|
///
|
|
/// `OutboundControlEvent` keeps those loops coordinated without sharing mutable
|
|
/// connection state directly. In particular, the outbound loop needs to know
|
|
/// when a connection opens/closes so it can route messages correctly.
|
|
enum OutboundControlEvent {
|
|
/// Register a new writer for an opened connection.
|
|
Opened {
|
|
connection_id: ConnectionId,
|
|
writer: mpsc::Sender<crate::outgoing_message::OutgoingMessage>,
|
|
// Allow codex/event/* notifications to be emitted.
|
|
allow_legacy_notifications: bool,
|
|
disconnect_sender: Option<CancellationToken>,
|
|
initialized: Arc<AtomicBool>,
|
|
experimental_api_enabled: Arc<AtomicBool>,
|
|
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
|
},
|
|
/// Remove state for a closed/disconnected connection.
|
|
Closed { connection_id: ConnectionId },
|
|
/// Disconnect all connection-oriented clients during graceful restart.
|
|
DisconnectAll,
|
|
}
|
|
|
|
#[derive(Default)]
|
|
struct ShutdownState {
|
|
requested: bool,
|
|
forced: bool,
|
|
last_logged_running_turn_count: Option<usize>,
|
|
}
|
|
|
|
enum ShutdownAction {
|
|
Noop,
|
|
Finish,
|
|
}
|
|
|
|
async fn shutdown_signal() -> IoResult<()> {
|
|
#[cfg(unix)]
|
|
{
|
|
use tokio::signal::unix::SignalKind;
|
|
use tokio::signal::unix::signal;
|
|
|
|
let mut term = signal(SignalKind::terminate())?;
|
|
tokio::select! {
|
|
ctrl_c_result = tokio::signal::ctrl_c() => ctrl_c_result,
|
|
_ = term.recv() => Ok(()),
|
|
}
|
|
}
|
|
|
|
#[cfg(not(unix))]
|
|
{
|
|
tokio::signal::ctrl_c().await
|
|
}
|
|
}
|
|
|
|
impl ShutdownState {
|
|
fn requested(&self) -> bool {
|
|
self.requested
|
|
}
|
|
|
|
fn forced(&self) -> bool {
|
|
self.forced
|
|
}
|
|
|
|
fn on_signal(&mut self, connection_count: usize, running_turn_count: usize) {
|
|
if self.requested {
|
|
self.forced = true;
|
|
return;
|
|
}
|
|
|
|
self.requested = true;
|
|
self.last_logged_running_turn_count = None;
|
|
info!(
|
|
"received shutdown signal; entering graceful restart drain (connections={}, runningAssistantTurns={}, requests still accepted until no assistant turns are running)",
|
|
connection_count, running_turn_count,
|
|
);
|
|
}
|
|
|
|
fn update(&mut self, running_turn_count: usize, connection_count: usize) -> ShutdownAction {
|
|
if !self.requested {
|
|
return ShutdownAction::Noop;
|
|
}
|
|
|
|
if self.forced || running_turn_count == 0 {
|
|
if self.forced {
|
|
info!(
|
|
"received second shutdown signal; forcing restart with {running_turn_count} running assistant turn(s) and {connection_count} connection(s)"
|
|
);
|
|
} else {
|
|
info!(
|
|
"shutdown signal restart: no assistant turns running; stopping acceptor and disconnecting {connection_count} connection(s)"
|
|
);
|
|
}
|
|
return ShutdownAction::Finish;
|
|
}
|
|
|
|
if self.last_logged_running_turn_count != Some(running_turn_count) {
|
|
info!(
|
|
"shutdown signal restart: waiting for {running_turn_count} running assistant turn(s) to finish"
|
|
);
|
|
self.last_logged_running_turn_count = Some(running_turn_count);
|
|
}
|
|
|
|
ShutdownAction::Noop
|
|
}
|
|
}
|
|
|
|
fn config_warning_from_error(
|
|
summary: impl Into<String>,
|
|
err: &std::io::Error,
|
|
) -> ConfigWarningNotification {
|
|
let (path, range) = match config_error_location(err) {
|
|
Some((path, range)) => (Some(path), Some(range)),
|
|
None => (None, None),
|
|
};
|
|
ConfigWarningNotification {
|
|
summary: summary.into(),
|
|
details: Some(err.to_string()),
|
|
path,
|
|
range,
|
|
}
|
|
}
|
|
|
|
fn config_error_location(err: &std::io::Error) -> Option<(String, AppTextRange)> {
|
|
err.get_ref()
|
|
.and_then(|err| err.downcast_ref::<ConfigLoadError>())
|
|
.map(|err| {
|
|
let config_error = err.config_error();
|
|
(
|
|
config_error.path.to_string_lossy().to_string(),
|
|
app_text_range(&config_error.range),
|
|
)
|
|
})
|
|
}
|
|
|
|
fn exec_policy_warning_location(err: &ExecPolicyError) -> (Option<String>, Option<AppTextRange>) {
|
|
match err {
|
|
ExecPolicyError::ParsePolicy { path, source } => {
|
|
if let Some(location) = source.location() {
|
|
let range = AppTextRange {
|
|
start: AppTextPosition {
|
|
line: location.range.start.line,
|
|
column: location.range.start.column,
|
|
},
|
|
end: AppTextPosition {
|
|
line: location.range.end.line,
|
|
column: location.range.end.column,
|
|
},
|
|
};
|
|
return (Some(location.path), Some(range));
|
|
}
|
|
(Some(path.clone()), None)
|
|
}
|
|
_ => (None, None),
|
|
}
|
|
}
|
|
|
|
fn app_text_range(range: &CoreTextRange) -> AppTextRange {
|
|
AppTextRange {
|
|
start: AppTextPosition {
|
|
line: range.start.line,
|
|
column: range.start.column,
|
|
},
|
|
end: AppTextPosition {
|
|
line: range.end.line,
|
|
column: range.end.column,
|
|
},
|
|
}
|
|
}
|
|
|
|
fn project_config_warning(config: &Config) -> Option<ConfigWarningNotification> {
|
|
let mut disabled_folders = Vec::new();
|
|
|
|
for layer in config.config_layer_stack.get_layers(
|
|
ConfigLayerStackOrdering::LowestPrecedenceFirst,
|
|
/*include_disabled*/ true,
|
|
) {
|
|
if !matches!(layer.name, ConfigLayerSource::Project { .. })
|
|
|| layer.disabled_reason.is_none()
|
|
{
|
|
continue;
|
|
}
|
|
if let ConfigLayerSource::Project { dot_codex_folder } = &layer.name {
|
|
disabled_folders.push((
|
|
dot_codex_folder.as_path().display().to_string(),
|
|
layer
|
|
.disabled_reason
|
|
.as_ref()
|
|
.map(ToString::to_string)
|
|
.unwrap_or_else(|| "config.toml is disabled.".to_string()),
|
|
));
|
|
}
|
|
}
|
|
|
|
if disabled_folders.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let mut message = concat!(
|
|
"Project config.toml files are disabled in the following folders. ",
|
|
"Settings in those files are ignored, but skills and exec policies still load.\n",
|
|
)
|
|
.to_string();
|
|
for (index, (folder, reason)) in disabled_folders.iter().enumerate() {
|
|
let display_index = index + 1;
|
|
message.push_str(&format!(" {display_index}. {folder}\n"));
|
|
message.push_str(&format!(" {reason}\n"));
|
|
}
|
|
|
|
Some(ConfigWarningNotification {
|
|
summary: message,
|
|
details: None,
|
|
path: None,
|
|
range: None,
|
|
})
|
|
}
|
|
|
|
impl LogFormat {
|
|
fn from_env_value(value: Option<&str>) -> Self {
|
|
match value.map(str::trim).map(str::to_ascii_lowercase) {
|
|
Some(value) if value == "json" => Self::Json,
|
|
_ => Self::Default,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn log_format_from_env() -> LogFormat {
|
|
let value = std::env::var(LOG_FORMAT_ENV_VAR).ok();
|
|
LogFormat::from_env_value(value.as_deref())
|
|
}
|
|
|
|
pub async fn run_main(
|
|
arg0_paths: Arg0DispatchPaths,
|
|
cli_config_overrides: CliConfigOverrides,
|
|
loader_overrides: LoaderOverrides,
|
|
default_analytics_enabled: bool,
|
|
) -> IoResult<()> {
|
|
run_main_with_transport(
|
|
arg0_paths,
|
|
cli_config_overrides,
|
|
loader_overrides,
|
|
default_analytics_enabled,
|
|
AppServerTransport::Stdio,
|
|
SessionSource::VSCode,
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub async fn run_main_with_transport(
|
|
arg0_paths: Arg0DispatchPaths,
|
|
cli_config_overrides: CliConfigOverrides,
|
|
loader_overrides: LoaderOverrides,
|
|
default_analytics_enabled: bool,
|
|
transport: AppServerTransport,
|
|
session_source: SessionSource,
|
|
) -> IoResult<()> {
|
|
let (transport_event_tx, mut transport_event_rx) =
|
|
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
|
|
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingEnvelope>(CHANNEL_CAPACITY);
|
|
let (outbound_control_tx, mut outbound_control_rx) =
|
|
mpsc::channel::<OutboundControlEvent>(CHANNEL_CAPACITY);
|
|
|
|
enum TransportRuntime {
|
|
Stdio,
|
|
WebSocket {
|
|
accept_handle: JoinHandle<()>,
|
|
shutdown_token: CancellationToken,
|
|
},
|
|
}
|
|
|
|
let mut stdio_handles = Vec::<JoinHandle<()>>::new();
|
|
let transport_runtime = match transport {
|
|
AppServerTransport::Stdio => {
|
|
start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?;
|
|
TransportRuntime::Stdio
|
|
}
|
|
AppServerTransport::WebSocket { bind_address } => {
|
|
let shutdown_token = CancellationToken::new();
|
|
let accept_handle = start_websocket_acceptor(
|
|
bind_address,
|
|
transport_event_tx.clone(),
|
|
shutdown_token.clone(),
|
|
)
|
|
.await?;
|
|
TransportRuntime::WebSocket {
|
|
accept_handle,
|
|
shutdown_token,
|
|
}
|
|
}
|
|
};
|
|
let single_client_mode = matches!(&transport_runtime, TransportRuntime::Stdio);
|
|
let shutdown_when_no_connections = single_client_mode;
|
|
let graceful_signal_restart_enabled = !single_client_mode;
|
|
// Parse CLI overrides once and derive the base Config eagerly so later
|
|
// components do not need to work with raw TOML values.
|
|
let cli_kv_overrides = cli_config_overrides.parse_overrides().map_err(|e| {
|
|
std::io::Error::new(
|
|
ErrorKind::InvalidInput,
|
|
format!("error parsing -c overrides: {e}"),
|
|
)
|
|
})?;
|
|
let cloud_requirements = match ConfigBuilder::default()
|
|
.cli_overrides(cli_kv_overrides.clone())
|
|
.loader_overrides(loader_overrides.clone())
|
|
.build()
|
|
.await
|
|
{
|
|
Ok(config) => {
|
|
let effective_toml = config.config_layer_stack.effective_config();
|
|
match effective_toml.try_into() {
|
|
Ok(config_toml) => {
|
|
if let Err(err) = codex_core::personality_migration::maybe_migrate_personality(
|
|
&config.codex_home,
|
|
&config_toml,
|
|
)
|
|
.await
|
|
{
|
|
warn!(error = %err, "Failed to run personality migration");
|
|
}
|
|
}
|
|
Err(err) => {
|
|
warn!(error = %err, "Failed to deserialize config for personality migration");
|
|
}
|
|
}
|
|
|
|
let auth_manager = AuthManager::shared(
|
|
config.codex_home.clone(),
|
|
/*enable_codex_api_key_env*/ false,
|
|
config.cli_auth_credentials_store_mode,
|
|
);
|
|
cloud_requirements_loader(
|
|
auth_manager,
|
|
config.chatgpt_base_url,
|
|
config.codex_home.clone(),
|
|
)
|
|
}
|
|
Err(err) => {
|
|
warn!(error = %err, "Failed to preload config for cloud requirements");
|
|
// TODO(gt): Make cloud requirements preload failures blocking once we can fail-closed.
|
|
CloudRequirementsLoader::default()
|
|
}
|
|
};
|
|
let loader_overrides_for_config_api = loader_overrides.clone();
|
|
let mut config_warnings = Vec::new();
|
|
let config = match ConfigBuilder::default()
|
|
.cli_overrides(cli_kv_overrides.clone())
|
|
.loader_overrides(loader_overrides)
|
|
.cloud_requirements(cloud_requirements.clone())
|
|
.build()
|
|
.await
|
|
{
|
|
Ok(config) => config,
|
|
Err(err) => {
|
|
let message = config_warning_from_error("Invalid configuration; using defaults.", &err);
|
|
config_warnings.push(message);
|
|
Config::load_default_with_cli_overrides(cli_kv_overrides.clone()).map_err(|e| {
|
|
std::io::Error::new(
|
|
ErrorKind::InvalidData,
|
|
format!("error loading default config after config error: {e}"),
|
|
)
|
|
})?
|
|
}
|
|
};
|
|
|
|
if let Ok(Some(err)) = check_execpolicy_for_warnings(&config.config_layer_stack).await {
|
|
let (path, range) = exec_policy_warning_location(&err);
|
|
let message = ConfigWarningNotification {
|
|
summary: "Error parsing rules; custom rules not applied.".to_string(),
|
|
details: Some(err.to_string()),
|
|
path,
|
|
range,
|
|
};
|
|
config_warnings.push(message);
|
|
}
|
|
|
|
if let Some(warning) = project_config_warning(&config) {
|
|
config_warnings.push(warning);
|
|
}
|
|
for warning in &config.startup_warnings {
|
|
config_warnings.push(ConfigWarningNotification {
|
|
summary: warning.clone(),
|
|
details: None,
|
|
path: None,
|
|
range: None,
|
|
});
|
|
}
|
|
if let Some(warning) = codex_core::config::missing_system_bwrap_warning() {
|
|
config_warnings.push(ConfigWarningNotification {
|
|
summary: warning,
|
|
details: None,
|
|
path: None,
|
|
range: None,
|
|
});
|
|
}
|
|
|
|
let feedback = CodexFeedback::new();
|
|
|
|
let otel = codex_core::otel_init::build_provider(
|
|
&config,
|
|
env!("CARGO_PKG_VERSION"),
|
|
Some("codex-app-server"),
|
|
default_analytics_enabled,
|
|
)
|
|
.map_err(|e| {
|
|
std::io::Error::new(
|
|
ErrorKind::InvalidData,
|
|
format!("error loading otel config: {e}"),
|
|
)
|
|
})?;
|
|
|
|
// Install a simple subscriber so `tracing` output is visible. Users can
|
|
// control the log level with `RUST_LOG` and switch to JSON logs with
|
|
// `LOG_FORMAT=json`.
|
|
let stderr_fmt: StderrLogLayer = match log_format_from_env() {
|
|
LogFormat::Json => tracing_subscriber::fmt::layer()
|
|
.json()
|
|
.with_writer(std::io::stderr)
|
|
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL)
|
|
.with_filter(EnvFilter::from_default_env())
|
|
.boxed(),
|
|
LogFormat::Default => tracing_subscriber::fmt::layer()
|
|
.with_writer(std::io::stderr)
|
|
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL)
|
|
.with_filter(EnvFilter::from_default_env())
|
|
.boxed(),
|
|
};
|
|
|
|
let feedback_layer = feedback.logger_layer();
|
|
let feedback_metadata_layer = feedback.metadata_layer();
|
|
let log_db = codex_state::StateRuntime::init(
|
|
config.sqlite_home.clone(),
|
|
config.model_provider_id.clone(),
|
|
)
|
|
.await
|
|
.ok()
|
|
.map(log_db::start);
|
|
let log_db_layer = log_db
|
|
.clone()
|
|
.map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE)));
|
|
let otel_logger_layer = otel.as_ref().and_then(|o| o.logger_layer());
|
|
let otel_tracing_layer = otel.as_ref().and_then(|o| o.tracing_layer());
|
|
let _ = tracing_subscriber::registry()
|
|
.with(stderr_fmt)
|
|
.with(feedback_layer)
|
|
.with(feedback_metadata_layer)
|
|
.with(log_db_layer)
|
|
.with(otel_logger_layer)
|
|
.with(otel_tracing_layer)
|
|
.try_init();
|
|
for warning in &config_warnings {
|
|
match &warning.details {
|
|
Some(details) => error!("{} {}", warning.summary, details),
|
|
None => error!("{}", warning.summary),
|
|
}
|
|
}
|
|
|
|
let outbound_handle = tokio::spawn(async move {
|
|
let mut outbound_connections = HashMap::<ConnectionId, OutboundConnectionState>::new();
|
|
loop {
|
|
tokio::select! {
|
|
biased;
|
|
event = outbound_control_rx.recv() => {
|
|
let Some(event) = event else {
|
|
break;
|
|
};
|
|
match event {
|
|
OutboundControlEvent::Opened {
|
|
connection_id,
|
|
writer,
|
|
allow_legacy_notifications,
|
|
disconnect_sender,
|
|
initialized,
|
|
experimental_api_enabled,
|
|
opted_out_notification_methods,
|
|
} => {
|
|
outbound_connections.insert(
|
|
connection_id,
|
|
OutboundConnectionState::new(
|
|
writer,
|
|
initialized,
|
|
experimental_api_enabled,
|
|
opted_out_notification_methods,
|
|
allow_legacy_notifications,
|
|
disconnect_sender,
|
|
),
|
|
);
|
|
}
|
|
OutboundControlEvent::Closed { connection_id } => {
|
|
outbound_connections.remove(&connection_id);
|
|
}
|
|
OutboundControlEvent::DisconnectAll => {
|
|
info!(
|
|
"disconnecting {} outbound websocket connection(s) for graceful restart",
|
|
outbound_connections.len()
|
|
);
|
|
for connection_state in outbound_connections.values() {
|
|
connection_state.request_disconnect();
|
|
}
|
|
outbound_connections.clear();
|
|
}
|
|
}
|
|
}
|
|
envelope = outgoing_rx.recv() => {
|
|
let Some(envelope) = envelope else {
|
|
break;
|
|
};
|
|
route_outgoing_envelope(&mut outbound_connections, envelope).await;
|
|
}
|
|
}
|
|
}
|
|
info!("outbound router task exited (channel closed)");
|
|
});
|
|
|
|
let processor_handle = tokio::spawn({
|
|
let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
|
let outbound_control_tx = outbound_control_tx;
|
|
let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone();
|
|
let loader_overrides = loader_overrides_for_config_api;
|
|
let mut processor = MessageProcessor::new(MessageProcessorArgs {
|
|
outgoing: outgoing_message_sender,
|
|
arg0_paths,
|
|
config: Arc::new(config),
|
|
cli_overrides,
|
|
loader_overrides,
|
|
cloud_requirements: cloud_requirements.clone(),
|
|
auth_manager: None,
|
|
thread_manager: None,
|
|
feedback: feedback.clone(),
|
|
log_db,
|
|
config_warnings,
|
|
session_source,
|
|
enable_codex_api_key_env: false,
|
|
});
|
|
let mut thread_created_rx = processor.thread_created_receiver();
|
|
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
|
|
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
|
|
let websocket_accept_shutdown = match &transport_runtime {
|
|
TransportRuntime::WebSocket { shutdown_token, .. } => Some(shutdown_token.clone()),
|
|
TransportRuntime::Stdio => None,
|
|
};
|
|
async move {
|
|
let mut listen_for_threads = true;
|
|
let mut shutdown_state = ShutdownState::default();
|
|
loop {
|
|
let running_turn_count = {
|
|
let running_turn_count = running_turn_count_rx.borrow();
|
|
*running_turn_count
|
|
};
|
|
if matches!(
|
|
shutdown_state.update(running_turn_count, connections.len()),
|
|
ShutdownAction::Finish
|
|
) {
|
|
if let Some(shutdown_token) = &websocket_accept_shutdown {
|
|
shutdown_token.cancel();
|
|
}
|
|
let _ = outbound_control_tx
|
|
.send(OutboundControlEvent::DisconnectAll)
|
|
.await;
|
|
break;
|
|
}
|
|
|
|
tokio::select! {
|
|
shutdown_signal_result = shutdown_signal(), if graceful_signal_restart_enabled && !shutdown_state.forced() => {
|
|
if let Err(err) = shutdown_signal_result {
|
|
warn!("failed to listen for shutdown signal during graceful restart drain: {err}");
|
|
}
|
|
let running_turn_count = *running_turn_count_rx.borrow();
|
|
shutdown_state.on_signal(connections.len(), running_turn_count);
|
|
}
|
|
changed = running_turn_count_rx.changed(), if graceful_signal_restart_enabled && shutdown_state.requested() => {
|
|
if changed.is_err() {
|
|
warn!("running-turn watcher closed during graceful restart drain");
|
|
}
|
|
}
|
|
event = transport_event_rx.recv() => {
|
|
let Some(event) = event else {
|
|
break;
|
|
};
|
|
match event {
|
|
TransportEvent::ConnectionOpened {
|
|
connection_id,
|
|
writer,
|
|
allow_legacy_notifications,
|
|
disconnect_sender,
|
|
} => {
|
|
let outbound_initialized = Arc::new(AtomicBool::new(false));
|
|
let outbound_experimental_api_enabled =
|
|
Arc::new(AtomicBool::new(false));
|
|
let outbound_opted_out_notification_methods =
|
|
Arc::new(RwLock::new(HashSet::new()));
|
|
if outbound_control_tx
|
|
.send(OutboundControlEvent::Opened {
|
|
connection_id,
|
|
writer,
|
|
allow_legacy_notifications,
|
|
disconnect_sender,
|
|
initialized: Arc::clone(&outbound_initialized),
|
|
experimental_api_enabled: Arc::clone(
|
|
&outbound_experimental_api_enabled,
|
|
),
|
|
opted_out_notification_methods: Arc::clone(
|
|
&outbound_opted_out_notification_methods,
|
|
),
|
|
})
|
|
.await
|
|
.is_err()
|
|
{
|
|
break;
|
|
}
|
|
connections.insert(
|
|
connection_id,
|
|
ConnectionState::new(
|
|
outbound_initialized,
|
|
outbound_experimental_api_enabled,
|
|
outbound_opted_out_notification_methods,
|
|
),
|
|
);
|
|
}
|
|
TransportEvent::ConnectionClosed { connection_id } => {
|
|
if connections.remove(&connection_id).is_none() {
|
|
continue;
|
|
}
|
|
if outbound_control_tx
|
|
.send(OutboundControlEvent::Closed { connection_id })
|
|
.await
|
|
.is_err()
|
|
{
|
|
break;
|
|
}
|
|
processor.connection_closed(connection_id).await;
|
|
if shutdown_when_no_connections && connections.is_empty() {
|
|
break;
|
|
}
|
|
}
|
|
TransportEvent::IncomingMessage { connection_id, message } => {
|
|
match message {
|
|
JSONRPCMessage::Request(request) => {
|
|
let Some(connection_state) = connections.get_mut(&connection_id) else {
|
|
warn!("dropping request from unknown connection: {connection_id:?}");
|
|
continue;
|
|
};
|
|
let was_initialized = connection_state.session.initialized;
|
|
processor
|
|
.process_request(
|
|
connection_id,
|
|
request,
|
|
transport,
|
|
&mut connection_state.session,
|
|
)
|
|
.await;
|
|
if let Ok(mut opted_out_notification_methods) = connection_state
|
|
.outbound_opted_out_notification_methods
|
|
.write()
|
|
{
|
|
*opted_out_notification_methods = connection_state
|
|
.session
|
|
.opted_out_notification_methods
|
|
.clone();
|
|
} else {
|
|
warn!(
|
|
"failed to update outbound opted-out notifications"
|
|
);
|
|
}
|
|
connection_state
|
|
.outbound_experimental_api_enabled
|
|
.store(
|
|
connection_state.session.experimental_api_enabled,
|
|
std::sync::atomic::Ordering::Release,
|
|
);
|
|
if !was_initialized && connection_state.session.initialized {
|
|
processor
|
|
.send_initialize_notifications_to_connection(
|
|
connection_id,
|
|
)
|
|
.await;
|
|
processor.connection_initialized(connection_id).await;
|
|
connection_state
|
|
.outbound_initialized
|
|
.store(true, std::sync::atomic::Ordering::Release);
|
|
}
|
|
}
|
|
JSONRPCMessage::Response(response) => {
|
|
if !connections.contains_key(&connection_id) {
|
|
warn!("dropping response from unknown connection: {connection_id:?}");
|
|
continue;
|
|
}
|
|
processor.process_response(response).await;
|
|
}
|
|
JSONRPCMessage::Notification(notification) => {
|
|
if !connections.contains_key(&connection_id) {
|
|
warn!("dropping notification from unknown connection: {connection_id:?}");
|
|
continue;
|
|
}
|
|
processor.process_notification(notification).await;
|
|
}
|
|
JSONRPCMessage::Error(err) => {
|
|
if !connections.contains_key(&connection_id) {
|
|
warn!("dropping error from unknown connection: {connection_id:?}");
|
|
continue;
|
|
}
|
|
processor.process_error(err).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
created = thread_created_rx.recv(), if listen_for_threads => {
|
|
match created {
|
|
Ok(thread_id) => {
|
|
let initialized_connection_ids: Vec<ConnectionId> = connections
|
|
.iter()
|
|
.filter_map(|(connection_id, connection_state)| {
|
|
connection_state.session.initialized.then_some(*connection_id)
|
|
})
|
|
.collect();
|
|
processor
|
|
.try_attach_thread_listener(
|
|
thread_id,
|
|
initialized_connection_ids,
|
|
)
|
|
.await;
|
|
}
|
|
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
|
|
// TODO(jif) handle lag.
|
|
// Assumes thread creation volume is low enough that lag never happens.
|
|
// If it does, we log and continue without resyncing to avoid attaching
|
|
// listeners for threads that should remain unsubscribed.
|
|
warn!("thread_created receiver lagged; skipping resync");
|
|
}
|
|
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
|
listen_for_threads = false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !shutdown_state.forced() {
|
|
processor.drain_background_tasks().await;
|
|
processor.shutdown_threads().await;
|
|
}
|
|
info!("processor task exited (channel closed)");
|
|
}
|
|
});
|
|
|
|
drop(transport_event_tx);
|
|
|
|
let _ = processor_handle.await;
|
|
let _ = outbound_handle.await;
|
|
|
|
if let TransportRuntime::WebSocket {
|
|
accept_handle,
|
|
shutdown_token,
|
|
} = transport_runtime
|
|
{
|
|
shutdown_token.cancel();
|
|
let _ = accept_handle.await;
|
|
}
|
|
|
|
for handle in stdio_handles {
|
|
let _ = handle.await;
|
|
}
|
|
|
|
if let Some(otel) = otel {
|
|
otel.shutdown();
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::LogFormat;
|
|
use pretty_assertions::assert_eq;
|
|
|
|
#[test]
|
|
fn log_format_from_env_value_matches_json_values_case_insensitively() {
|
|
assert_eq!(LogFormat::from_env_value(Some("json")), LogFormat::Json);
|
|
assert_eq!(LogFormat::from_env_value(Some("JSON")), LogFormat::Json);
|
|
assert_eq!(LogFormat::from_env_value(Some(" Json ")), LogFormat::Json);
|
|
}
|
|
|
|
#[test]
|
|
fn log_format_from_env_value_defaults_for_non_json_values() {
|
|
assert_eq!(LogFormat::from_env_value(None), LogFormat::Default);
|
|
assert_eq!(LogFormat::from_env_value(Some("")), LogFormat::Default);
|
|
assert_eq!(LogFormat::from_env_value(Some("text")), LogFormat::Default);
|
|
assert_eq!(LogFormat::from_env_value(Some("jsonl")), LogFormat::Default);
|
|
}
|
|
}
|