feat(app-server): thread/rollback API (#8454)
Add `thread/rollback` to app-server to support IDEs undo-ing the last N turns of a thread. For context, an IDE partner will be supporting an "undo" capability where the IDE (the app-server client) will be responsible for reverting the local changes made during the last turn. To support this well, we also need a way to drop the last turn (or more generally, the last N turns) from the agent's context. This is what `thread/rollback` does. **Core idea**: A Thread rollback is represented as a persisted event message (EventMsg::ThreadRollback) in the rollout JSONL file, not by rewriting history. On resume, both the model's context (core replay) and the UI turn list (app-server v2's thread history builder) apply these markers so the pruned history is consistent across live conversations and `thread/resume`. Implementation notes: - Rollback only affects agent context and appends to the rollout file; clients are responsible for reverting files on disk. - If a thread rollback is currently in progress, subsequent `thread/rollback` calls are rejected. - Because we use `CodexConversation::submit` and codex core tracks active turns, returning an error on concurrent rollbacks is communicated via an `EventMsg::Error` with a new variant `CodexErrorInfo::ThreadRollbackFailed`. app-server watches for that and sends the BAD_REQUEST RPC response. Tests cover thread rollbacks in both core and app-server, including when `num_turns` > existing turns (which clears all turns). **Note**: this explicitly does **not** behave like `/undo` which we just removed from the CLI, which does the opposite of what `thread/rollback` does. `/undo` reverts local changes via ghost commits/snapshots and does not modify the agent's context / conversation history.
This commit is contained in:
parent
188f79afee
commit
8b7ec31ba7
21 changed files with 1187 additions and 30 deletions
|
|
@ -113,6 +113,10 @@ client_request_definitions! {
|
|||
params: v2::ThreadArchiveParams,
|
||||
response: v2::ThreadArchiveResponse,
|
||||
},
|
||||
ThreadRollback => "thread/rollback" {
|
||||
params: v2::ThreadRollbackParams,
|
||||
response: v2::ThreadRollbackResponse,
|
||||
},
|
||||
ThreadList => "thread/list" {
|
||||
params: v2::ThreadListParams,
|
||||
response: v2::ThreadListResponse,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ use crate::protocol::v2::UserInput;
|
|||
use codex_protocol::protocol::AgentReasoningEvent;
|
||||
use codex_protocol::protocol::AgentReasoningRawContentEvent;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::ThreadRolledBackEvent;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::UserMessageEvent;
|
||||
|
||||
|
|
@ -57,6 +58,7 @@ impl ThreadHistoryBuilder {
|
|||
EventMsg::TokenCount(_) => {}
|
||||
EventMsg::EnteredReviewMode(_) => {}
|
||||
EventMsg::ExitedReviewMode(_) => {}
|
||||
EventMsg::ThreadRolledBack(payload) => self.handle_thread_rollback(payload),
|
||||
EventMsg::UndoCompleted(_) => {}
|
||||
EventMsg::TurnAborted(payload) => self.handle_turn_aborted(payload),
|
||||
_ => {}
|
||||
|
|
@ -130,6 +132,23 @@ impl ThreadHistoryBuilder {
|
|||
turn.status = TurnStatus::Interrupted;
|
||||
}
|
||||
|
||||
fn handle_thread_rollback(&mut self, payload: &ThreadRolledBackEvent) {
|
||||
self.finish_current_turn();
|
||||
|
||||
let n = usize::try_from(payload.num_turns).unwrap_or(usize::MAX);
|
||||
if n >= self.turns.len() {
|
||||
self.turns.clear();
|
||||
} else {
|
||||
self.turns.truncate(self.turns.len().saturating_sub(n));
|
||||
}
|
||||
|
||||
// Re-number subsequent synthetic ids so the pruned history is consistent.
|
||||
self.next_turn_index =
|
||||
i64::try_from(self.turns.len().saturating_add(1)).unwrap_or(i64::MAX);
|
||||
let item_count: usize = self.turns.iter().map(|t| t.items.len()).sum();
|
||||
self.next_item_index = i64::try_from(item_count.saturating_add(1)).unwrap_or(i64::MAX);
|
||||
}
|
||||
|
||||
fn finish_current_turn(&mut self) {
|
||||
if let Some(turn) = self.current_turn.take() {
|
||||
if turn.items.is_empty() {
|
||||
|
|
@ -213,6 +232,7 @@ mod tests {
|
|||
use codex_protocol::protocol::AgentMessageEvent;
|
||||
use codex_protocol::protocol::AgentReasoningEvent;
|
||||
use codex_protocol::protocol::AgentReasoningRawContentEvent;
|
||||
use codex_protocol::protocol::ThreadRolledBackEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::UserMessageEvent;
|
||||
|
|
@ -410,4 +430,95 @@ mod tests {
|
|||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drops_last_turns_on_thread_rollback() {
|
||||
let events = vec![
|
||||
EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "First".into(),
|
||||
images: None,
|
||||
}),
|
||||
EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "A1".into(),
|
||||
}),
|
||||
EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "Second".into(),
|
||||
images: None,
|
||||
}),
|
||||
EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "A2".into(),
|
||||
}),
|
||||
EventMsg::ThreadRolledBack(ThreadRolledBackEvent { num_turns: 1 }),
|
||||
EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "Third".into(),
|
||||
images: None,
|
||||
}),
|
||||
EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "A3".into(),
|
||||
}),
|
||||
];
|
||||
|
||||
let turns = build_turns_from_event_msgs(&events);
|
||||
let expected = vec![
|
||||
Turn {
|
||||
id: "turn-1".into(),
|
||||
status: TurnStatus::Completed,
|
||||
error: None,
|
||||
items: vec![
|
||||
ThreadItem::UserMessage {
|
||||
id: "item-1".into(),
|
||||
content: vec![UserInput::Text {
|
||||
text: "First".into(),
|
||||
}],
|
||||
},
|
||||
ThreadItem::AgentMessage {
|
||||
id: "item-2".into(),
|
||||
text: "A1".into(),
|
||||
},
|
||||
],
|
||||
},
|
||||
Turn {
|
||||
id: "turn-2".into(),
|
||||
status: TurnStatus::Completed,
|
||||
error: None,
|
||||
items: vec![
|
||||
ThreadItem::UserMessage {
|
||||
id: "item-3".into(),
|
||||
content: vec![UserInput::Text {
|
||||
text: "Third".into(),
|
||||
}],
|
||||
},
|
||||
ThreadItem::AgentMessage {
|
||||
id: "item-4".into(),
|
||||
text: "A3".into(),
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
assert_eq!(turns, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thread_rollback_clears_all_turns_when_num_turns_exceeds_history() {
|
||||
let events = vec![
|
||||
EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "One".into(),
|
||||
images: None,
|
||||
}),
|
||||
EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "A1".into(),
|
||||
}),
|
||||
EventMsg::UserMessage(UserMessageEvent {
|
||||
message: "Two".into(),
|
||||
images: None,
|
||||
}),
|
||||
EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "A2".into(),
|
||||
}),
|
||||
EventMsg::ThreadRolledBack(ThreadRolledBackEvent { num_turns: 99 }),
|
||||
];
|
||||
|
||||
let turns = build_turns_from_event_msgs(&events);
|
||||
assert_eq!(turns, Vec::<Turn>::new());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ pub enum CodexErrorInfo {
|
|||
InternalServerError,
|
||||
Unauthorized,
|
||||
BadRequest,
|
||||
ThreadRollbackFailed,
|
||||
SandboxError,
|
||||
/// The response SSE stream disconnected in the middle of a turn before completion.
|
||||
ResponseStreamDisconnected {
|
||||
|
|
@ -119,6 +120,7 @@ impl From<CoreCodexErrorInfo> for CodexErrorInfo {
|
|||
CoreCodexErrorInfo::InternalServerError => CodexErrorInfo::InternalServerError,
|
||||
CoreCodexErrorInfo::Unauthorized => CodexErrorInfo::Unauthorized,
|
||||
CoreCodexErrorInfo::BadRequest => CodexErrorInfo::BadRequest,
|
||||
CoreCodexErrorInfo::ThreadRollbackFailed => CodexErrorInfo::ThreadRollbackFailed,
|
||||
CoreCodexErrorInfo::SandboxError => CodexErrorInfo::SandboxError,
|
||||
CoreCodexErrorInfo::ResponseStreamDisconnected { http_status_code } => {
|
||||
CodexErrorInfo::ResponseStreamDisconnected { http_status_code }
|
||||
|
|
@ -1055,6 +1057,30 @@ pub struct ThreadArchiveParams {
|
|||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadArchiveResponse {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadRollbackParams {
|
||||
pub thread_id: String,
|
||||
/// The number of turns to drop from the end of the thread. Must be >= 1.
|
||||
///
|
||||
/// This only modifies the thread's history and does not revert local file changes
|
||||
/// that have been made by the agent. Clients are responsible for reverting these changes.
|
||||
pub num_turns: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadRollbackResponse {
|
||||
/// The updated thread after applying the rollback, with `turns` populated.
|
||||
///
|
||||
/// The ThreadItems stored in each Turn are lossy since we explicitly do not
|
||||
/// persist all agent interactions, such as command executions. This is the same
|
||||
/// behavior as `thread/resume`.
|
||||
pub thread: Thread,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
|
|
@ -1193,7 +1219,7 @@ pub struct Thread {
|
|||
pub source: SessionSource,
|
||||
/// Optional Git metadata captured when the thread was created.
|
||||
pub git_info: Option<GitInfo>,
|
||||
/// Only populated on a `thread/resume` response.
|
||||
/// Only populated on `thread/resume` and `thread/rollback` responses.
|
||||
/// For all other responses and notifications returning a Thread,
|
||||
/// the turns field will be an empty list.
|
||||
pub turns: Vec<Turn>,
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ Example (from OpenAI's official VSCode extension):
|
|||
- `thread/resume` — reopen an existing thread by id so subsequent `turn/start` calls append to it.
|
||||
- `thread/list` — page through stored rollouts; supports cursor-based pagination and optional `modelProviders` filtering.
|
||||
- `thread/archive` — move a thread’s rollout file into the archived directory; returns `{}` on success.
|
||||
- `thread/rollback` — drop the last N turns from the agent’s in-memory context and persist a rollback marker in the rollout so future resumes see the pruned history; returns the updated `thread` (with `turns` populated) on success.
|
||||
- `turn/start` — add user input to a thread and begin Codex generation; responds with the initial `turn` object and streams `turn/started`, `item/*`, and `turn/completed` notifications.
|
||||
- `turn/interrupt` — request cancellation of an in-flight turn by `(thread_id, turn_id)`; success is an empty `{}` response and the turn finishes with `status: "interrupted"`.
|
||||
- `review/start` — kick off Codex’s automated reviewer for a thread; responds like `turn/start` and emits `item/started`/`item/completed` notifications with `enteredReviewMode` and `exitedReviewMode` items, plus a final assistant `agentMessage` containing the review.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
use crate::codex_message_processor::ApiVersion;
|
||||
use crate::codex_message_processor::PendingInterrupts;
|
||||
use crate::codex_message_processor::PendingRollbacks;
|
||||
use crate::codex_message_processor::TurnSummary;
|
||||
use crate::codex_message_processor::TurnSummaryStore;
|
||||
use crate::codex_message_processor::read_event_msgs_from_rollout;
|
||||
use crate::codex_message_processor::read_summary_from_rollout;
|
||||
use crate::codex_message_processor::summary_to_thread;
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
|
||||
use codex_app_server_protocol::AgentMessageDeltaNotification;
|
||||
|
|
@ -27,6 +33,7 @@ use codex_app_server_protocol::FileUpdateChange;
|
|||
use codex_app_server_protocol::InterruptConversationResponse;
|
||||
use codex_app_server_protocol::ItemCompletedNotification;
|
||||
use codex_app_server_protocol::ItemStartedNotification;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::McpToolCallError;
|
||||
use codex_app_server_protocol::McpToolCallResult;
|
||||
use codex_app_server_protocol::McpToolCallStatus;
|
||||
|
|
@ -40,6 +47,7 @@ use codex_app_server_protocol::ServerNotification;
|
|||
use codex_app_server_protocol::ServerRequestPayload;
|
||||
use codex_app_server_protocol::TerminalInteractionNotification;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadRollbackResponse;
|
||||
use codex_app_server_protocol::ThreadTokenUsage;
|
||||
use codex_app_server_protocol::ThreadTokenUsageUpdatedNotification;
|
||||
use codex_app_server_protocol::Turn;
|
||||
|
|
@ -50,9 +58,11 @@ use codex_app_server_protocol::TurnInterruptResponse;
|
|||
use codex_app_server_protocol::TurnPlanStep;
|
||||
use codex_app_server_protocol::TurnPlanUpdatedNotification;
|
||||
use codex_app_server_protocol::TurnStatus;
|
||||
use codex_app_server_protocol::build_turns_from_event_msgs;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::parse_command::shlex_join;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::CodexErrorInfo as CoreCodexErrorInfo;
|
||||
use codex_core::protocol::Event;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
|
|
@ -78,14 +88,17 @@ use tracing::error;
|
|||
|
||||
type JsonValue = serde_json::Value;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn apply_bespoke_event_handling(
|
||||
event: Event,
|
||||
conversation_id: ConversationId,
|
||||
conversation: Arc<CodexConversation>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
pending_interrupts: PendingInterrupts,
|
||||
pending_rollbacks: PendingRollbacks,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
api_version: ApiVersion,
|
||||
fallback_model_provider: String,
|
||||
) {
|
||||
let Event {
|
||||
id: event_turn_id,
|
||||
|
|
@ -337,6 +350,26 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
.await;
|
||||
}
|
||||
EventMsg::Error(ev) => {
|
||||
let message = ev.message.clone();
|
||||
let codex_error_info = ev.codex_error_info.clone();
|
||||
|
||||
// If this error belongs to an in-flight `thread/rollback` request, fail that request
|
||||
// (and clear pending state) so subsequent rollbacks are unblocked.
|
||||
//
|
||||
// Don't send a notification for this error.
|
||||
if matches!(
|
||||
codex_error_info,
|
||||
Some(CoreCodexErrorInfo::ThreadRollbackFailed)
|
||||
) {
|
||||
return handle_thread_rollback_failed(
|
||||
conversation_id,
|
||||
message,
|
||||
&pending_rollbacks,
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
};
|
||||
|
||||
let turn_error = TurnError {
|
||||
message: ev.message,
|
||||
codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from),
|
||||
|
|
@ -345,7 +378,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await;
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::Error(ErrorNotification {
|
||||
error: turn_error,
|
||||
error: turn_error.clone(),
|
||||
will_retry: false,
|
||||
thread_id: conversation_id.to_string(),
|
||||
turn_id: event_turn_id.clone(),
|
||||
|
|
@ -690,6 +723,58 @@ pub(crate) async fn apply_bespoke_event_handling(
|
|||
)
|
||||
.await;
|
||||
}
|
||||
EventMsg::ThreadRolledBack(_rollback_event) => {
|
||||
let pending = {
|
||||
let mut map = pending_rollbacks.lock().await;
|
||||
map.remove(&conversation_id)
|
||||
};
|
||||
|
||||
if let Some(request_id) = pending {
|
||||
let rollout_path = conversation.rollout_path();
|
||||
let response = match read_summary_from_rollout(
|
||||
rollout_path.as_path(),
|
||||
fallback_model_provider.as_str(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(summary) => {
|
||||
let mut thread = summary_to_thread(summary);
|
||||
match read_event_msgs_from_rollout(rollout_path.as_path()).await {
|
||||
Ok(events) => {
|
||||
thread.turns = build_turns_from_event_msgs(&events);
|
||||
ThreadRollbackResponse { thread }
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!(
|
||||
"failed to load rollout `{}`: {err}",
|
||||
rollout_path.display()
|
||||
),
|
||||
data: None,
|
||||
};
|
||||
outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!(
|
||||
"failed to load rollout `{}`: {err}",
|
||||
rollout_path.display()
|
||||
),
|
||||
data: None,
|
||||
};
|
||||
outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
}
|
||||
EventMsg::TurnDiff(turn_diff_event) => {
|
||||
handle_turn_diff(
|
||||
conversation_id,
|
||||
|
|
@ -906,6 +991,31 @@ async fn handle_turn_interrupted(
|
|||
.await;
|
||||
}
|
||||
|
||||
async fn handle_thread_rollback_failed(
|
||||
conversation_id: ConversationId,
|
||||
message: String,
|
||||
pending_rollbacks: &PendingRollbacks,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
) {
|
||||
let pending_rollback = {
|
||||
let mut map = pending_rollbacks.lock().await;
|
||||
map.remove(&conversation_id)
|
||||
};
|
||||
|
||||
if let Some(request_id) = pending_rollback {
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id,
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: message.clone(),
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_token_count_event(
|
||||
conversation_id: ConversationId,
|
||||
turn_id: String,
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ use codex_app_server_protocol::ThreadListParams;
|
|||
use codex_app_server_protocol::ThreadListResponse;
|
||||
use codex_app_server_protocol::ThreadResumeParams;
|
||||
use codex_app_server_protocol::ThreadResumeResponse;
|
||||
use codex_app_server_protocol::ThreadRollbackParams;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
use codex_app_server_protocol::ThreadStartResponse;
|
||||
use codex_app_server_protocol::ThreadStartedNotification;
|
||||
|
|
@ -178,6 +179,8 @@ use uuid::Uuid;
|
|||
type PendingInterruptQueue = Vec<(RequestId, ApiVersion)>;
|
||||
pub(crate) type PendingInterrupts = Arc<Mutex<HashMap<ConversationId, PendingInterruptQueue>>>;
|
||||
|
||||
pub(crate) type PendingRollbacks = Arc<Mutex<HashMap<ConversationId, RequestId>>>;
|
||||
|
||||
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
|
||||
#[derive(Default, Clone)]
|
||||
pub(crate) struct TurnSummary {
|
||||
|
|
@ -220,6 +223,8 @@ pub(crate) struct CodexMessageProcessor {
|
|||
active_login: Arc<Mutex<Option<ActiveLogin>>>,
|
||||
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
|
||||
pending_interrupts: PendingInterrupts,
|
||||
// Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives.
|
||||
pending_rollbacks: PendingRollbacks,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
pending_fuzzy_searches: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
|
||||
feedback: CodexFeedback,
|
||||
|
|
@ -275,6 +280,7 @@ impl CodexMessageProcessor {
|
|||
conversation_listeners: HashMap::new(),
|
||||
active_login: Arc::new(Mutex::new(None)),
|
||||
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
|
||||
pending_rollbacks: Arc::new(Mutex::new(HashMap::new())),
|
||||
turn_summary_store: Arc::new(Mutex::new(HashMap::new())),
|
||||
pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())),
|
||||
feedback,
|
||||
|
|
@ -365,6 +371,9 @@ impl CodexMessageProcessor {
|
|||
ClientRequest::ThreadArchive { request_id, params } => {
|
||||
self.thread_archive(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ThreadRollback { request_id, params } => {
|
||||
self.thread_rollback(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ThreadList { request_id, params } => {
|
||||
self.thread_list(request_id, params).await;
|
||||
}
|
||||
|
|
@ -1506,6 +1515,52 @@ impl CodexMessageProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
async fn thread_rollback(&mut self, request_id: RequestId, params: ThreadRollbackParams) {
|
||||
let ThreadRollbackParams {
|
||||
thread_id,
|
||||
num_turns,
|
||||
} = params;
|
||||
|
||||
if num_turns == 0 {
|
||||
self.send_invalid_request_error(request_id, "numTurns must be >= 1".to_string())
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let (conversation_id, conversation) =
|
||||
match self.conversation_from_thread_id(&thread_id).await {
|
||||
Ok(v) => v,
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
let mut map = self.pending_rollbacks.lock().await;
|
||||
if map.contains_key(&conversation_id) {
|
||||
self.send_invalid_request_error(
|
||||
request_id,
|
||||
"rollback already in progress for this thread".to_string(),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
map.insert(conversation_id, request_id.clone());
|
||||
}
|
||||
|
||||
if let Err(err) = conversation.submit(Op::ThreadRollback { num_turns }).await {
|
||||
// No ThreadRollback event will arrive if an error occurs.
|
||||
// Clean up and reply immediately.
|
||||
let mut map = self.pending_rollbacks.lock().await;
|
||||
map.remove(&conversation_id);
|
||||
|
||||
self.send_internal_error(request_id, format!("failed to start rollback: {err}"))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn thread_list(&self, request_id: RequestId, params: ThreadListParams) {
|
||||
let ThreadListParams {
|
||||
cursor,
|
||||
|
|
@ -3095,8 +3150,10 @@ impl CodexMessageProcessor {
|
|||
|
||||
let outgoing_for_task = self.outgoing.clone();
|
||||
let pending_interrupts = self.pending_interrupts.clone();
|
||||
let pending_rollbacks = self.pending_rollbacks.clone();
|
||||
let turn_summary_store = self.turn_summary_store.clone();
|
||||
let api_version_for_task = api_version;
|
||||
let fallback_model_provider = self.config.model_provider_id.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
|
|
@ -3152,8 +3209,10 @@ impl CodexMessageProcessor {
|
|||
conversation.clone(),
|
||||
outgoing_for_task.clone(),
|
||||
pending_interrupts.clone(),
|
||||
pending_rollbacks.clone(),
|
||||
turn_summary_store.clone(),
|
||||
api_version_for_task,
|
||||
fallback_model_provider.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -3354,7 +3413,7 @@ async fn derive_config_from_params(
|
|||
Config::load_with_cli_overrides_and_harness_overrides(cli_overrides, overrides).await
|
||||
}
|
||||
|
||||
async fn read_summary_from_rollout(
|
||||
pub(crate) async fn read_summary_from_rollout(
|
||||
path: &Path,
|
||||
fallback_provider: &str,
|
||||
) -> std::io::Result<ConversationSummary> {
|
||||
|
|
@ -3413,6 +3472,24 @@ async fn read_summary_from_rollout(
|
|||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn read_event_msgs_from_rollout(
|
||||
path: &Path,
|
||||
) -> std::io::Result<Vec<codex_protocol::protocol::EventMsg>> {
|
||||
let items = match RolloutRecorder::get_rollout_history(path).await? {
|
||||
InitialHistory::New => Vec::new(),
|
||||
InitialHistory::Forked(items) => items,
|
||||
InitialHistory::Resumed(resumed) => resumed.history,
|
||||
};
|
||||
|
||||
Ok(items
|
||||
.into_iter()
|
||||
.filter_map(|item| match item {
|
||||
RolloutItem::EventMsg(event) => Some(event),
|
||||
_ => None,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn extract_conversation_summary(
|
||||
path: PathBuf,
|
||||
head: &[serde_json::Value],
|
||||
|
|
@ -3474,7 +3551,7 @@ fn parse_datetime(timestamp: Option<&str>) -> Option<DateTime<Utc>> {
|
|||
})
|
||||
}
|
||||
|
||||
fn summary_to_thread(summary: ConversationSummary) -> Thread {
|
||||
pub(crate) fn summary_to_thread(summary: ConversationSummary) -> Thread {
|
||||
let ConversationSummary {
|
||||
conversation_id,
|
||||
path,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ use codex_app_server_protocol::SetDefaultModelParams;
|
|||
use codex_app_server_protocol::ThreadArchiveParams;
|
||||
use codex_app_server_protocol::ThreadListParams;
|
||||
use codex_app_server_protocol::ThreadResumeParams;
|
||||
use codex_app_server_protocol::ThreadRollbackParams;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
use codex_app_server_protocol::TurnInterruptParams;
|
||||
use codex_app_server_protocol::TurnStartParams;
|
||||
|
|
@ -316,6 +317,15 @@ impl McpProcess {
|
|||
self.send_request("thread/archive", params).await
|
||||
}
|
||||
|
||||
/// Send a `thread/rollback` JSON-RPC request.
|
||||
pub async fn send_thread_rollback_request(
|
||||
&mut self,
|
||||
params: ThreadRollbackParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("thread/rollback", params).await
|
||||
}
|
||||
|
||||
/// Send a `thread/list` JSON-RPC request.
|
||||
pub async fn send_thread_list_request(
|
||||
&mut self,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ mod review;
|
|||
mod thread_archive;
|
||||
mod thread_list;
|
||||
mod thread_resume;
|
||||
mod thread_rollback;
|
||||
mod thread_start;
|
||||
mod turn_interrupt;
|
||||
mod turn_start;
|
||||
|
|
|
|||
177
codex-rs/app-server/tests/suite/v2/thread_rollback.rs
Normal file
177
codex-rs/app-server/tests/suite/v2/thread_rollback.rs
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_chat_completions_server_unchecked;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadResumeParams;
|
||||
use codex_app_server_protocol::ThreadResumeResponse;
|
||||
use codex_app_server_protocol::ThreadRollbackParams;
|
||||
use codex_app_server_protocol::ThreadRollbackResponse;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
use codex_app_server_protocol::ThreadStartResponse;
|
||||
use codex_app_server_protocol::TurnStartParams;
|
||||
use codex_app_server_protocol::UserInput as V2UserInput;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_drops_last_turns_and_persists_to_rollout() -> Result<()> {
|
||||
// Three Codex turns hit the mock model (session start + two turn/start calls).
|
||||
let responses = vec![
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
create_final_assistant_message_sse_response("Done")?,
|
||||
];
|
||||
let server = create_mock_chat_completions_server_unchecked(responses).await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
|
||||
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
// Start a thread.
|
||||
let start_id = mcp
|
||||
.send_thread_start_request(ThreadStartParams {
|
||||
model: Some("mock-model".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let start_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(start_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(start_resp)?;
|
||||
|
||||
// Two turns.
|
||||
let first_text = "First";
|
||||
let turn1_id = mcp
|
||||
.send_turn_start_request(TurnStartParams {
|
||||
thread_id: thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: first_text.to_string(),
|
||||
}],
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let _turn1_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(turn1_id)),
|
||||
)
|
||||
.await??;
|
||||
let _completed1 = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("turn/completed"),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let turn2_id = mcp
|
||||
.send_turn_start_request(TurnStartParams {
|
||||
thread_id: thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: "Second".to_string(),
|
||||
}],
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let _turn2_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(turn2_id)),
|
||||
)
|
||||
.await??;
|
||||
let _completed2 = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("turn/completed"),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// Roll back the last turn.
|
||||
let rollback_id = mcp
|
||||
.send_thread_rollback_request(ThreadRollbackParams {
|
||||
thread_id: thread.id.clone(),
|
||||
num_turns: 1,
|
||||
})
|
||||
.await?;
|
||||
let rollback_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(rollback_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadRollbackResponse {
|
||||
thread: rolled_back_thread,
|
||||
} = to_response::<ThreadRollbackResponse>(rollback_resp)?;
|
||||
|
||||
assert_eq!(rolled_back_thread.turns.len(), 1);
|
||||
assert_eq!(rolled_back_thread.turns[0].items.len(), 2);
|
||||
match &rolled_back_thread.turns[0].items[0] {
|
||||
ThreadItem::UserMessage { content, .. } => {
|
||||
assert_eq!(
|
||||
content,
|
||||
&vec![V2UserInput::Text {
|
||||
text: first_text.to_string()
|
||||
}]
|
||||
);
|
||||
}
|
||||
other => panic!("expected user message item, got {other:?}"),
|
||||
}
|
||||
|
||||
// Resume and confirm the history is pruned.
|
||||
let resume_id = mcp
|
||||
.send_thread_resume_request(ThreadResumeParams {
|
||||
thread_id: thread.id,
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let resume_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(resume_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadResumeResponse { thread, .. } = to_response::<ThreadResumeResponse>(resume_resp)?;
|
||||
|
||||
assert_eq!(thread.turns.len(), 1);
|
||||
assert_eq!(thread.turns[0].items.len(), 2);
|
||||
match &thread.turns[0].items[0] {
|
||||
ThreadItem::UserMessage { content, .. } => {
|
||||
assert_eq!(
|
||||
content,
|
||||
&vec![V2UserInput::Text {
|
||||
text: first_text.to_string()
|
||||
}]
|
||||
);
|
||||
}
|
||||
other => panic!("expected user message item, got {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_config_toml(codex_home: &std::path::Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "never"
|
||||
sandbox_mode = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
@ -1028,6 +1028,25 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
/// Persist the event to the rollout file, flush it, and only then deliver it to clients.
|
||||
///
|
||||
/// Most events can be delivered immediately after queueing the rollout write, but some
|
||||
/// clients (e.g. app-server thread/rollback) re-read the rollout file synchronously on
|
||||
/// receipt of the event and depend on the marker already being visible on disk.
|
||||
pub(crate) async fn send_event_raw_flushed(&self, event: Event) {
|
||||
// Record the last known agent status.
|
||||
if let Some(status) = agent_status_from_event(&event.msg) {
|
||||
let mut guard = self.agent_status.write().await;
|
||||
*guard = status;
|
||||
}
|
||||
self.persist_rollout_items(&[RolloutItem::EventMsg(event.msg.clone())])
|
||||
.await;
|
||||
self.flush_rollout().await;
|
||||
if let Err(e) = self.tx_event.send(event).await {
|
||||
error!("failed to send tool call event: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn emit_turn_item_started(&self, turn_context: &TurnContext, item: &TurnItem) {
|
||||
self.send_event(
|
||||
turn_context,
|
||||
|
|
@ -1245,6 +1264,9 @@ impl Session {
|
|||
history.replace(rebuilt);
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
|
||||
history.drop_last_n_user_turns(rollback.num_turns);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
|
@ -1684,6 +1706,9 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||
Op::Compact => {
|
||||
handlers::compact(&sess, sub.id.clone()).await;
|
||||
}
|
||||
Op::ThreadRollback { num_turns } => {
|
||||
handlers::thread_rollback(&sess, sub.id.clone(), num_turns).await;
|
||||
}
|
||||
Op::RunUserShellCommand { command } => {
|
||||
handlers::run_user_shell_command(
|
||||
&sess,
|
||||
|
|
@ -1741,6 +1766,7 @@ mod handlers {
|
|||
use codex_protocol::protocol::ReviewDecision;
|
||||
use codex_protocol::protocol::ReviewRequest;
|
||||
use codex_protocol::protocol::SkillsListEntry;
|
||||
use codex_protocol::protocol::ThreadRolledBackEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::WarningEvent;
|
||||
|
||||
|
|
@ -2060,6 +2086,46 @@ mod handlers {
|
|||
.await;
|
||||
}
|
||||
|
||||
pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32) {
|
||||
if num_turns == 0 {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "num_turns must be >= 1".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let has_active_turn = { sess.active_turn.lock().await.is_some() };
|
||||
if has_active_turn {
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: "Cannot rollback while a turn is in progress.".to_string(),
|
||||
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
|
||||
let mut history = sess.clone_history().await;
|
||||
history.drop_last_n_user_turns(num_turns);
|
||||
sess.replace_history(history.get_history()).await;
|
||||
sess.recompute_token_usage(turn_context.as_ref()).await;
|
||||
|
||||
sess.send_event_raw_flushed(Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::ThreadRolledBack(ThreadRolledBackEvent { num_turns }),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn shutdown(sess: &Arc<Session>, sub_id: String) -> bool {
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
sess.services
|
||||
|
|
@ -2973,6 +3039,131 @@ mod tests {
|
|||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_drops_last_turn_from_history() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
let initial_context = sess.build_initial_context(tc.as_ref());
|
||||
sess.record_into_history(&initial_context, tc.as_ref())
|
||||
.await;
|
||||
|
||||
let turn_1 = vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "turn 1 user".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "turn 1 assistant".to_string(),
|
||||
}],
|
||||
},
|
||||
];
|
||||
sess.record_into_history(&turn_1, tc.as_ref()).await;
|
||||
|
||||
let turn_2 = vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "turn 2 user".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "turn 2 assistant".to_string(),
|
||||
}],
|
||||
},
|
||||
];
|
||||
sess.record_into_history(&turn_2, tc.as_ref()).await;
|
||||
|
||||
handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await;
|
||||
|
||||
let rollback_event = wait_for_thread_rolled_back(&rx).await;
|
||||
assert_eq!(rollback_event.num_turns, 1);
|
||||
|
||||
let mut expected = Vec::new();
|
||||
expected.extend(initial_context);
|
||||
expected.extend(turn_1);
|
||||
|
||||
let actual = sess.clone_history().await.get_history();
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
let initial_context = sess.build_initial_context(tc.as_ref());
|
||||
sess.record_into_history(&initial_context, tc.as_ref())
|
||||
.await;
|
||||
|
||||
let turn_1 = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "turn 1 user".to_string(),
|
||||
}],
|
||||
}];
|
||||
sess.record_into_history(&turn_1, tc.as_ref()).await;
|
||||
|
||||
handlers::thread_rollback(&sess, "sub-1".to_string(), 99).await;
|
||||
|
||||
let rollback_event = wait_for_thread_rolled_back(&rx).await;
|
||||
assert_eq!(rollback_event.num_turns, 99);
|
||||
|
||||
let actual = sess.clone_history().await.get_history();
|
||||
assert_eq!(initial_context, actual);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_fails_when_turn_in_progress() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
let initial_context = sess.build_initial_context(tc.as_ref());
|
||||
sess.record_into_history(&initial_context, tc.as_ref())
|
||||
.await;
|
||||
|
||||
*sess.active_turn.lock().await = Some(crate::state::ActiveTurn::default());
|
||||
handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await;
|
||||
|
||||
let error_event = wait_for_thread_rollback_failed(&rx).await;
|
||||
assert_eq!(
|
||||
error_event.codex_error_info,
|
||||
Some(CodexErrorInfo::ThreadRollbackFailed)
|
||||
);
|
||||
|
||||
let actual = sess.clone_history().await.get_history();
|
||||
assert_eq!(initial_context, actual);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_fails_when_num_turns_is_zero() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
let initial_context = sess.build_initial_context(tc.as_ref());
|
||||
sess.record_into_history(&initial_context, tc.as_ref())
|
||||
.await;
|
||||
|
||||
handlers::thread_rollback(&sess, "sub-1".to_string(), 0).await;
|
||||
|
||||
let error_event = wait_for_thread_rollback_failed(&rx).await;
|
||||
assert_eq!(error_event.message, "num_turns must be >= 1");
|
||||
assert_eq!(
|
||||
error_event.codex_error_info,
|
||||
Some(CodexErrorInfo::ThreadRollbackFailed)
|
||||
);
|
||||
|
||||
let actual = sess.clone_history().await.get_history();
|
||||
assert_eq!(initial_context, actual);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn set_rate_limits_retains_previous_credits() {
|
||||
let codex_home = tempfile::tempdir().expect("create temp dir");
|
||||
|
|
@ -3206,6 +3397,44 @@ mod tests {
|
|||
assert_eq!(expected, got);
|
||||
}
|
||||
|
||||
async fn wait_for_thread_rolled_back(
|
||||
rx: &async_channel::Receiver<Event>,
|
||||
) -> crate::protocol::ThreadRolledBackEvent {
|
||||
let deadline = StdDuration::from_secs(2);
|
||||
let start = std::time::Instant::now();
|
||||
loop {
|
||||
let remaining = deadline.saturating_sub(start.elapsed());
|
||||
let evt = tokio::time::timeout(remaining, rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("event");
|
||||
match evt.msg {
|
||||
EventMsg::ThreadRolledBack(payload) => return payload,
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_thread_rollback_failed(rx: &async_channel::Receiver<Event>) -> ErrorEvent {
|
||||
let deadline = StdDuration::from_secs(2);
|
||||
let start = std::time::Instant::now();
|
||||
loop {
|
||||
let remaining = deadline.saturating_sub(start.elapsed());
|
||||
let evt = tokio::time::timeout(remaining, rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("event");
|
||||
match evt.msg {
|
||||
EventMsg::Error(payload)
|
||||
if payload.codex_error_info == Some(CodexErrorInfo::ThreadRollbackFailed) =>
|
||||
{
|
||||
return payload;
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn text_block(s: &str) -> ContentBlock {
|
||||
ContentBlock::TextContent(TextContent {
|
||||
annotations: None,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@ use crate::truncate::approx_token_count;
|
|||
use crate::truncate::approx_tokens_from_byte_count;
|
||||
use crate::truncate::truncate_function_output_items_with_policy;
|
||||
use crate::truncate::truncate_text;
|
||||
use crate::user_instructions::SkillInstructions;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
use crate::user_shell_command::is_user_shell_command_text;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
|
|
@ -152,6 +155,39 @@ impl ContextManager {
|
|||
}
|
||||
}
|
||||
|
||||
/// Drop the last `num_turns` user turns from this history.
|
||||
///
|
||||
/// "User turns" are identified as `ResponseItem::Message` entries whose role is `"user"`.
|
||||
///
|
||||
/// This mirrors thread-rollback semantics:
|
||||
/// - `num_turns == 0` is a no-op
|
||||
/// - if there are no user turns, this is a no-op
|
||||
/// - if `num_turns` exceeds the number of user turns, all user turns are dropped while
|
||||
/// preserving any items that occurred before the first user message.
|
||||
pub(crate) fn drop_last_n_user_turns(&mut self, num_turns: u32) {
|
||||
if num_turns == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
// Keep behavior consistent with call sites that previously operated on `get_history()`:
|
||||
// normalize first (call/output invariants), then truncate based on the normalized view.
|
||||
let snapshot = self.get_history();
|
||||
let user_positions = user_message_positions(&snapshot);
|
||||
let Some(&first_user_idx) = user_positions.first() else {
|
||||
self.replace(snapshot);
|
||||
return;
|
||||
};
|
||||
|
||||
let n_from_end = usize::try_from(num_turns).unwrap_or(usize::MAX);
|
||||
let cut_idx = if n_from_end >= user_positions.len() {
|
||||
first_user_idx
|
||||
} else {
|
||||
user_positions[user_positions.len() - n_from_end]
|
||||
};
|
||||
|
||||
self.replace(snapshot[..cut_idx].to_vec());
|
||||
}
|
||||
|
||||
pub(crate) fn update_token_info(
|
||||
&mut self,
|
||||
usage: &TokenUsage,
|
||||
|
|
@ -291,6 +327,56 @@ fn estimate_reasoning_length(encoded_len: usize) -> usize {
|
|||
.saturating_sub(650)
|
||||
}
|
||||
|
||||
fn is_session_prefix(text: &str) -> bool {
|
||||
let trimmed = text.trim_start();
|
||||
let lowered = trimmed.to_ascii_lowercase();
|
||||
lowered.starts_with("<environment_context>")
|
||||
}
|
||||
|
||||
fn is_user_turn_boundary(item: &ResponseItem) -> bool {
|
||||
let ResponseItem::Message { role, content, .. } = item else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if role != "user" {
|
||||
return false;
|
||||
}
|
||||
|
||||
if UserInstructions::is_user_instructions(content)
|
||||
|| SkillInstructions::is_skill_instructions(content)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
for content_item in content {
|
||||
match content_item {
|
||||
ContentItem::InputText { text } => {
|
||||
if is_session_prefix(text) || is_user_shell_command_text(text) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ContentItem::OutputText { text } => {
|
||||
if is_session_prefix(text) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ContentItem::InputImage { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn user_message_positions(items: &[ResponseItem]) -> Vec<usize> {
|
||||
let mut positions = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
if is_user_turn_boundary(item) {
|
||||
positions.push(idx);
|
||||
}
|
||||
}
|
||||
positions
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "history_tests.rs"]
|
||||
mod tests;
|
||||
|
|
|
|||
|
|
@ -43,6 +43,16 @@ fn user_msg(text: &str) -> ResponseItem {
|
|||
}
|
||||
}
|
||||
|
||||
fn user_input_text_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
fn reasoning_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
|
|
@ -227,6 +237,127 @@ fn remove_first_item_handles_local_shell_pair() {
|
|||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drop_last_n_user_turns_preserves_prefix() {
|
||||
let items = vec![
|
||||
assistant_msg("session prefix item"),
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
user_msg("u2"),
|
||||
assistant_msg("a2"),
|
||||
];
|
||||
|
||||
let mut history = create_history_with_items(items);
|
||||
history.drop_last_n_user_turns(1);
|
||||
assert_eq!(
|
||||
history.get_history(),
|
||||
vec![
|
||||
assistant_msg("session prefix item"),
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
]
|
||||
);
|
||||
|
||||
let mut history = create_history_with_items(vec![
|
||||
assistant_msg("session prefix item"),
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
user_msg("u2"),
|
||||
assistant_msg("a2"),
|
||||
]);
|
||||
history.drop_last_n_user_turns(99);
|
||||
assert_eq!(
|
||||
history.get_history(),
|
||||
vec![assistant_msg("session prefix item")]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
|
||||
let items = vec![
|
||||
user_input_text_msg("<environment_context>ctx</environment_context>"),
|
||||
user_input_text_msg("<user_instructions>do the thing</user_instructions>"),
|
||||
user_input_text_msg(
|
||||
"# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>",
|
||||
),
|
||||
user_input_text_msg(
|
||||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg("turn 1 user"),
|
||||
assistant_msg("turn 1 assistant"),
|
||||
user_input_text_msg("turn 2 user"),
|
||||
assistant_msg("turn 2 assistant"),
|
||||
];
|
||||
|
||||
let mut history = create_history_with_items(items);
|
||||
history.drop_last_n_user_turns(1);
|
||||
|
||||
let expected_prefix_and_first_turn = vec![
|
||||
user_input_text_msg("<environment_context>ctx</environment_context>"),
|
||||
user_input_text_msg("<user_instructions>do the thing</user_instructions>"),
|
||||
user_input_text_msg(
|
||||
"# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>",
|
||||
),
|
||||
user_input_text_msg(
|
||||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg("turn 1 user"),
|
||||
assistant_msg("turn 1 assistant"),
|
||||
];
|
||||
|
||||
assert_eq!(history.get_history(), expected_prefix_and_first_turn);
|
||||
|
||||
let expected_prefix_only = vec![
|
||||
user_input_text_msg("<environment_context>ctx</environment_context>"),
|
||||
user_input_text_msg("<user_instructions>do the thing</user_instructions>"),
|
||||
user_input_text_msg(
|
||||
"# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>",
|
||||
),
|
||||
user_input_text_msg(
|
||||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
];
|
||||
|
||||
let mut history = create_history_with_items(vec![
|
||||
user_input_text_msg("<environment_context>ctx</environment_context>"),
|
||||
user_input_text_msg("<user_instructions>do the thing</user_instructions>"),
|
||||
user_input_text_msg(
|
||||
"# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>",
|
||||
),
|
||||
user_input_text_msg(
|
||||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg("turn 1 user"),
|
||||
assistant_msg("turn 1 assistant"),
|
||||
user_input_text_msg("turn 2 user"),
|
||||
assistant_msg("turn 2 assistant"),
|
||||
]);
|
||||
history.drop_last_n_user_turns(2);
|
||||
assert_eq!(history.get_history(), expected_prefix_only);
|
||||
|
||||
let mut history = create_history_with_items(vec![
|
||||
user_input_text_msg("<environment_context>ctx</environment_context>"),
|
||||
user_input_text_msg("<user_instructions>do the thing</user_instructions>"),
|
||||
user_input_text_msg(
|
||||
"# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>",
|
||||
),
|
||||
user_input_text_msg(
|
||||
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
|
||||
),
|
||||
user_input_text_msg("<user_shell_command>echo 42</user_shell_command>"),
|
||||
user_input_text_msg("turn 1 user"),
|
||||
assistant_msg("turn 1 assistant"),
|
||||
user_input_text_msg("turn 2 user"),
|
||||
assistant_msg("turn 2 assistant"),
|
||||
]);
|
||||
history.drop_last_n_user_turns(3);
|
||||
assert_eq!(history.get_history(), expected_prefix_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_handles_custom_tool_pair() {
|
||||
let items = vec![
|
||||
|
|
|
|||
|
|
@ -16,10 +16,9 @@ use crate::protocol::Event;
|
|||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::rollout::truncation;
|
||||
use crate::skills::SkillsManager;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::openai_models::ModelPreset;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::Op;
|
||||
|
|
@ -307,30 +306,8 @@ impl ConversationManagerState {
|
|||
/// Return a prefix of `items` obtained by cutting strictly before the nth user message
|
||||
/// (0-based) and all items that follow it.
|
||||
fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> InitialHistory {
|
||||
// Work directly on rollout items, and cut the vector at the nth user message input.
|
||||
let items: Vec<RolloutItem> = history.get_rollout_items();
|
||||
|
||||
// Find indices of user message inputs in rollout order.
|
||||
let mut user_positions: Vec<usize> = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(item @ ResponseItem::Message { .. }) = item
|
||||
&& matches!(
|
||||
crate::event_mapping::parse_turn_item(item),
|
||||
Some(TurnItem::UserMessage(_))
|
||||
)
|
||||
{
|
||||
user_positions.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// If fewer than or equal to n user messages exist, treat as empty (out of range).
|
||||
if user_positions.len() <= n {
|
||||
return InitialHistory::New;
|
||||
}
|
||||
|
||||
// Cut strictly before the nth user message (do not keep the nth itself).
|
||||
let cut_idx = user_positions[n];
|
||||
let rolled: Vec<RolloutItem> = items.into_iter().take(cut_idx).collect();
|
||||
let rolled = truncation::truncate_rollout_before_nth_user_message_from_start(&items, n);
|
||||
|
||||
if rolled.is_empty() {
|
||||
InitialHistory::New
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ pub(crate) mod error;
|
|||
pub mod list;
|
||||
pub(crate) mod policy;
|
||||
pub mod recorder;
|
||||
pub(crate) mod truncation;
|
||||
|
||||
pub use codex_protocol::protocol::SessionMeta;
|
||||
pub(crate) use error::map_session_init_error;
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool {
|
|||
| EventMsg::ContextCompacted(_)
|
||||
| EventMsg::EnteredReviewMode(_)
|
||||
| EventMsg::ExitedReviewMode(_)
|
||||
| EventMsg::ThreadRolledBack(_)
|
||||
| EventMsg::UndoCompleted(_)
|
||||
| EventMsg::TurnAborted(_) => true,
|
||||
EventMsg::Error(_)
|
||||
|
|
|
|||
195
codex-rs/core/src/rollout/truncation.rs
Normal file
195
codex-rs/core/src/rollout/truncation.rs
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
//! Helpers for truncating rollouts based on "user turn" boundaries.
|
||||
//!
|
||||
//! In core, "user turns" are detected by scanning `ResponseItem::Message` items and
|
||||
//! interpreting them via `event_mapping::parse_turn_item(...)`.
|
||||
|
||||
use crate::event_mapping;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
|
||||
/// Return the indices of user message boundaries in a rollout.
|
||||
///
|
||||
/// A user message boundary is a `RolloutItem::ResponseItem(ResponseItem::Message { .. })`
|
||||
/// whose parsed turn item is `TurnItem::UserMessage`.
|
||||
///
|
||||
/// Rollouts can contain `ThreadRolledBack` markers. Those markers indicate that the
|
||||
/// last N user turns were removed from the effective thread history; we apply them here so
|
||||
/// indexing uses the post-rollback history rather than the raw stream.
|
||||
pub(crate) fn user_message_positions_in_rollout(items: &[RolloutItem]) -> Vec<usize> {
|
||||
let mut user_positions = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item @ ResponseItem::Message { .. })
|
||||
if matches!(
|
||||
event_mapping::parse_turn_item(item),
|
||||
Some(TurnItem::UserMessage(_))
|
||||
) =>
|
||||
{
|
||||
user_positions.push(idx);
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
|
||||
let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
|
||||
let new_len = user_positions.len().saturating_sub(num_turns);
|
||||
user_positions.truncate(new_len);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
user_positions
|
||||
}
|
||||
|
||||
/// Return a prefix of `items` obtained by cutting strictly before the nth user message.
|
||||
///
|
||||
/// The boundary index is 0-based from the start of `items` (so `n_from_start = 0` returns
|
||||
/// a prefix that excludes the first user message and everything after it).
|
||||
///
|
||||
/// If fewer than or equal to `n_from_start` user messages exist, this returns an empty
|
||||
/// vector (out of range).
|
||||
pub(crate) fn truncate_rollout_before_nth_user_message_from_start(
|
||||
items: &[RolloutItem],
|
||||
n_from_start: usize,
|
||||
) -> Vec<RolloutItem> {
|
||||
let user_positions = user_message_positions_in_rollout(items);
|
||||
|
||||
// If fewer than or equal to n user messages exist, treat as empty (out of range).
|
||||
if user_positions.len() <= n_from_start {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// Cut strictly before the nth user message (do not keep the nth itself).
|
||||
let cut_idx = user_positions[n_from_start];
|
||||
items[..cut_idx].to_vec()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codex::make_session_and_context;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::protocol::ThreadRolledBackEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn user_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_rollout_from_start_before_nth_user_only() {
|
||||
let items = [
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
assistant_msg("a2"),
|
||||
user_msg("u2"),
|
||||
assistant_msg("a3"),
|
||||
ResponseItem::Reasoning {
|
||||
id: "r1".to_string(),
|
||||
summary: vec![ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "s".to_string(),
|
||||
}],
|
||||
content: None,
|
||||
encrypted_content: None,
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "tool".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
assistant_msg("a4"),
|
||||
];
|
||||
|
||||
let rollout: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
|
||||
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, 1);
|
||||
let expected = vec![
|
||||
RolloutItem::ResponseItem(items[0].clone()),
|
||||
RolloutItem::ResponseItem(items[1].clone()),
|
||||
RolloutItem::ResponseItem(items[2].clone()),
|
||||
];
|
||||
assert_eq!(
|
||||
serde_json::to_value(&truncated).unwrap(),
|
||||
serde_json::to_value(&expected).unwrap()
|
||||
);
|
||||
|
||||
let truncated2 = truncate_rollout_before_nth_user_message_from_start(&rollout, 2);
|
||||
assert_matches!(truncated2.as_slice(), []);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_rollout_from_start_applies_thread_rollback_markers() {
|
||||
let rollout_items = vec![
|
||||
RolloutItem::ResponseItem(user_msg("u1")),
|
||||
RolloutItem::ResponseItem(assistant_msg("a1")),
|
||||
RolloutItem::ResponseItem(user_msg("u2")),
|
||||
RolloutItem::ResponseItem(assistant_msg("a2")),
|
||||
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent {
|
||||
num_turns: 1,
|
||||
})),
|
||||
RolloutItem::ResponseItem(user_msg("u3")),
|
||||
RolloutItem::ResponseItem(assistant_msg("a3")),
|
||||
RolloutItem::ResponseItem(user_msg("u4")),
|
||||
RolloutItem::ResponseItem(assistant_msg("a4")),
|
||||
];
|
||||
|
||||
// Effective user history after applying rollback(1) is: u1, u3, u4.
|
||||
// So n_from_start=2 should cut before u4 (not u3).
|
||||
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 2);
|
||||
let expected = rollout_items[..7].to_vec();
|
||||
assert_eq!(
|
||||
serde_json::to_value(&truncated).unwrap(),
|
||||
serde_json::to_value(&expected).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
let mut items = session.build_initial_context(&turn_context);
|
||||
items.push(user_msg("feature request"));
|
||||
items.push(assistant_msg("ack"));
|
||||
items.push(user_msg("second question"));
|
||||
items.push(assistant_msg("answer"));
|
||||
|
||||
let rollout_items: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
|
||||
let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1);
|
||||
let expected: Vec<RolloutItem> = vec![
|
||||
RolloutItem::ResponseItem(items[0].clone()),
|
||||
RolloutItem::ResponseItem(items[1].clone()),
|
||||
RolloutItem::ResponseItem(items[2].clone()),
|
||||
];
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&truncated).unwrap(),
|
||||
serde_json::to_value(&expected).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -595,7 +595,8 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
|||
| EventMsg::ReasoningRawContentDelta(_)
|
||||
| EventMsg::SkillsUpdateAvailable
|
||||
| EventMsg::UndoCompleted(_)
|
||||
| EventMsg::UndoStarted(_) => {}
|
||||
| EventMsg::UndoStarted(_)
|
||||
| EventMsg::ThreadRolledBack(_) => {}
|
||||
}
|
||||
CodexStatus::Running
|
||||
}
|
||||
|
|
|
|||
|
|
@ -311,6 +311,7 @@ async fn run_codex_tool_session_inner(
|
|||
| EventMsg::UndoCompleted(_)
|
||||
| EventMsg::ExitedReviewMode(_)
|
||||
| EventMsg::ContextCompacted(_)
|
||||
| EventMsg::ThreadRolledBack(_)
|
||||
| EventMsg::DeprecationNotice(_) => {
|
||||
// For now, we do not do anything extra for these
|
||||
// events. Note that
|
||||
|
|
|
|||
|
|
@ -210,6 +210,12 @@ pub enum Op {
|
|||
/// Request Codex to undo a turn (turn are stacked so it is the same effect as CMD + Z).
|
||||
Undo,
|
||||
|
||||
/// Request Codex to drop the last N user turns from in-memory context.
|
||||
///
|
||||
/// This does not attempt to revert local filesystem changes. Clients are
|
||||
/// responsible for undoing any edits on disk.
|
||||
ThreadRollback { num_turns: u32 },
|
||||
|
||||
/// Request a code review from the agent.
|
||||
Review { review_request: ReviewRequest },
|
||||
|
||||
|
|
@ -541,6 +547,9 @@ pub enum EventMsg {
|
|||
/// Conversation history was compacted (either automatically or manually).
|
||||
ContextCompacted(ContextCompactedEvent),
|
||||
|
||||
/// Conversation history was rolled back by dropping the last N user turns.
|
||||
ThreadRolledBack(ThreadRolledBackEvent),
|
||||
|
||||
/// Agent has started a task
|
||||
TaskStarted(TaskStartedEvent),
|
||||
|
||||
|
|
@ -718,6 +727,7 @@ pub enum CodexErrorInfo {
|
|||
ResponseTooManyFailedAttempts {
|
||||
http_status_code: Option<u16>,
|
||||
},
|
||||
ThreadRollbackFailed,
|
||||
Other,
|
||||
}
|
||||
|
||||
|
|
@ -1618,6 +1628,12 @@ pub struct UndoCompletedEvent {
|
|||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)]
|
||||
pub struct ThreadRolledBackEvent {
|
||||
/// Number of user turns that were removed from context.
|
||||
pub num_turns: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)]
|
||||
pub struct StreamErrorEvent {
|
||||
pub message: String,
|
||||
|
|
|
|||
|
|
@ -2114,6 +2114,7 @@ impl ChatWidget {
|
|||
}
|
||||
EventMsg::ExitedReviewMode(review) => self.on_exited_review_mode(review),
|
||||
EventMsg::ContextCompacted(_) => self.on_agent_message("Context compacted".to_owned()),
|
||||
EventMsg::ThreadRolledBack(_) => {}
|
||||
EventMsg::RawResponseItem(_)
|
||||
| EventMsg::ItemStarted(_)
|
||||
| EventMsg::ItemCompleted(_)
|
||||
|
|
|
|||
|
|
@ -1921,6 +1921,7 @@ impl ChatWidget {
|
|||
EventMsg::ExitedReviewMode(review) => self.on_exited_review_mode(review),
|
||||
EventMsg::ContextCompacted(_) => self.on_agent_message("Context compacted".to_owned()),
|
||||
EventMsg::RawResponseItem(_)
|
||||
| EventMsg::ThreadRolledBack(_)
|
||||
| EventMsg::ItemStarted(_)
|
||||
| EventMsg::ItemCompleted(_)
|
||||
| EventMsg::AgentMessageContentDelta(_)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue