Reapply "Add app-server transport layer with websocket support" (#11370)
Reapply "Add app-server transport layer with websocket support" with
additional fixes from https://github.com/openai/codex/pull/11313/changes
to avoid deadlocking.
This reverts commit 47356ff83c.
## Summary
To avoid deadlocking when queues are full, we maintain separate tokio
tasks dedicated to incoming vs outgoing event handling
- split the app-server main loop into two tasks in
`run_main_with_transport`
- inbound handling (`transport_event_rx`)
- outbound handling (`outgoing_rx` + `thread_created_rx`)
- separate incoming and outgoing websocket tasks
## Validation
Integration tests, testing thoroughly e2e in codex app w/ >10 concurrent
requests
<img width="1365" height="979" alt="Screenshot 2026-02-10 at 2 54 22 PM"
src="https://github.com/user-attachments/assets/47ca2c13-f322-4e5c-bedd-25859cbdc45f"
/>
---------
Co-authored-by: jif-oai <jif@openai.com>
This commit is contained in:
parent
577a416f9a
commit
7053aa5457
19 changed files with 1940 additions and 388 deletions
5
codex-rs/Cargo.lock
generated
5
codex-rs/Cargo.lock
generated
|
|
@ -1348,6 +1348,7 @@ dependencies = [
|
|||
"axum",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-app-server-protocol",
|
||||
"codex-arg0",
|
||||
"codex-backend-client",
|
||||
|
|
@ -1361,10 +1362,13 @@ dependencies = [
|
|||
"codex-protocol",
|
||||
"codex-rmcp-client",
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-cargo-bin",
|
||||
"codex-utils-cli",
|
||||
"codex-utils-json-to-toml",
|
||||
"core_test_support",
|
||||
"futures",
|
||||
"os_info",
|
||||
"owo-colors",
|
||||
"pretty_assertions",
|
||||
"rmcp",
|
||||
"serde",
|
||||
|
|
@ -1374,6 +1378,7 @@ dependencies = [
|
|||
"tempfile",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"toml 0.9.12+spec-1.1.0",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
|
|
|||
|
|
@ -30,8 +30,12 @@ codex-protocol = { workspace = true }
|
|||
codex-app-server-protocol = { workspace = true }
|
||||
codex-feedback = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-json-to-toml = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
futures = { workspace = true }
|
||||
owo-colors = { workspace = true, features = ["supports-colors"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
|
@ -44,6 +48,7 @@ tokio = { workspace = true, features = [
|
|||
"rt-multi-thread",
|
||||
"signal",
|
||||
] }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] }
|
||||
uuid = { workspace = true, features = ["serde", "v7"] }
|
||||
|
|
@ -57,8 +62,8 @@ axum = { workspace = true, default-features = false, features = [
|
|||
] }
|
||||
base64 = { workspace = true }
|
||||
codex-execpolicy = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
core_test_support = { workspace = true }
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
rmcp = { workspace = true, default-features = false, features = [
|
||||
|
|
@ -66,5 +71,6 @@ rmcp = { workspace = true, default-features = false, features = [
|
|||
"transport-streamable-http-server",
|
||||
] }
|
||||
serial_test = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
wiremock = { workspace = true }
|
||||
shlex = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -19,7 +19,20 @@
|
|||
|
||||
## Protocol
|
||||
|
||||
Similar to [MCP](https://modelcontextprotocol.io/), `codex app-server` supports bidirectional communication, streaming JSONL over stdio. The protocol is JSON-RPC 2.0, though the `"jsonrpc":"2.0"` header is omitted.
|
||||
Similar to [MCP](https://modelcontextprotocol.io/), `codex app-server` supports bidirectional communication using JSON-RPC 2.0 messages (with the `"jsonrpc":"2.0"` header omitted on the wire).
|
||||
|
||||
Supported transports:
|
||||
|
||||
- stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL)
|
||||
- websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**)
|
||||
|
||||
Websocket transport is currently experimental and unsupported. Do not rely on it for production workloads.
|
||||
|
||||
Backpressure behavior:
|
||||
|
||||
- The server uses bounded queues between transport ingress, request processing, and outbound writes.
|
||||
- When request ingress is saturated, new requests are rejected with a JSON-RPC error code `-32001` and message `"Server overloaded; retry later."`.
|
||||
- Clients should treat this as retryable and use exponential backoff with jitter.
|
||||
|
||||
## Message Schema
|
||||
|
||||
|
|
@ -42,7 +55,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat
|
|||
|
||||
## Lifecycle Overview
|
||||
|
||||
- Initialize once: Immediately after launching the codex app-server process, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request before this handshake gets rejected.
|
||||
- Initialize once per connection: Immediately after opening a transport connection, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request on that connection before this handshake gets rejected.
|
||||
- Start (or resume) a thread: Call `thread/start` to open a fresh conversation. The response returns the thread object and you’ll also get a `thread/started` notification. If you’re continuing an existing conversation, call `thread/resume` with its ID instead. If you want to branch from an existing conversation, call `thread/fork` to create a new thread id with copied history.
|
||||
- Begin a turn: To send user input, call `turn/start` with the target `threadId` and the user's input. Optional fields let you override model, cwd, sandbox policy, etc. This immediately returns the new turn object and triggers a `turn/started` notification.
|
||||
- Stream events: After `turn/start`, keep reading JSON-RPC notifications on stdout. You’ll see `item/started`, `item/completed`, deltas like `item/agentMessage/delta`, tool progress, etc. These represent streaming model output plus any side effects (commands, tool calls, reasoning notes).
|
||||
|
|
@ -50,7 +63,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat
|
|||
|
||||
## Initialization
|
||||
|
||||
Clients must send a single `initialize` request before invoking any other method, then acknowledge with an `initialized` notification. The server returns the user agent string it will present to upstream services; subsequent requests issued before initialization receive a `"Not initialized"` error, and repeated `initialize` calls receive an `"Already initialized"` error.
|
||||
Clients must send a single `initialize` request per transport connection before invoking any other method on that connection, then acknowledge with an `initialized` notification. The server returns the user agent string it will present to upstream services; subsequent requests issued before initialization receive a `"Not initialized"` error, and repeated `initialize` calls on the same connection receive an `"Already initialized"` error.
|
||||
|
||||
`initialize.params.capabilities` also supports per-connection notification opt-out via `optOutNotificationMethods`, which is a list of exact method names to suppress for that connection. Matching is exact (no wildcards/prefixes). Unknown method names are accepted and ignored.
|
||||
|
||||
|
|
|
|||
|
|
@ -1115,7 +1115,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
),
|
||||
data: None,
|
||||
};
|
||||
outgoing.send_error(request_id, error).await;
|
||||
outgoing.send_error(request_id.clone(), error).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -1129,7 +1129,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
),
|
||||
data: None,
|
||||
};
|
||||
outgoing.send_error(request_id, error).await;
|
||||
outgoing.send_error(request_id.clone(), error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
@ -1894,6 +1894,7 @@ async fn construct_mcp_tool_call_end_notification(
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::CHANNEL_CAPACITY;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use anyhow::Result;
|
||||
|
|
@ -1923,6 +1924,21 @@ mod tests {
|
|||
Arc::new(Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
async fn recv_broadcast_message(
|
||||
rx: &mut mpsc::Receiver<OutgoingEnvelope>,
|
||||
) -> Result<OutgoingMessage> {
|
||||
let envelope = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send one message"))?;
|
||||
match envelope {
|
||||
OutgoingEnvelope::Broadcast { message } => Ok(message),
|
||||
OutgoingEnvelope::ToConnection { connection_id, .. } => {
|
||||
bail!("unexpected targeted message for connection {connection_id:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_change_accept_for_session_maps_to_approved_for_session() {
|
||||
let (decision, completion_status) =
|
||||
|
|
@ -2024,10 +2040,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, event_turn_id);
|
||||
|
|
@ -2066,10 +2079,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, event_turn_id);
|
||||
|
|
@ -2108,10 +2118,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, event_turn_id);
|
||||
|
|
@ -2160,10 +2167,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnPlanUpdated(n)) => {
|
||||
assert_eq!(n.thread_id, conversation_id.to_string());
|
||||
|
|
@ -2233,10 +2237,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
|
||||
let first = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("expected usage notification"))?;
|
||||
let first = recv_broadcast_message(&mut rx).await?;
|
||||
match first {
|
||||
OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ThreadTokenUsageUpdated(payload),
|
||||
|
|
@ -2252,10 +2253,7 @@ mod tests {
|
|||
other => bail!("unexpected notification: {other:?}"),
|
||||
}
|
||||
|
||||
let second = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("expected rate limit notification"))?;
|
||||
let second = recv_broadcast_message(&mut rx).await?;
|
||||
match second {
|
||||
OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::AccountRateLimitsUpdated(payload),
|
||||
|
|
@ -2394,10 +2392,7 @@ mod tests {
|
|||
.await;
|
||||
|
||||
// Verify: A turn 1
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send first notification"))?;
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, a_turn1);
|
||||
|
|
@ -2415,10 +2410,7 @@ mod tests {
|
|||
}
|
||||
|
||||
// Verify: B turn 1
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send second notification"))?;
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, b_turn1);
|
||||
|
|
@ -2436,10 +2428,7 @@ mod tests {
|
|||
}
|
||||
|
||||
// Verify: A turn 2
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send third notification"))?;
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => {
|
||||
assert_eq!(n.turn.id, a_turn2);
|
||||
|
|
@ -2605,10 +2594,7 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
|
||||
let msg = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("should send one notification"))?;
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
match msg {
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::TurnDiffUpdated(
|
||||
notification,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,2 +1,3 @@
|
|||
pub(crate) const INVALID_REQUEST_ERROR_CODE: i64 = -32600;
|
||||
pub(crate) const INTERNAL_ERROR_CODE: i64 = -32603;
|
||||
pub(crate) const OVERLOADED_ERROR_CODE: i64 = -32001;
|
||||
|
|
|
|||
|
|
@ -8,14 +8,29 @@ use codex_core::config_loader::CloudRequirementsLoader;
|
|||
use codex_core::config_loader::ConfigLayerStackOrdering;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::message_processor::MessageProcessorArgs;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::ConnectionState;
|
||||
use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::TransportEvent;
|
||||
use crate::transport::has_initialized_connections;
|
||||
use crate::transport::route_outgoing_envelope;
|
||||
use crate::transport::start_stdio_connection;
|
||||
use crate::transport::start_websocket_acceptor;
|
||||
use codex_app_server_protocol::ConfigLayerSource;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
|
|
@ -26,13 +41,9 @@ use codex_core::check_execpolicy_for_warnings;
|
|||
use codex_core::config_loader::ConfigLoadError;
|
||||
use codex_core::config_loader::TextRange as CoreTextRange;
|
||||
use codex_feedback::CodexFeedback;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::{self};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use toml::Value as TomlValue;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
|
@ -51,11 +62,30 @@ mod fuzzy_file_search;
|
|||
mod message_processor;
|
||||
mod models;
|
||||
mod outgoing_message;
|
||||
mod transport;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage – 128 messages should be
|
||||
/// plenty for an interactive CLI.
|
||||
const CHANNEL_CAPACITY: usize = 128;
|
||||
pub use crate::transport::AppServerTransport;
|
||||
|
||||
/// Control-plane messages from the processor/transport side to the outbound router task.
|
||||
///
|
||||
/// `run_main_with_transport` now uses two loops/tasks:
|
||||
/// - processor loop: handles incoming JSON-RPC and request dispatch
|
||||
/// - outbound loop: performs potentially slow writes to per-connection writers
|
||||
///
|
||||
/// `OutboundControlEvent` keeps those loops coordinated without sharing mutable
|
||||
/// connection state directly. In particular, the outbound loop needs to know
|
||||
/// when a connection opens/closes so it can route messages correctly.
|
||||
enum OutboundControlEvent {
|
||||
/// Register a new writer for an opened connection.
|
||||
Opened {
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<crate::outgoing_message::OutgoingMessage>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
},
|
||||
/// Remove state for a closed/disconnected connection.
|
||||
Closed { connection_id: ConnectionId },
|
||||
}
|
||||
|
||||
fn config_warning_from_error(
|
||||
summary: impl Into<String>,
|
||||
|
|
@ -173,32 +203,41 @@ pub async fn run_main(
|
|||
loader_overrides: LoaderOverrides,
|
||||
default_analytics_enabled: bool,
|
||||
) -> IoResult<()> {
|
||||
// Set up channels.
|
||||
let (incoming_tx, mut incoming_rx) = mpsc::channel::<JSONRPCMessage>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
run_main_with_transport(
|
||||
codex_linux_sandbox_exe,
|
||||
cli_config_overrides,
|
||||
loader_overrides,
|
||||
default_analytics_enabled,
|
||||
AppServerTransport::Stdio,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Task: read from stdin, push to `incoming_tx`.
|
||||
let stdin_reader_handle = tokio::spawn({
|
||||
async move {
|
||||
let stdin = io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
let mut lines = reader.lines();
|
||||
pub async fn run_main_with_transport(
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
cli_config_overrides: CliConfigOverrides,
|
||||
loader_overrides: LoaderOverrides,
|
||||
default_analytics_enabled: bool,
|
||||
transport: AppServerTransport,
|
||||
) -> IoResult<()> {
|
||||
let (transport_event_tx, mut transport_event_rx) =
|
||||
mpsc::channel::<TransportEvent>(CHANNEL_CAPACITY);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingEnvelope>(CHANNEL_CAPACITY);
|
||||
let (outbound_control_tx, mut outbound_control_rx) =
|
||||
mpsc::channel::<OutboundControlEvent>(CHANNEL_CAPACITY);
|
||||
|
||||
while let Some(line) = lines.next_line().await.unwrap_or_default() {
|
||||
match serde_json::from_str::<JSONRPCMessage>(&line) {
|
||||
Ok(msg) => {
|
||||
if incoming_tx.send(msg).await.is_err() {
|
||||
// Receiver gone – nothing left to do.
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => error!("Failed to deserialize JSONRPCMessage: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
debug!("stdin reader finished (EOF)");
|
||||
let mut stdio_handles = Vec::<JoinHandle<()>>::new();
|
||||
let mut websocket_accept_handle = None;
|
||||
match transport {
|
||||
AppServerTransport::Stdio => {
|
||||
start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?;
|
||||
}
|
||||
});
|
||||
AppServerTransport::WebSocket { bind_address } => {
|
||||
websocket_accept_handle =
|
||||
Some(start_websocket_acceptor(bind_address, transport_event_tx.clone()).await?);
|
||||
}
|
||||
}
|
||||
let shutdown_when_no_connections = matches!(transport, AppServerTransport::Stdio);
|
||||
|
||||
// Parse CLI overrides once and derive the base Config eagerly so later
|
||||
// components do not need to work with raw TOML values.
|
||||
|
|
@ -329,15 +368,76 @@ pub async fn run_main(
|
|||
}
|
||||
}
|
||||
|
||||
// Task: process incoming messages.
|
||||
let transport_event_tx_for_outbound = transport_event_tx.clone();
|
||||
let outbound_handle = tokio::spawn(async move {
|
||||
let mut outbound_connections = HashMap::<ConnectionId, OutboundConnectionState>::new();
|
||||
let mut pending_closed_connections = VecDeque::<ConnectionId>::new();
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
event = outbound_control_rx.recv() => {
|
||||
let Some(event) = event else {
|
||||
break;
|
||||
};
|
||||
match event {
|
||||
OutboundControlEvent::Opened {
|
||||
connection_id,
|
||||
writer,
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
} => {
|
||||
outbound_connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer,
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
),
|
||||
);
|
||||
}
|
||||
OutboundControlEvent::Closed { connection_id } => {
|
||||
outbound_connections.remove(&connection_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
envelope = outgoing_rx.recv() => {
|
||||
let Some(envelope) = envelope else {
|
||||
break;
|
||||
};
|
||||
let disconnected_connections =
|
||||
route_outgoing_envelope(&mut outbound_connections, envelope).await;
|
||||
pending_closed_connections.extend(disconnected_connections);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(connection_id) = pending_closed_connections.front().copied() {
|
||||
match transport_event_tx_for_outbound
|
||||
.try_send(TransportEvent::ConnectionClosed { connection_id })
|
||||
{
|
||||
Ok(()) => {
|
||||
pending_closed_connections.pop_front();
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
break;
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("outbound router task exited (channel closed)");
|
||||
});
|
||||
|
||||
let processor_handle = tokio::spawn({
|
||||
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
|
||||
let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
||||
let outbound_control_tx = outbound_control_tx;
|
||||
let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone();
|
||||
let loader_overrides = loader_overrides_for_config_api;
|
||||
let mut processor = MessageProcessor::new(MessageProcessorArgs {
|
||||
outgoing: outgoing_message_sender,
|
||||
codex_linux_sandbox_exe,
|
||||
config: std::sync::Arc::new(config),
|
||||
config: Arc::new(config),
|
||||
cli_overrides,
|
||||
loader_overrides,
|
||||
cloud_requirements: cloud_requirements.clone(),
|
||||
|
|
@ -345,25 +445,107 @@ pub async fn run_main(
|
|||
config_warnings,
|
||||
});
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
|
||||
async move {
|
||||
let mut listen_for_threads = true;
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = incoming_rx.recv() => {
|
||||
let Some(msg) = msg else {
|
||||
event = transport_event_rx.recv() => {
|
||||
let Some(event) = event else {
|
||||
break;
|
||||
};
|
||||
match msg {
|
||||
JSONRPCMessage::Request(r) => processor.process_request(r).await,
|
||||
JSONRPCMessage::Response(r) => processor.process_response(r).await,
|
||||
JSONRPCMessage::Notification(n) => processor.process_notification(n).await,
|
||||
JSONRPCMessage::Error(e) => processor.process_error(e).await,
|
||||
match event {
|
||||
TransportEvent::ConnectionOpened { connection_id, writer } => {
|
||||
let outbound_initialized = Arc::new(AtomicBool::new(false));
|
||||
let outbound_opted_out_notification_methods =
|
||||
Arc::new(RwLock::new(HashSet::new()));
|
||||
if outbound_control_tx
|
||||
.send(OutboundControlEvent::Opened {
|
||||
connection_id,
|
||||
writer,
|
||||
initialized: Arc::clone(&outbound_initialized),
|
||||
opted_out_notification_methods: Arc::clone(
|
||||
&outbound_opted_out_notification_methods,
|
||||
),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
connections.insert(
|
||||
connection_id,
|
||||
ConnectionState::new(
|
||||
outbound_initialized,
|
||||
outbound_opted_out_notification_methods,
|
||||
),
|
||||
);
|
||||
}
|
||||
TransportEvent::ConnectionClosed { connection_id } => {
|
||||
if outbound_control_tx
|
||||
.send(OutboundControlEvent::Closed { connection_id })
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
connections.remove(&connection_id);
|
||||
if shutdown_when_no_connections && connections.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
TransportEvent::IncomingMessage { connection_id, message } => {
|
||||
match message {
|
||||
JSONRPCMessage::Request(request) => {
|
||||
let Some(connection_state) = connections.get_mut(&connection_id) else {
|
||||
warn!("dropping request from unknown connection: {:?}", connection_id);
|
||||
continue;
|
||||
};
|
||||
let was_initialized = connection_state.session.initialized;
|
||||
processor
|
||||
.process_request(
|
||||
connection_id,
|
||||
request,
|
||||
&mut connection_state.session,
|
||||
&connection_state.outbound_initialized,
|
||||
)
|
||||
.await;
|
||||
if let Ok(mut opted_out_notification_methods) = connection_state
|
||||
.outbound_opted_out_notification_methods
|
||||
.write()
|
||||
{
|
||||
*opted_out_notification_methods = connection_state
|
||||
.session
|
||||
.opted_out_notification_methods
|
||||
.clone();
|
||||
} else {
|
||||
warn!(
|
||||
"failed to update outbound opted-out notifications"
|
||||
);
|
||||
}
|
||||
if !was_initialized && connection_state.session.initialized {
|
||||
processor.send_initialize_notifications().await;
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Response(response) => {
|
||||
processor.process_response(response).await;
|
||||
}
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
processor.process_notification(notification).await;
|
||||
}
|
||||
JSONRPCMessage::Error(err) => {
|
||||
processor.process_error(err).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
created = thread_created_rx.recv(), if listen_for_threads => {
|
||||
match created {
|
||||
Ok(thread_id) => {
|
||||
processor.try_attach_thread_listener(thread_id).await;
|
||||
if has_initialized_connections(&connections) {
|
||||
processor.try_attach_thread_listener(thread_id).await;
|
||||
}
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
|
||||
// TODO(jif) handle lag.
|
||||
|
|
@ -384,33 +566,18 @@ pub async fn run_main(
|
|||
}
|
||||
});
|
||||
|
||||
// Task: write outgoing messages to stdout.
|
||||
let stdout_writer_handle = tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(outgoing_message) = outgoing_rx.recv().await {
|
||||
let Ok(value) = serde_json::to_value(outgoing_message) else {
|
||||
error!("Failed to convert OutgoingMessage to JSON value");
|
||||
continue;
|
||||
};
|
||||
match serde_json::to_string(&value) {
|
||||
Ok(mut json) => {
|
||||
json.push('\n');
|
||||
if let Err(e) = stdout.write_all(json.as_bytes()).await {
|
||||
error!("Failed to write to stdout: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => error!("Failed to serialize JSONRPCMessage: {e}"),
|
||||
}
|
||||
}
|
||||
drop(transport_event_tx);
|
||||
|
||||
info!("stdout writer exited (channel closed)");
|
||||
});
|
||||
let _ = processor_handle.await;
|
||||
let _ = outbound_handle.await;
|
||||
|
||||
// Wait for all tasks to finish. The typical exit path is the stdin reader
|
||||
// hitting EOF which, once it drops `incoming_tx`, propagates shutdown to
|
||||
// the processor and then to the stdout task.
|
||||
let _ = tokio::join!(stdin_reader_handle, processor_handle, stdout_writer_handle);
|
||||
if let Some(handle) = websocket_accept_handle {
|
||||
handle.abort();
|
||||
}
|
||||
|
||||
for handle in stdio_handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
use codex_app_server::run_main;
|
||||
use clap::Parser;
|
||||
use codex_app_server::AppServerTransport;
|
||||
use codex_app_server::run_main_with_transport;
|
||||
use codex_arg0::arg0_dispatch_or_else;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
|
|
@ -8,19 +10,34 @@ use std::path::PathBuf;
|
|||
// managed config file without writing to /etc.
|
||||
const MANAGED_CONFIG_PATH_ENV_VAR: &str = "CODEX_APP_SERVER_MANAGED_CONFIG_PATH";
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct AppServerArgs {
|
||||
/// Transport endpoint URL. Supported values: `stdio://` (default),
|
||||
/// `ws://IP:PORT`.
|
||||
#[arg(
|
||||
long = "listen",
|
||||
value_name = "URL",
|
||||
default_value = AppServerTransport::DEFAULT_LISTEN_URL
|
||||
)]
|
||||
listen: AppServerTransport,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move {
|
||||
let args = AppServerArgs::parse();
|
||||
let managed_config_path = managed_config_path_from_debug_env();
|
||||
let loader_overrides = LoaderOverrides {
|
||||
managed_config_path,
|
||||
..Default::default()
|
||||
};
|
||||
let transport = args.listen;
|
||||
|
||||
run_main(
|
||||
run_main_with_transport(
|
||||
codex_linux_sandbox_exe,
|
||||
CliConfigOverrides::default(),
|
||||
loader_overrides,
|
||||
false,
|
||||
transport,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
|
|
@ -8,6 +9,8 @@ use crate::codex_message_processor::CodexMessageProcessor;
|
|||
use crate::codex_message_processor::CodexMessageProcessorArgs;
|
||||
use crate::config_api::ConfigApi;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::ConnectionRequestId;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use async_trait::async_trait;
|
||||
use codex_app_server_protocol::ChatgptAuthTokensRefreshParams;
|
||||
|
|
@ -26,7 +29,6 @@ use codex_app_server_protocol::JSONRPCErrorError;
|
|||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ServerRequestPayload;
|
||||
use codex_app_server_protocol::experimental_required_message;
|
||||
|
|
@ -112,13 +114,18 @@ pub(crate) struct MessageProcessor {
|
|||
codex_message_processor: CodexMessageProcessor,
|
||||
config_api: ConfigApi,
|
||||
config: Arc<Config>,
|
||||
initialized: bool,
|
||||
experimental_api_enabled: Arc<AtomicBool>,
|
||||
config_warnings: Vec<ConfigWarningNotification>,
|
||||
config_warnings: Arc<Vec<ConfigWarningNotification>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub(crate) struct ConnectionSessionState {
|
||||
pub(crate) initialized: bool,
|
||||
experimental_api_enabled: bool,
|
||||
pub(crate) opted_out_notification_methods: HashSet<String>,
|
||||
}
|
||||
|
||||
pub(crate) struct MessageProcessorArgs {
|
||||
pub(crate) outgoing: OutgoingMessageSender,
|
||||
pub(crate) outgoing: Arc<OutgoingMessageSender>,
|
||||
pub(crate) codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub(crate) config: Arc<Config>,
|
||||
pub(crate) cli_overrides: Vec<(String, TomlValue)>,
|
||||
|
|
@ -142,8 +149,6 @@ impl MessageProcessor {
|
|||
feedback,
|
||||
config_warnings,
|
||||
} = args;
|
||||
let outgoing = Arc::new(outgoing);
|
||||
let experimental_api_enabled = Arc::new(AtomicBool::new(false));
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
false,
|
||||
|
|
@ -181,14 +186,21 @@ impl MessageProcessor {
|
|||
codex_message_processor,
|
||||
config_api,
|
||||
config,
|
||||
initialized: false,
|
||||
experimental_api_enabled,
|
||||
config_warnings,
|
||||
config_warnings: Arc::new(config_warnings),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
|
||||
let request_id = request.id.clone();
|
||||
pub(crate) async fn process_request(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
request: JSONRPCRequest,
|
||||
session: &mut ConnectionSessionState,
|
||||
outbound_initialized: &AtomicBool,
|
||||
) {
|
||||
let request_id = ConnectionRequestId {
|
||||
connection_id,
|
||||
request_id: request.id.clone(),
|
||||
};
|
||||
let request_json = match serde_json::to_value(&request) {
|
||||
Ok(request_json) => request_json,
|
||||
Err(err) => {
|
||||
|
|
@ -219,7 +231,11 @@ impl MessageProcessor {
|
|||
// Handle Initialize internally so CodexMessageProcessor does not have to concern
|
||||
// itself with the `initialized` bool.
|
||||
ClientRequest::Initialize { request_id, params } => {
|
||||
if self.initialized {
|
||||
let request_id = ConnectionRequestId {
|
||||
connection_id,
|
||||
request_id,
|
||||
};
|
||||
if session.initialized {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "Already initialized".to_string(),
|
||||
|
|
@ -228,6 +244,12 @@ impl MessageProcessor {
|
|||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
} else {
|
||||
// TODO(maxj): Revisit capability scoping for `experimental_api_enabled`.
|
||||
// Current behavior is per-connection. Reviewer feedback notes this can
|
||||
// create odd cross-client behavior (for example dynamic tool calls on a
|
||||
// shared thread when another connected client did not opt into
|
||||
// experimental API). Proposed direction is instance-global first-write-wins
|
||||
// with initialize-time mismatch rejection.
|
||||
let (experimental_api_enabled, opt_out_notification_methods) =
|
||||
match params.capabilities {
|
||||
Some(capabilities) => (
|
||||
|
|
@ -238,11 +260,9 @@ impl MessageProcessor {
|
|||
),
|
||||
None => (false, Vec::new()),
|
||||
};
|
||||
self.experimental_api_enabled
|
||||
.store(experimental_api_enabled, Ordering::Relaxed);
|
||||
self.outgoing
|
||||
.set_opted_out_notification_methods(opt_out_notification_methods)
|
||||
.await;
|
||||
session.experimental_api_enabled = experimental_api_enabled;
|
||||
session.opted_out_notification_methods =
|
||||
opt_out_notification_methods.into_iter().collect();
|
||||
let ClientInfo {
|
||||
name,
|
||||
title: _title,
|
||||
|
|
@ -258,7 +278,7 @@ impl MessageProcessor {
|
|||
),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
self.outgoing.send_error(request_id.clone(), error).await;
|
||||
return;
|
||||
}
|
||||
SetOriginatorError::AlreadyInitialized => {
|
||||
|
|
@ -279,22 +299,13 @@ impl MessageProcessor {
|
|||
let response = InitializeResponse { user_agent };
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
|
||||
self.initialized = true;
|
||||
if !self.config_warnings.is_empty() {
|
||||
for notification in self.config_warnings.drain(..) {
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::ConfigWarning(
|
||||
notification,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
session.initialized = true;
|
||||
outbound_initialized.store(true, Ordering::Release);
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if !self.initialized {
|
||||
if !session.initialized {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "Not initialized".to_string(),
|
||||
|
|
@ -307,7 +318,7 @@ impl MessageProcessor {
|
|||
}
|
||||
|
||||
if let Some(reason) = codex_request.experimental_reason()
|
||||
&& !self.experimental_api_enabled.load(Ordering::Relaxed)
|
||||
&& !session.experimental_api_enabled
|
||||
{
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
|
|
@ -320,22 +331,49 @@ impl MessageProcessor {
|
|||
|
||||
match codex_request {
|
||||
ClientRequest::ConfigRead { request_id, params } => {
|
||||
self.handle_config_read(request_id, params).await;
|
||||
self.handle_config_read(
|
||||
ConnectionRequestId {
|
||||
connection_id,
|
||||
request_id,
|
||||
},
|
||||
params,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
ClientRequest::ConfigValueWrite { request_id, params } => {
|
||||
self.handle_config_value_write(request_id, params).await;
|
||||
self.handle_config_value_write(
|
||||
ConnectionRequestId {
|
||||
connection_id,
|
||||
request_id,
|
||||
},
|
||||
params,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
ClientRequest::ConfigBatchWrite { request_id, params } => {
|
||||
self.handle_config_batch_write(request_id, params).await;
|
||||
self.handle_config_batch_write(
|
||||
ConnectionRequestId {
|
||||
connection_id,
|
||||
request_id,
|
||||
},
|
||||
params,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
ClientRequest::ConfigRequirementsRead {
|
||||
request_id,
|
||||
params: _,
|
||||
} => {
|
||||
self.handle_config_requirements_read(request_id).await;
|
||||
self.handle_config_requirements_read(ConnectionRequestId {
|
||||
connection_id,
|
||||
request_id,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
other => {
|
||||
self.codex_message_processor.process_request(other).await;
|
||||
self.codex_message_processor
|
||||
.process_request(connection_id, other)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -350,10 +388,15 @@ impl MessageProcessor {
|
|||
self.codex_message_processor.thread_created_receiver()
|
||||
}
|
||||
|
||||
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
|
||||
if !self.initialized {
|
||||
return;
|
||||
pub(crate) async fn send_initialize_notifications(&self) {
|
||||
for notification in self.config_warnings.iter().cloned() {
|
||||
self.outgoing
|
||||
.send_server_notification(ServerNotification::ConfigWarning(notification))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) {
|
||||
self.codex_message_processor
|
||||
.try_attach_thread_listener(thread_id)
|
||||
.await;
|
||||
|
|
@ -372,7 +415,7 @@ impl MessageProcessor {
|
|||
self.outgoing.notify_client_error(err.id, err.error).await;
|
||||
}
|
||||
|
||||
async fn handle_config_read(&self, request_id: RequestId, params: ConfigReadParams) {
|
||||
async fn handle_config_read(&self, request_id: ConnectionRequestId, params: ConfigReadParams) {
|
||||
match self.config_api.read(params).await {
|
||||
Ok(response) => self.outgoing.send_response(request_id, response).await,
|
||||
Err(error) => self.outgoing.send_error(request_id, error).await,
|
||||
|
|
@ -381,7 +424,7 @@ impl MessageProcessor {
|
|||
|
||||
async fn handle_config_value_write(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
request_id: ConnectionRequestId,
|
||||
params: ConfigValueWriteParams,
|
||||
) {
|
||||
match self.config_api.write_value(params).await {
|
||||
|
|
@ -392,7 +435,7 @@ impl MessageProcessor {
|
|||
|
||||
async fn handle_config_batch_write(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
request_id: ConnectionRequestId,
|
||||
params: ConfigBatchWriteParams,
|
||||
) {
|
||||
match self.config_api.batch_write(params).await {
|
||||
|
|
@ -401,7 +444,7 @@ impl MessageProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
async fn handle_config_requirements_read(&self, request_id: RequestId) {
|
||||
async fn handle_config_requirements_read(&self, request_id: ConnectionRequestId) {
|
||||
match self.config_api.config_requirements_read().await {
|
||||
Ok(response) => self.outgoing.send_response(request_id, response).await,
|
||||
Err(error) => self.outgoing.send_error(request_id, error).await,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
|
|
@ -20,35 +19,44 @@ use crate::error_code::INTERNAL_ERROR_CODE;
|
|||
#[cfg(test)]
|
||||
use codex_protocol::account::PlanType;
|
||||
|
||||
/// Stable identifier for a transport connection.
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
|
||||
pub(crate) struct ConnectionId(pub(crate) u64);
|
||||
|
||||
/// Stable identifier for a client request scoped to a transport connection.
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub(crate) struct ConnectionRequestId {
|
||||
pub(crate) connection_id: ConnectionId,
|
||||
pub(crate) request_id: RequestId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum OutgoingEnvelope {
|
||||
ToConnection {
|
||||
connection_id: ConnectionId,
|
||||
message: OutgoingMessage,
|
||||
},
|
||||
Broadcast {
|
||||
message: OutgoingMessage,
|
||||
},
|
||||
}
|
||||
|
||||
/// Sends messages to the client and manages request callbacks.
|
||||
pub(crate) struct OutgoingMessageSender {
|
||||
next_request_id: AtomicI64,
|
||||
sender: mpsc::Sender<OutgoingMessage>,
|
||||
next_server_request_id: AtomicI64,
|
||||
sender: mpsc::Sender<OutgoingEnvelope>,
|
||||
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
|
||||
opted_out_notification_methods: Mutex<HashSet<String>>,
|
||||
}
|
||||
|
||||
impl OutgoingMessageSender {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingMessage>) -> Self {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingEnvelope>) -> Self {
|
||||
Self {
|
||||
next_request_id: AtomicI64::new(0),
|
||||
next_server_request_id: AtomicI64::new(0),
|
||||
sender,
|
||||
request_id_to_callback: Mutex::new(HashMap::new()),
|
||||
opted_out_notification_methods: Mutex::new(HashSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn set_opted_out_notification_methods(&self, methods: Vec<String>) {
|
||||
let mut opted_out = self.opted_out_notification_methods.lock().await;
|
||||
opted_out.clear();
|
||||
opted_out.extend(methods);
|
||||
}
|
||||
|
||||
async fn should_skip_notification(&self, method: &str) -> bool {
|
||||
let opted_out = self.opted_out_notification_methods.lock().await;
|
||||
opted_out.contains(method)
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
&self,
|
||||
request: ServerRequestPayload,
|
||||
|
|
@ -61,7 +69,7 @@ impl OutgoingMessageSender {
|
|||
&self,
|
||||
request: ServerRequestPayload,
|
||||
) -> (RequestId, oneshot::Receiver<Result>) {
|
||||
let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let id = RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let outgoing_message_id = id.clone();
|
||||
let (tx_approve, rx_approve) = oneshot::channel();
|
||||
{
|
||||
|
|
@ -71,7 +79,13 @@ impl OutgoingMessageSender {
|
|||
|
||||
let outgoing_message =
|
||||
OutgoingMessage::Request(request.request_with_id(outgoing_message_id.clone()));
|
||||
if let Err(err) = self.sender.send(outgoing_message).await {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send request {outgoing_message_id:?} to client: {err:?}");
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove(&outgoing_message_id);
|
||||
|
|
@ -121,17 +135,31 @@ impl OutgoingMessageSender {
|
|||
entry.is_some()
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response<T: Serialize>(&self, id: RequestId, response: T) {
|
||||
pub(crate) async fn send_response<T: Serialize>(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
response: T,
|
||||
) {
|
||||
match serde_json::to_value(response) {
|
||||
Ok(result) => {
|
||||
let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result });
|
||||
if let Err(err) = self.sender.send(outgoing_message).await {
|
||||
let outgoing_message = OutgoingMessage::Response(OutgoingResponse {
|
||||
id: request_id.request_id,
|
||||
result,
|
||||
});
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: request_id.connection_id,
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send response to client: {err:?}");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
self.send_error(
|
||||
id,
|
||||
request_id,
|
||||
JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to serialize response: {err}"),
|
||||
|
|
@ -144,13 +172,11 @@ impl OutgoingMessageSender {
|
|||
}
|
||||
|
||||
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
||||
let method = notification.to_string();
|
||||
if self.should_skip_notification(&method).await {
|
||||
return;
|
||||
}
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingMessage::AppServerNotification(notification))
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: OutgoingMessage::AppServerNotification(notification),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send server notification to client: {err:?}");
|
||||
|
|
@ -160,21 +186,35 @@ impl OutgoingMessageSender {
|
|||
/// All notifications should be migrated to [`ServerNotification`] and
|
||||
/// [`OutgoingMessage::Notification`] should be removed.
|
||||
pub(crate) async fn send_notification(&self, notification: OutgoingNotification) {
|
||||
if self
|
||||
.should_skip_notification(notification.method.as_str())
|
||||
let outgoing_message = OutgoingMessage::Notification(notification);
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
let outgoing_message = OutgoingMessage::Notification(notification);
|
||||
if let Err(err) = self.sender.send(outgoing_message).await {
|
||||
warn!("failed to send notification to client: {err:?}");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
|
||||
if let Err(err) = self.sender.send(outgoing_message).await {
|
||||
pub(crate) async fn send_error(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
error: JSONRPCErrorError,
|
||||
) {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError {
|
||||
id: request_id.request_id,
|
||||
error,
|
||||
});
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: request_id.connection_id,
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send error to client: {err:?}");
|
||||
}
|
||||
}
|
||||
|
|
@ -214,6 +254,8 @@ pub(crate) struct OutgoingError {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::AccountLoginCompletedNotification;
|
||||
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
|
||||
use codex_app_server_protocol::AccountUpdatedNotification;
|
||||
|
|
@ -224,6 +266,7 @@ mod tests {
|
|||
use codex_app_server_protocol::RateLimitWindow;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::time::timeout;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::*;
|
||||
|
|
@ -364,4 +407,75 @@ mod tests {
|
|||
"ensure the notification serializes correctly"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_response_routes_to_target_connection() {
|
||||
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let request_id = ConnectionRequestId {
|
||||
connection_id: ConnectionId(42),
|
||||
request_id: RequestId::Integer(7),
|
||||
};
|
||||
|
||||
outgoing
|
||||
.send_response(request_id.clone(), json!({ "ok": true }))
|
||||
.await;
|
||||
|
||||
let envelope = timeout(Duration::from_secs(1), rx.recv())
|
||||
.await
|
||||
.expect("should receive envelope before timeout")
|
||||
.expect("channel should contain one message");
|
||||
|
||||
match envelope {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(connection_id, ConnectionId(42));
|
||||
let OutgoingMessage::Response(response) = message else {
|
||||
panic!("expected response message");
|
||||
};
|
||||
assert_eq!(response.id, request_id.request_id);
|
||||
assert_eq!(response.result, json!({ "ok": true }));
|
||||
}
|
||||
other => panic!("expected targeted response envelope, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_error_routes_to_target_connection() {
|
||||
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let request_id = ConnectionRequestId {
|
||||
connection_id: ConnectionId(9),
|
||||
request_id: RequestId::Integer(3),
|
||||
};
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: "boom".to_string(),
|
||||
data: None,
|
||||
};
|
||||
|
||||
outgoing.send_error(request_id.clone(), error.clone()).await;
|
||||
|
||||
let envelope = timeout(Duration::from_secs(1), rx.recv())
|
||||
.await
|
||||
.expect("should receive envelope before timeout")
|
||||
.expect("channel should contain one message");
|
||||
|
||||
match envelope {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(connection_id, ConnectionId(9));
|
||||
let OutgoingMessage::Error(outgoing_error) = message else {
|
||||
panic!("expected error message");
|
||||
};
|
||||
assert_eq!(outgoing_error.id, RequestId::Integer(3));
|
||||
assert_eq!(outgoing_error.error, error);
|
||||
}
|
||||
other => panic!("expected targeted error envelope, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
749
codex-rs/app-server/src/transport.rs
Normal file
749
codex-rs/app-server/src/transport.rs
Normal file
|
|
@ -0,0 +1,749 @@
|
|||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use crate::message_processor::ConnectionSessionState;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingError;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Stream;
|
||||
use owo_colors::Style;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::{self};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage - 128 messages should be
|
||||
/// plenty for an interactive CLI.
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
fn colorize(text: &str, style: Style) -> String {
|
||||
text.if_supports_color(Stream::Stderr, |value| value.style(style))
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
fn print_websocket_startup_banner(addr: SocketAddr) {
|
||||
let title = colorize("codex app-server (WebSockets)", Style::new().bold().cyan());
|
||||
let listening_label = colorize("listening on:", Style::new().dimmed());
|
||||
let listen_url = colorize(&format!("ws://{addr}"), Style::new().green());
|
||||
let note_label = colorize("note:", Style::new().dimmed());
|
||||
eprintln!("{title}");
|
||||
eprintln!(" {listening_label} {listen_url}");
|
||||
if addr.ip().is_loopback() {
|
||||
eprintln!(
|
||||
" {note_label} binds localhost only (use SSH port-forwarding for remote access)"
|
||||
);
|
||||
} else {
|
||||
eprintln!(
|
||||
" {note_label} this is a raw WS server; consider running behind TLS/auth for real remote use"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
fn print_websocket_connection(peer_addr: SocketAddr) {
|
||||
let connected_label = colorize("websocket client connected from", Style::new().dimmed());
|
||||
eprintln!("{connected_label} {peer_addr}");
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum AppServerTransport {
|
||||
Stdio,
|
||||
WebSocket { bind_address: SocketAddr },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum AppServerTransportParseError {
|
||||
UnsupportedListenUrl(String),
|
||||
InvalidWebSocketListenUrl(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AppServerTransportParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`"
|
||||
),
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AppServerTransportParseError {}
|
||||
|
||||
impl AppServerTransport {
|
||||
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
|
||||
|
||||
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
|
||||
if listen_url == Self::DEFAULT_LISTEN_URL {
|
||||
return Ok(Self::Stdio);
|
||||
}
|
||||
|
||||
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
})?;
|
||||
return Ok(Self::WebSocket { bind_address });
|
||||
}
|
||||
|
||||
Err(AppServerTransportParseError::UnsupportedListenUrl(
|
||||
listen_url.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for AppServerTransport {
|
||||
type Err = AppServerTransportParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Self::from_listen_url(s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TransportEvent {
|
||||
ConnectionOpened {
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<OutgoingMessage>,
|
||||
},
|
||||
ConnectionClosed {
|
||||
connection_id: ConnectionId,
|
||||
},
|
||||
IncomingMessage {
|
||||
connection_id: ConnectionId,
|
||||
message: JSONRPCMessage,
|
||||
},
|
||||
}
|
||||
|
||||
pub(crate) struct ConnectionState {
|
||||
pub(crate) outbound_initialized: Arc<AtomicBool>,
|
||||
pub(crate) outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) session: ConnectionSessionState,
|
||||
}
|
||||
|
||||
impl ConnectionState {
|
||||
pub(crate) fn new(
|
||||
outbound_initialized: Arc<AtomicBool>,
|
||||
outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
outbound_initialized,
|
||||
outbound_opted_out_notification_methods,
|
||||
session: ConnectionSessionState::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct OutboundConnectionState {
|
||||
pub(crate) initialized: Arc<AtomicBool>,
|
||||
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) writer: mpsc::Sender<OutgoingMessage>,
|
||||
}
|
||||
|
||||
impl OutboundConnectionState {
|
||||
pub(crate) fn new(
|
||||
writer: mpsc::Sender<OutgoingMessage>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
initialized,
|
||||
opted_out_notification_methods,
|
||||
writer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn start_stdio_connection(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||
) -> IoResult<()> {
|
||||
let connection_id = ConnectionId(0);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?;
|
||||
|
||||
let transport_event_tx_for_reader = transport_event_tx.clone();
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let stdin = io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
let mut lines = reader.lines();
|
||||
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
if !forward_incoming_message(
|
||||
&transport_event_tx_for_reader,
|
||||
&writer_tx_for_reader,
|
||||
connection_id,
|
||||
&line,
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(err) => {
|
||||
error!("Failed reading stdin: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx_for_reader
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
debug!("stdin reader finished (EOF)");
|
||||
}));
|
||||
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(outgoing_message) = writer_rx.recv().await {
|
||||
let Some(mut json) = serialize_outgoing_message(outgoing_message) else {
|
||||
continue;
|
||||
};
|
||||
json.push('\n');
|
||||
if let Err(err) = stdout.write_all(json.as_bytes()).await {
|
||||
error!("Failed to write to stdout: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
info!("stdout writer exited (channel closed)");
|
||||
}));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn start_websocket_acceptor(
|
||||
bind_address: SocketAddr,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
) -> IoResult<JoinHandle<()>> {
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
print_websocket_startup_banner(local_addr);
|
||||
info!("app-server websocket listening on ws://{local_addr}");
|
||||
|
||||
let connection_counter = Arc::new(AtomicU64::new(1));
|
||||
Ok(tokio::spawn(async move {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, peer_addr)) => {
|
||||
print_websocket_connection(peer_addr);
|
||||
let connection_id =
|
||||
ConnectionId(connection_counter.fetch_add(1, Ordering::Relaxed));
|
||||
let transport_event_tx_for_connection = transport_event_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
run_websocket_connection(
|
||||
connection_id,
|
||||
stream,
|
||||
transport_event_tx_for_connection,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
Err(err) => {
|
||||
error!("failed to accept websocket connection: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
async fn run_websocket_connection(
|
||||
connection_id: ConnectionId,
|
||||
stream: TcpStream,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
) {
|
||||
let websocket_stream = match accept_async(stream).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
warn!("failed to complete websocket handshake: {err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<OutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let (mut websocket_writer, mut websocket_reader) = websocket_stream.split();
|
||||
loop {
|
||||
tokio::select! {
|
||||
outgoing_message = writer_rx.recv() => {
|
||||
let Some(outgoing_message) = outgoing_message else {
|
||||
break;
|
||||
};
|
||||
let Some(json) = serialize_outgoing_message(outgoing_message) else {
|
||||
continue;
|
||||
};
|
||||
if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
incoming_message = websocket_reader.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(WebSocketMessage::Text(text))) => {
|
||||
if !forward_incoming_message(
|
||||
&transport_event_tx,
|
||||
&writer_tx_for_reader,
|
||||
connection_id,
|
||||
&text,
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Ping(payload))) => {
|
||||
if websocket_writer.send(WebSocketMessage::Pong(payload)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Pong(_))) => {}
|
||||
Some(Ok(WebSocketMessage::Close(_))) | None => break,
|
||||
Some(Ok(WebSocketMessage::Binary(_))) => {
|
||||
warn!("dropping unsupported binary websocket message");
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Frame(_))) => {}
|
||||
Some(Err(err)) => {
|
||||
warn!("websocket receive error: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn forward_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<OutgoingMessage>,
|
||||
connection_id: ConnectionId,
|
||||
payload: &str,
|
||||
) -> bool {
|
||||
match serde_json::from_str::<JSONRPCMessage>(payload) {
|
||||
Ok(message) => {
|
||||
enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Failed to deserialize JSONRPCMessage: {err}");
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn enqueue_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<OutgoingMessage>,
|
||||
connection_id: ConnectionId,
|
||||
message: JSONRPCMessage,
|
||||
) -> bool {
|
||||
let event = TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
};
|
||||
match transport_event_tx.try_send(event) {
|
||||
Ok(()) => true,
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||
Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: JSONRPCMessage::Request(request),
|
||||
})) => {
|
||||
let overload_error = OutgoingMessage::Error(OutgoingError {
|
||||
id: request.id,
|
||||
error: JSONRPCErrorError {
|
||||
code: OVERLOADED_ERROR_CODE,
|
||||
message: "Server overloaded; retry later.".to_string(),
|
||||
data: None,
|
||||
},
|
||||
});
|
||||
match writer.try_send(overload_error) {
|
||||
Ok(()) => true,
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_overload_error)) => {
|
||||
warn!(
|
||||
"dropping overload response for connection {:?}: outbound queue is full",
|
||||
connection_id
|
||||
);
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option<String> {
|
||||
let value = match serde_json::to_value(outgoing_message) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("Failed to convert OutgoingMessage to JSON value: {err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
match serde_json::to_string(&value) {
|
||||
Ok(json) => Some(json),
|
||||
Err(err) => {
|
||||
error!("Failed to serialize JSONRPCMessage: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_skip_notification_for_connection(
|
||||
connection_state: &OutboundConnectionState,
|
||||
message: &OutgoingMessage,
|
||||
) -> bool {
|
||||
let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read()
|
||||
else {
|
||||
warn!("failed to read outbound opted-out notifications");
|
||||
return false;
|
||||
};
|
||||
match message {
|
||||
OutgoingMessage::AppServerNotification(notification) => {
|
||||
let method = notification.to_string();
|
||||
opted_out_notification_methods.contains(method.as_str())
|
||||
}
|
||||
OutgoingMessage::Notification(notification) => {
|
||||
opted_out_notification_methods.contains(notification.method.as_str())
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn route_outgoing_envelope(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
envelope: OutgoingEnvelope,
|
||||
) -> Vec<ConnectionId> {
|
||||
let mut disconnected = Vec::new();
|
||||
match envelope {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
} => {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
warn!(
|
||||
"dropping message for disconnected connection: {:?}",
|
||||
connection_id
|
||||
);
|
||||
return disconnected;
|
||||
};
|
||||
if connection_state.writer.send(message).await.is_err() {
|
||||
connections.remove(&connection_id);
|
||||
disconnected.push(connection_id);
|
||||
}
|
||||
}
|
||||
OutgoingEnvelope::Broadcast { message } => {
|
||||
let target_connections: Vec<ConnectionId> = connections
|
||||
.iter()
|
||||
.filter_map(|(connection_id, connection_state)| {
|
||||
if connection_state.initialized.load(Ordering::Acquire)
|
||||
&& !should_skip_notification_for_connection(connection_state, &message)
|
||||
{
|
||||
Some(*connection_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for connection_id in target_connections {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
continue;
|
||||
};
|
||||
if connection_state.writer.send(message.clone()).await.is_err() {
|
||||
connections.remove(&connection_id);
|
||||
disconnected.push(connection_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
disconnected
|
||||
}
|
||||
|
||||
pub(crate) fn has_initialized_connections(
|
||||
connections: &HashMap<ConnectionId, ConnectionState>,
|
||||
) -> bool {
|
||||
connections
|
||||
.values()
|
||||
.any(|connection| connection.session.initialized)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_parses_stdio_listen_url() {
|
||||
let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL)
|
||||
.expect("stdio listen URL should parse");
|
||||
assert_eq!(transport, AppServerTransport::Stdio);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_parses_websocket_listen_url() {
|
||||
let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234")
|
||||
.expect("websocket listen URL should parse");
|
||||
assert_eq!(
|
||||
transport,
|
||||
AppServerTransport::WebSocket {
|
||||
bind_address: "127.0.0.1:1234".parse().expect("valid socket address"),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_rejects_invalid_websocket_listen_url() {
|
||||
let err = AppServerTransport::from_listen_url("ws://localhost:1234")
|
||||
.expect_err("hostname bind address should be rejected");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_rejects_unsupported_listen_url() {
|
||||
let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234")
|
||||
.expect_err("unsupported scheme should fail");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_incoming_request_returns_overload_error_when_queue_is_full() {
|
||||
let connection_id = ConnectionId(42);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let first_message =
|
||||
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: first_message.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("queue should accept first message");
|
||||
|
||||
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
method: "config/read".to_string(),
|
||||
params: Some(json!({ "includeLayers": false })),
|
||||
});
|
||||
assert!(
|
||||
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await
|
||||
);
|
||||
|
||||
let queued_event = transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("first event should stay queued");
|
||||
match queued_event {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: queued_connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(queued_connection_id, connection_id);
|
||||
assert_eq!(message, first_message);
|
||||
}
|
||||
_ => panic!("expected queued incoming message"),
|
||||
}
|
||||
|
||||
let overload = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("request should receive overload error");
|
||||
let overload_json = serde_json::to_value(overload).expect("serialize overload error");
|
||||
assert_eq!(
|
||||
overload_json,
|
||||
json!({
|
||||
"id": 7,
|
||||
"error": {
|
||||
"code": OVERLOADED_ERROR_CODE,
|
||||
"message": "Server overloaded; retry later."
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_incoming_response_waits_instead_of_dropping_when_queue_is_full() {
|
||||
let connection_id = ConnectionId(42);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, _writer_rx) = mpsc::channel(1);
|
||||
|
||||
let first_message =
|
||||
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: first_message.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("queue should accept first message");
|
||||
|
||||
let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
result: json!({"ok": true}),
|
||||
});
|
||||
let transport_event_tx_for_enqueue = transport_event_tx.clone();
|
||||
let writer_tx_for_enqueue = writer_tx.clone();
|
||||
let enqueue_handle = tokio::spawn(async move {
|
||||
enqueue_incoming_message(
|
||||
&transport_event_tx_for_enqueue,
|
||||
&writer_tx_for_enqueue,
|
||||
connection_id,
|
||||
response,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let queued_event = transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("first event should be dequeued");
|
||||
match queued_event {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: queued_connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(queued_connection_id, connection_id);
|
||||
assert_eq!(message, first_message);
|
||||
}
|
||||
_ => panic!("expected queued incoming message"),
|
||||
}
|
||||
|
||||
let enqueue_result = enqueue_handle.await.expect("enqueue task should not panic");
|
||||
assert!(enqueue_result);
|
||||
|
||||
let forwarded_event = transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("response should be forwarded instead of dropped");
|
||||
match forwarded_event {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: queued_connection_id,
|
||||
message:
|
||||
JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }),
|
||||
} => {
|
||||
assert_eq!(queued_connection_id, connection_id);
|
||||
assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7));
|
||||
assert_eq!(result, json!({"ok": true}));
|
||||
}
|
||||
_ => panic!("expected forwarded response message"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_incoming_request_does_not_block_when_writer_queue_is_full() {
|
||||
let connection_id = ConnectionId(42);
|
||||
let (transport_event_tx, _transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: JSONRPCMessage::Notification(
|
||||
codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
},
|
||||
),
|
||||
})
|
||||
.await
|
||||
.expect("transport queue should accept first message");
|
||||
|
||||
writer_tx
|
||||
.send(OutgoingMessage::Notification(
|
||||
crate::outgoing_message::OutgoingNotification {
|
||||
method: "queued".to_string(),
|
||||
params: None,
|
||||
},
|
||||
))
|
||||
.await
|
||||
.expect("writer queue should accept first message");
|
||||
|
||||
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
method: "config/read".to_string(),
|
||||
params: Some(json!({ "includeLayers": false })),
|
||||
});
|
||||
|
||||
let enqueue_result = tokio::time::timeout(
|
||||
std::time::Duration::from_millis(100),
|
||||
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request),
|
||||
)
|
||||
.await
|
||||
.expect("enqueue should not block while writer queue is full");
|
||||
assert!(enqueue_result);
|
||||
|
||||
let queued_outgoing = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("writer queue should still contain original message");
|
||||
let queued_json = serde_json::to_value(queued_outgoing).expect("serialize queued message");
|
||||
assert_eq!(queued_json, json!({ "method": "queued" }));
|
||||
}
|
||||
}
|
||||
|
|
@ -174,7 +174,7 @@ impl McpProcess {
|
|||
client_info,
|
||||
Some(InitializeCapabilities {
|
||||
experimental_api: true,
|
||||
opt_out_notification_methods: None,
|
||||
..Default::default()
|
||||
}),
|
||||
)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -36,8 +36,9 @@ async fn app_server_default_analytics_disabled_without_flag() -> Result<()> {
|
|||
.map_err(|err| anyhow::anyhow!(err.to_string()))?;
|
||||
|
||||
// With analytics unset in the config and the default flag is false, metrics are disabled.
|
||||
// No provider is built.
|
||||
assert_eq!(provider.is_none(), true);
|
||||
// A provider may still exist for non-metrics telemetry, so check metrics specifically.
|
||||
let has_metrics = provider.as_ref().and_then(|otel| otel.metrics()).is_some();
|
||||
assert_eq!(has_metrics, false);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -560,9 +560,22 @@ fn assert_layers_user_then_optional_system(
|
|||
layers: &[codex_app_server_protocol::ConfigLayer],
|
||||
user_file: AbsolutePathBuf,
|
||||
) -> Result<()> {
|
||||
assert_eq!(layers.len(), 2);
|
||||
assert_eq!(layers[0].name, ConfigLayerSource::User { file: user_file });
|
||||
assert!(matches!(layers[1].name, ConfigLayerSource::System { .. }));
|
||||
let mut first_index = 0;
|
||||
if matches!(
|
||||
layers.first().map(|layer| &layer.name),
|
||||
Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm)
|
||||
) {
|
||||
first_index = 1;
|
||||
}
|
||||
assert_eq!(layers.len(), first_index + 2);
|
||||
assert_eq!(
|
||||
layers[first_index].name,
|
||||
ConfigLayerSource::User { file: user_file }
|
||||
);
|
||||
assert!(matches!(
|
||||
layers[first_index + 1].name,
|
||||
ConfigLayerSource::System { .. }
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -571,12 +584,25 @@ fn assert_layers_managed_user_then_optional_system(
|
|||
managed_file: AbsolutePathBuf,
|
||||
user_file: AbsolutePathBuf,
|
||||
) -> Result<()> {
|
||||
assert_eq!(layers.len(), 3);
|
||||
let mut first_index = 0;
|
||||
if matches!(
|
||||
layers.first().map(|layer| &layer.name),
|
||||
Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm)
|
||||
) {
|
||||
first_index = 1;
|
||||
}
|
||||
assert_eq!(layers.len(), first_index + 3);
|
||||
assert_eq!(
|
||||
layers[0].name,
|
||||
layers[first_index].name,
|
||||
ConfigLayerSource::LegacyManagedConfigTomlFromFile { file: managed_file }
|
||||
);
|
||||
assert_eq!(layers[1].name, ConfigLayerSource::User { file: user_file });
|
||||
assert!(matches!(layers[2].name, ConfigLayerSource::System { .. }));
|
||||
assert_eq!(
|
||||
layers[first_index + 1].name,
|
||||
ConfigLayerSource::User { file: user_file }
|
||||
);
|
||||
assert!(matches!(
|
||||
layers[first_index + 2].name,
|
||||
ConfigLayerSource::System { .. }
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,263 @@
|
|||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::bail;
|
||||
use app_test_support::create_mock_responses_server_sequence_unchecked;
|
||||
use codex_app_server_protocol::ClientInfo;
|
||||
use codex_app_server_protocol::InitializeParams;
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use serde_json::json;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use tempfile::TempDir;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
type WsClient = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_transport_routes_per_connection_handshake_and_responses() -> Result<()> {
|
||||
let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri(), "never")?;
|
||||
|
||||
let bind_addr = reserve_local_addr()?;
|
||||
let mut process = spawn_websocket_server(codex_home.path(), bind_addr).await?;
|
||||
|
||||
let mut ws1 = connect_websocket(bind_addr).await?;
|
||||
let mut ws2 = connect_websocket(bind_addr).await?;
|
||||
|
||||
send_initialize_request(&mut ws1, 1, "ws_client_one").await?;
|
||||
let first_init = read_response_for_id(&mut ws1, 1).await?;
|
||||
assert_eq!(first_init.id, RequestId::Integer(1));
|
||||
|
||||
// Initialize responses are request-scoped and must not leak to other
|
||||
// connections.
|
||||
assert_no_message(&mut ws2, Duration::from_millis(250)).await?;
|
||||
|
||||
send_config_read_request(&mut ws2, 2).await?;
|
||||
let not_initialized = read_error_for_id(&mut ws2, 2).await?;
|
||||
assert_eq!(not_initialized.error.message, "Not initialized");
|
||||
|
||||
send_initialize_request(&mut ws2, 3, "ws_client_two").await?;
|
||||
let second_init = read_response_for_id(&mut ws2, 3).await?;
|
||||
assert_eq!(second_init.id, RequestId::Integer(3));
|
||||
|
||||
// Same request-id on different connections must route independently.
|
||||
send_config_read_request(&mut ws1, 77).await?;
|
||||
send_config_read_request(&mut ws2, 77).await?;
|
||||
let ws1_config = read_response_for_id(&mut ws1, 77).await?;
|
||||
let ws2_config = read_response_for_id(&mut ws2, 77).await?;
|
||||
|
||||
assert_eq!(ws1_config.id, RequestId::Integer(77));
|
||||
assert_eq!(ws2_config.id, RequestId::Integer(77));
|
||||
assert!(ws1_config.result.get("config").is_some());
|
||||
assert!(ws2_config.result.get("config").is_some());
|
||||
|
||||
process
|
||||
.kill()
|
||||
.await
|
||||
.context("failed to stop websocket app-server process")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn spawn_websocket_server(codex_home: &Path, bind_addr: SocketAddr) -> Result<Child> {
|
||||
let program = codex_utils_cargo_bin::cargo_bin("codex-app-server")
|
||||
.context("should find app-server binary")?;
|
||||
let mut cmd = Command::new(program);
|
||||
cmd.arg("--listen")
|
||||
.arg(format!("ws://{bind_addr}"))
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::piped())
|
||||
.env("CODEX_HOME", codex_home)
|
||||
.env("RUST_LOG", "debug");
|
||||
let mut process = cmd
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.context("failed to spawn websocket app-server process")?;
|
||||
|
||||
if let Some(stderr) = process.stderr.take() {
|
||||
let mut stderr_reader = tokio::io::BufReader::new(stderr).lines();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(Some(line)) = stderr_reader.next_line().await {
|
||||
eprintln!("[websocket app-server stderr] {line}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(process)
|
||||
}
|
||||
|
||||
fn reserve_local_addr() -> Result<SocketAddr> {
|
||||
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
|
||||
let addr = listener.local_addr()?;
|
||||
drop(listener);
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
async fn connect_websocket(bind_addr: SocketAddr) -> Result<WsClient> {
|
||||
let url = format!("ws://{bind_addr}");
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
loop {
|
||||
match connect_async(&url).await {
|
||||
Ok((stream, _response)) => return Ok(stream),
|
||||
Err(err) => {
|
||||
if Instant::now() >= deadline {
|
||||
bail!("failed to connect websocket to {url}: {err}");
|
||||
}
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_initialize_request(stream: &mut WsClient, id: i64, client_name: &str) -> Result<()> {
|
||||
let params = InitializeParams {
|
||||
client_info: ClientInfo {
|
||||
name: client_name.to_string(),
|
||||
title: Some("WebSocket Test Client".to_string()),
|
||||
version: "0.1.0".to_string(),
|
||||
},
|
||||
capabilities: None,
|
||||
};
|
||||
send_request(
|
||||
stream,
|
||||
"initialize",
|
||||
id,
|
||||
Some(serde_json::to_value(params)?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_config_read_request(stream: &mut WsClient, id: i64) -> Result<()> {
|
||||
send_request(
|
||||
stream,
|
||||
"config/read",
|
||||
id,
|
||||
Some(json!({ "includeLayers": false })),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
stream: &mut WsClient,
|
||||
method: &str,
|
||||
id: i64,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> Result<()> {
|
||||
let message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(id),
|
||||
method: method.to_string(),
|
||||
params,
|
||||
});
|
||||
send_jsonrpc(stream, message).await
|
||||
}
|
||||
|
||||
async fn send_jsonrpc(stream: &mut WsClient, message: JSONRPCMessage) -> Result<()> {
|
||||
let payload = serde_json::to_string(&message)?;
|
||||
stream
|
||||
.send(WebSocketMessage::Text(payload.into()))
|
||||
.await
|
||||
.context("failed to send websocket frame")
|
||||
}
|
||||
|
||||
async fn read_response_for_id(stream: &mut WsClient, id: i64) -> Result<JSONRPCResponse> {
|
||||
let target_id = RequestId::Integer(id);
|
||||
loop {
|
||||
let message = read_jsonrpc_message(stream).await?;
|
||||
if let JSONRPCMessage::Response(response) = message
|
||||
&& response.id == target_id
|
||||
{
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_error_for_id(stream: &mut WsClient, id: i64) -> Result<JSONRPCError> {
|
||||
let target_id = RequestId::Integer(id);
|
||||
loop {
|
||||
let message = read_jsonrpc_message(stream).await?;
|
||||
if let JSONRPCMessage::Error(err) = message
|
||||
&& err.id == target_id
|
||||
{
|
||||
return Ok(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_jsonrpc_message(stream: &mut WsClient) -> Result<JSONRPCMessage> {
|
||||
loop {
|
||||
let frame = timeout(DEFAULT_READ_TIMEOUT, stream.next())
|
||||
.await
|
||||
.context("timed out waiting for websocket frame")?
|
||||
.context("websocket stream ended unexpectedly")?
|
||||
.context("failed to read websocket frame")?;
|
||||
|
||||
match frame {
|
||||
WebSocketMessage::Text(text) => return Ok(serde_json::from_str(text.as_ref())?),
|
||||
WebSocketMessage::Ping(payload) => {
|
||||
stream.send(WebSocketMessage::Pong(payload)).await?;
|
||||
}
|
||||
WebSocketMessage::Pong(_) => {}
|
||||
WebSocketMessage::Close(frame) => {
|
||||
bail!("websocket closed unexpectedly: {frame:?}")
|
||||
}
|
||||
WebSocketMessage::Binary(_) => bail!("unexpected binary websocket frame"),
|
||||
WebSocketMessage::Frame(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn assert_no_message(stream: &mut WsClient, wait_for: Duration) -> Result<()> {
|
||||
match timeout(wait_for, stream.next()).await {
|
||||
Ok(Some(Ok(frame))) => bail!("unexpected frame while waiting for silence: {frame:?}"),
|
||||
Ok(Some(Err(err))) => bail!("unexpected websocket read error: {err}"),
|
||||
Ok(None) => bail!("websocket closed unexpectedly while waiting for silence"),
|
||||
Err(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_config_toml(
|
||||
codex_home: &Path,
|
||||
server_uri: &str,
|
||||
approval_policy: &str,
|
||||
) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "{approval_policy}"
|
||||
sandbox_mode = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "responses"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ mod app_list;
|
|||
mod collaboration_mode_list;
|
||||
mod compaction;
|
||||
mod config_rpc;
|
||||
mod connection_handling_websocket;
|
||||
mod dynamic_tools;
|
||||
mod experimental_api;
|
||||
mod experimental_feature_list;
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ use app_test_support::create_mock_responses_server_repeating_assistant;
|
|||
use app_test_support::create_mock_responses_server_sequence;
|
||||
use app_test_support::create_shell_command_sse_response;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::CommandExecutionApprovalDecision;
|
||||
use codex_app_server_protocol::CommandExecutionRequestApprovalResponse;
|
||||
use codex_app_server_protocol::ItemCompletedNotification;
|
||||
use codex_app_server_protocol::ItemStartedNotification;
|
||||
use codex_app_server_protocol::JSONRPCError;
|
||||
|
|
@ -211,9 +209,7 @@ async fn review_start_exec_approval_item_id_matches_command_execution_item() ->
|
|||
|
||||
mcp.send_response(
|
||||
request_id,
|
||||
serde_json::to_value(CommandExecutionRequestApprovalResponse {
|
||||
decision: CommandExecutionApprovalDecision::Accept,
|
||||
})?,
|
||||
serde_json::json!({ "decision": codex_core::protocol::ReviewDecision::Approved }),
|
||||
)
|
||||
.await?;
|
||||
timeout(
|
||||
|
|
|
|||
|
|
@ -306,6 +306,15 @@ struct AppServerCommand {
|
|||
#[command(subcommand)]
|
||||
subcommand: Option<AppServerSubcommand>,
|
||||
|
||||
/// Transport endpoint URL. Supported values: `stdio://` (default),
|
||||
/// `ws://IP:PORT`.
|
||||
#[arg(
|
||||
long = "listen",
|
||||
value_name = "URL",
|
||||
default_value = codex_app_server::AppServerTransport::DEFAULT_LISTEN_URL
|
||||
)]
|
||||
listen: codex_app_server::AppServerTransport,
|
||||
|
||||
/// Controls whether analytics are enabled by default.
|
||||
///
|
||||
/// Analytics are disabled by default for app-server. Users have to explicitly opt in
|
||||
|
|
@ -587,11 +596,13 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
|
|||
}
|
||||
Some(Subcommand::AppServer(app_server_cli)) => match app_server_cli.subcommand {
|
||||
None => {
|
||||
codex_app_server::run_main(
|
||||
let transport = app_server_cli.listen;
|
||||
codex_app_server::run_main_with_transport(
|
||||
codex_linux_sandbox_exe,
|
||||
root_config_overrides,
|
||||
codex_core::config_loader::LoaderOverrides::default(),
|
||||
app_server_cli.analytics_default_enabled,
|
||||
transport,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
|
@ -1328,6 +1339,10 @@ mod tests {
|
|||
fn app_server_analytics_default_disabled_without_flag() {
|
||||
let app_server = app_server_from_args(["codex", "app-server"].as_ref());
|
||||
assert!(!app_server.analytics_default_enabled);
|
||||
assert_eq!(
|
||||
app_server.listen,
|
||||
codex_app_server::AppServerTransport::Stdio
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1337,6 +1352,36 @@ mod tests {
|
|||
assert!(app_server.analytics_default_enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_listen_websocket_url_parses() {
|
||||
let app_server = app_server_from_args(
|
||||
["codex", "app-server", "--listen", "ws://127.0.0.1:4500"].as_ref(),
|
||||
);
|
||||
assert_eq!(
|
||||
app_server.listen,
|
||||
codex_app_server::AppServerTransport::WebSocket {
|
||||
bind_address: "127.0.0.1:4500".parse().expect("valid socket address"),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_listen_stdio_url_parses() {
|
||||
let app_server =
|
||||
app_server_from_args(["codex", "app-server", "--listen", "stdio://"].as_ref());
|
||||
assert_eq!(
|
||||
app_server.listen,
|
||||
codex_app_server::AppServerTransport::Stdio
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_listen_invalid_url_fails_to_parse() {
|
||||
let parse_result =
|
||||
MultitoolCli::try_parse_from(["codex", "app-server", "--listen", "http://foo"]);
|
||||
assert!(parse_result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn features_enable_parses_feature_name() {
|
||||
let cli = MultitoolCli::try_parse_from(["codex", "features", "enable", "unified_exec"])
|
||||
|
|
|
|||
|
|
@ -371,25 +371,6 @@ async fn review_does_not_emit_agent_message_on_structured_output() {
|
|||
_ => false,
|
||||
})
|
||||
.await;
|
||||
// On slower CI hosts, the final AgentMessage can arrive immediately after
|
||||
// TurnComplete. Drain a brief tail window to make ordering nondeterminism
|
||||
// harmless while still enforcing "exactly one final AgentMessage".
|
||||
while let Ok(Ok(event)) =
|
||||
tokio::time::timeout(std::time::Duration::from_millis(200), codex.next_event()).await
|
||||
{
|
||||
match event.msg {
|
||||
EventMsg::AgentMessage(_) => agent_messages += 1,
|
||||
EventMsg::EnteredReviewMode(_) => saw_entered = true,
|
||||
EventMsg::ExitedReviewMode(_) => saw_exited = true,
|
||||
EventMsg::AgentMessageContentDelta(_) => {
|
||||
panic!("unexpected AgentMessageContentDelta surfaced during review")
|
||||
}
|
||||
EventMsg::AgentMessageDelta(_) => {
|
||||
panic!("unexpected AgentMessageDelta surfaced during review")
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
assert_eq!(1, agent_messages, "expected exactly one AgentMessage event");
|
||||
assert!(saw_entered && saw_exited, "missing review lifecycle events");
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue