Pump pings (#11413)
Keep processing ping even when the agent isn't actively running. Otherwise the connection will drop.
This commit is contained in:
parent
b5339a591d
commit
d73de9c8ba
1 changed files with 123 additions and 10 deletions
|
|
@ -26,6 +26,7 @@ use std::time::Duration;
|
|||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
|
|
@ -41,7 +42,124 @@ use tungstenite::extensions::compression::deflate::DeflateConfig;
|
|||
use tungstenite::protocol::WebSocketConfig;
|
||||
use url::Url;
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
struct WsStream {
|
||||
tx_command: mpsc::Sender<WsCommand>,
|
||||
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
||||
pump_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
enum WsCommand {
|
||||
Send {
|
||||
message: Message,
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
Close {
|
||||
tx_result: oneshot::Sender<Result<(), WsError>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl WsStream {
|
||||
fn new(inner: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
|
||||
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
|
||||
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();
|
||||
|
||||
let pump_task = tokio::spawn(async move {
|
||||
let mut inner = inner;
|
||||
loop {
|
||||
tokio::select! {
|
||||
command = rx_command.recv() => {
|
||||
let Some(command) = command else {
|
||||
break;
|
||||
};
|
||||
match command {
|
||||
WsCommand::Send { message, tx_result } => {
|
||||
let result = inner.send(message).await;
|
||||
let should_break = result.is_err();
|
||||
let _ = tx_result.send(result);
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
}
|
||||
WsCommand::Close { tx_result } => {
|
||||
let result = inner.close(None).await;
|
||||
let _ = tx_result.send(result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
message = inner.next() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
match message {
|
||||
Ok(Message::Ping(payload)) => {
|
||||
if let Err(err) = inner.send(Message::Pong(payload)).await {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Message::Pong(_)) => {}
|
||||
Ok(message @ (Message::Text(_)
|
||||
| Message::Binary(_)
|
||||
| Message::Close(_)
|
||||
| Message::Frame(_))) => {
|
||||
let is_close = matches!(message, Message::Close(_));
|
||||
if tx_message.send(Ok(message)).is_err() {
|
||||
break;
|
||||
}
|
||||
if is_close {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx_message.send(Err(err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
tx_command,
|
||||
rx_message,
|
||||
pump_task,
|
||||
}
|
||||
}
|
||||
|
||||
async fn request(
|
||||
&self,
|
||||
make_command: impl FnOnce(oneshot::Sender<Result<(), WsError>>) -> WsCommand,
|
||||
) -> Result<(), WsError> {
|
||||
let (tx_result, rx_result) = oneshot::channel();
|
||||
if self.tx_command.send(make_command(tx_result)).await.is_err() {
|
||||
return Err(WsError::ConnectionClosed);
|
||||
}
|
||||
rx_result.await.unwrap_or(Err(WsError::ConnectionClosed))
|
||||
}
|
||||
|
||||
async fn send(&self, message: Message) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Send { message, tx_result })
|
||||
.await
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), WsError> {
|
||||
self.request(|tx_result| WsCommand::Close { tx_result })
|
||||
.await
|
||||
}
|
||||
|
||||
async fn next(&mut self) -> Option<Result<Message, WsError>> {
|
||||
self.rx_message.recv().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WsStream {
|
||||
fn drop(&mut self) {
|
||||
self.pump_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
|
||||
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
||||
|
|
@ -119,7 +237,7 @@ impl ResponsesWebsocketConnection {
|
|||
)
|
||||
.await
|
||||
{
|
||||
let _ = ws_stream.close(None).await;
|
||||
let _ = ws_stream.close().await;
|
||||
*guard = None;
|
||||
let _ = tx_event.send(Err(err)).await;
|
||||
}
|
||||
|
|
@ -215,7 +333,7 @@ async fn connect_websocket(
|
|||
{
|
||||
let _ = turn_state.set(header_value.to_string());
|
||||
}
|
||||
Ok((stream, reasoning_included, models_etag))
|
||||
Ok((WsStream::new(stream), reasoning_included, models_etag))
|
||||
}
|
||||
|
||||
fn websocket_config() -> WebSocketConfig {
|
||||
|
|
@ -419,18 +537,13 @@ async fn run_websocket_response_stream(
|
|||
Message::Binary(_) => {
|
||||
return Err(ApiError::Stream("unexpected binary websocket event".into()));
|
||||
}
|
||||
Message::Ping(payload) => {
|
||||
if ws_stream.send(Message::Pong(payload)).await.is_err() {
|
||||
return Err(ApiError::Stream("websocket ping failed".into()));
|
||||
}
|
||||
}
|
||||
Message::Pong(_) => {}
|
||||
Message::Close(_) => {
|
||||
return Err(ApiError::Stream(
|
||||
"websocket closed by server before response.completed".into(),
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
Message::Frame(_) => {}
|
||||
Message::Ping(_) | Message::Pong(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue