fix: codex delegate cancellation (#7092)
This commit is contained in:
parent
99bcb90353
commit
920239f272
2 changed files with 170 additions and 45 deletions
|
|
@ -2436,6 +2436,9 @@ use crate::features::Features;
|
|||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context_with_rx;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -2712,7 +2715,7 @@ mod tests {
|
|||
|
||||
// Like make_session_and_context, but returns Arc<Session> and the event receiver
|
||||
// so tests can assert on emitted events.
|
||||
fn make_session_and_context_with_rx() -> (
|
||||
pub(crate) fn make_session_and_context_with_rx() -> (
|
||||
Arc<Session>,
|
||||
Arc<TurnContext>,
|
||||
async_channel::Receiver<Event>,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ use codex_protocol::protocol::SessionSource;
|
|||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::protocol::Submission;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::AuthManager;
|
||||
|
|
@ -60,14 +62,13 @@ pub(crate) async fn run_codex_conversation_interactive(
|
|||
let parent_ctx_clone = Arc::clone(&parent_ctx);
|
||||
let codex_for_events = Arc::clone(&codex);
|
||||
tokio::spawn(async move {
|
||||
let _ = forward_events(
|
||||
forward_events(
|
||||
codex_for_events,
|
||||
tx_sub,
|
||||
parent_session_clone,
|
||||
parent_ctx_clone,
|
||||
cancel_token_events.clone(),
|
||||
cancel_token_events,
|
||||
)
|
||||
.or_cancel(&cancel_token_events)
|
||||
.await;
|
||||
});
|
||||
|
||||
|
|
@ -156,53 +157,92 @@ async fn forward_events(
|
|||
parent_ctx: Arc<TurnContext>,
|
||||
cancel_token: CancellationToken,
|
||||
) {
|
||||
while let Ok(event) = codex.next_event().await {
|
||||
match event {
|
||||
// ignore all legacy delta events
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_),
|
||||
} => continue,
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::SessionConfigured(_),
|
||||
} => continue,
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ExecApprovalRequest(event),
|
||||
} => {
|
||||
// Initiate approval via parent session; do not surface to consumer.
|
||||
handle_exec_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
let cancelled = cancel_token.cancelled();
|
||||
tokio::pin!(cancelled);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = &mut cancelled => {
|
||||
shutdown_delegate(&codex).await;
|
||||
break;
|
||||
}
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ApplyPatchApprovalRequest(event),
|
||||
} => {
|
||||
handle_patch_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
other => {
|
||||
let _ = tx_sub.send(other).await;
|
||||
event = codex.next_event() => {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(_) => break,
|
||||
};
|
||||
match event {
|
||||
// ignore all legacy delta events
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_),
|
||||
} => {}
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::SessionConfigured(_),
|
||||
} => {}
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ExecApprovalRequest(event),
|
||||
} => {
|
||||
// Initiate approval via parent session; do not surface to consumer.
|
||||
handle_exec_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ApplyPatchApprovalRequest(event),
|
||||
} => {
|
||||
handle_patch_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
other => {
|
||||
match tx_sub.send(other).or_cancel(&cancel_token).await {
|
||||
Ok(Ok(())) => {}
|
||||
_ => {
|
||||
shutdown_delegate(&codex).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ask the delegate to stop and drain its events so background sends do not hit a closed channel.
|
||||
async fn shutdown_delegate(codex: &Codex) {
|
||||
let _ = codex.submit(Op::Interrupt).await;
|
||||
let _ = codex.submit(Op::Shutdown {}).await;
|
||||
|
||||
let _ = timeout(Duration::from_millis(500), async {
|
||||
while let Ok(event) = codex.next_event().await {
|
||||
if matches!(
|
||||
event.msg,
|
||||
EventMsg::TurnAborted(_) | EventMsg::TaskComplete(_)
|
||||
) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Forward ops from a caller to a sub-agent, respecting cancellation.
|
||||
async fn forward_ops(
|
||||
codex: Arc<Codex>,
|
||||
|
|
@ -298,3 +338,85 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_channel::bounded;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RawResponseItemEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn forward_events_cancelled_while_send_blocked_shuts_down_delegate() {
|
||||
let (tx_events, rx_events) = bounded(1);
|
||||
let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
let codex = Arc::new(Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub,
|
||||
rx_event: rx_events,
|
||||
});
|
||||
|
||||
let (session, ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx();
|
||||
|
||||
let (tx_out, rx_out) = bounded(1);
|
||||
tx_out
|
||||
.send(Event {
|
||||
id: "full".to_string(),
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent {
|
||||
reason: TurnAbortReason::Interrupted,
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
let forward = tokio::spawn(forward_events(
|
||||
Arc::clone(&codex),
|
||||
tx_out.clone(),
|
||||
session,
|
||||
ctx,
|
||||
cancel.clone(),
|
||||
));
|
||||
|
||||
tx_events
|
||||
.send(Event {
|
||||
id: "evt".to_string(),
|
||||
msg: EventMsg::RawResponseItem(RawResponseItemEvent {
|
||||
item: ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "call-1".to_string(),
|
||||
name: "tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
drop(tx_events);
|
||||
cancel.cancel();
|
||||
timeout(std::time::Duration::from_millis(1000), forward)
|
||||
.await
|
||||
.expect("forward_events hung")
|
||||
.expect("forward_events join error");
|
||||
|
||||
let received = rx_out.recv().await.expect("prefilled event missing");
|
||||
assert_eq!("full", received.id);
|
||||
let mut ops = Vec::new();
|
||||
while let Ok(sub) = rx_sub.try_recv() {
|
||||
ops.push(sub.op);
|
||||
}
|
||||
assert!(
|
||||
ops.iter().any(|op| matches!(op, Op::Interrupt)),
|
||||
"expected Interrupt op after cancellation"
|
||||
);
|
||||
assert!(
|
||||
ops.iter().any(|op| matches!(op, Op::Shutdown)),
|
||||
"expected Shutdown op after cancellation"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue