diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index e3369f822..c6e11b137 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -815,7 +815,7 @@ impl Session { match conversation_history { InitialHistory::New => { // Build and record initial items (user instructions + environment context) - let items = self.build_initial_context(&turn_context); + let items = self.build_initial_context(&turn_context).await; self.record_conversation_items(&turn_context, &items).await; // Ensure initial items are visible to immediate readers (e.g., tests, forks). self.flush_rollout().await; @@ -853,8 +853,9 @@ impl Session { } // Always add response items to conversation history - let reconstructed_history = - self.reconstruct_history_from_rollout(&turn_context, &rollout_items); + let reconstructed_history = self + .reconstruct_history_from_rollout(&turn_context, &rollout_items) + .await; if !reconstructed_history.is_empty() { self.record_into_history(&reconstructed_history, &turn_context) .await; @@ -873,7 +874,7 @@ impl Session { } // Append the current session's initial context after the reconstructed history. - let initial_context = self.build_initial_context(&turn_context); + let initial_context = self.build_initial_context(&turn_context).await; self.record_conversation_items(&turn_context, &initial_context) .await; // Flush after seeding history and any persisted rollout copy. @@ -1061,6 +1062,50 @@ impl Session { ) } + fn build_collaboration_mode_update_item( + &self, + previous_collaboration_mode: &CollaborationMode, + next_collaboration_mode: Option<&CollaborationMode>, + ) -> Option { + if let Some(next_mode) = next_collaboration_mode { + if previous_collaboration_mode == next_mode { + return None; + } + // If the next mode has empty developer instructions, this returns None and we emit no + // update, so prior collaboration instructions remain in the prompt history. + Some(DeveloperInstructions::from_collaboration_mode(next_mode)?.into()) + } else { + None + } + } + + fn build_settings_update_items( + &self, + previous_context: Option<&Arc>, + current_context: &TurnContext, + previous_collaboration_mode: &CollaborationMode, + next_collaboration_mode: Option<&CollaborationMode>, + ) -> Vec { + let mut update_items = Vec::new(); + if let Some(env_item) = + self.build_environment_update_item(previous_context, current_context) + { + update_items.push(env_item); + } + if let Some(permissions_item) = + self.build_permissions_update_item(previous_context, current_context) + { + update_items.push(permissions_item); + } + if let Some(collaboration_mode_item) = self.build_collaboration_mode_update_item( + previous_collaboration_mode, + next_collaboration_mode, + ) { + update_items.push(collaboration_mode_item); + } + update_items + } + /// Persist the event to rollout and send it to clients. pub(crate) async fn send_event(&self, turn_context: &TurnContext, msg: EventMsg) { let legacy_source = msg.clone(); @@ -1299,7 +1344,7 @@ impl Session { self.send_raw_response_items(turn_context, items).await; } - fn reconstruct_history_from_rollout( + async fn reconstruct_history_from_rollout( &self, turn_context: &TurnContext, rollout_items: &[RolloutItem], @@ -1319,7 +1364,7 @@ impl Session { } else { let user_messages = collect_user_messages(history.raw_items()); let rebuilt = compact::build_compacted_history( - self.build_initial_context(turn_context), + self.build_initial_context(turn_context).await, &user_messages, &compacted.message, ); @@ -1389,7 +1434,10 @@ impl Session { } } - pub(crate) fn build_initial_context(&self, turn_context: &TurnContext) -> Vec { + pub(crate) async fn build_initial_context( + &self, + turn_context: &TurnContext, + ) -> Vec { let mut items = Vec::::with_capacity(4); let shell = self.user_shell(); items.push( @@ -1403,6 +1451,16 @@ impl Session { if let Some(developer_instructions) = turn_context.developer_instructions.as_deref() { items.push(DeveloperInstructions::new(developer_instructions.to_string()).into()); } + // Add developer instructions from collaboration_mode if they exist and are non-empty + let collaboration_mode = { + let state = self.state.lock().await; + state.session_configuration.collaboration_mode.clone() + }; + if let Some(collab_instructions) = + DeveloperInstructions::from_collaboration_mode(&collaboration_mode) + { + items.push(collab_instructions.into()); + } if let Some(user_instructions) = turn_context.user_instructions.as_deref() { items.push( UserInstructions { @@ -1984,6 +2042,18 @@ mod handlers { sub_id: String, updates: SessionSettingsUpdate, ) { + let previous_context = sess + .new_default_turn_with_sub_id(sess.next_internal_sub_id()) + .await; + let previous_collaboration_mode = sess + .state + .lock() + .await + .session_configuration + .collaboration_mode + .clone(); + let next_collaboration_mode = updates.collaboration_mode.clone(); + if let Err(err) = sess.update_settings(updates).await { sess.send_event_raw(Event { id: sub_id, @@ -1993,6 +2063,19 @@ mod handlers { }), }) .await; + return; + } + + let current_context = sess.new_default_turn_with_sub_id(sub_id).await; + let update_items = sess.build_settings_update_items( + Some(&previous_context), + ¤t_context, + &previous_collaboration_mode, + next_collaboration_mode.as_ref(), + ); + if !update_items.is_empty() { + sess.record_conversation_items(¤t_context, &update_items) + .await; } } @@ -2046,6 +2129,14 @@ mod handlers { _ => unreachable!(), }; + let previous_collaboration_mode = sess + .state + .lock() + .await + .session_configuration + .collaboration_mode + .clone(); + let next_collaboration_mode = updates.collaboration_mode.clone(); let Ok(current_context) = sess.new_turn_with_sub_id(sub_id, updates).await else { // new_turn_with_sub_id already emits the error event. return; @@ -2057,17 +2148,12 @@ mod handlers { // Attempt to inject input into current task if let Err(items) = sess.inject_input(items).await { - let mut update_items = Vec::new(); - if let Some(env_item) = - sess.build_environment_update_item(previous_context.as_ref(), ¤t_context) - { - update_items.push(env_item); - } - if let Some(permissions_item) = - sess.build_permissions_update_item(previous_context.as_ref(), ¤t_context) - { - update_items.push(permissions_item); - } + let update_items = sess.build_settings_update_items( + previous_context.as_ref(), + ¤t_context, + &previous_collaboration_mode, + next_collaboration_mode.as_ref(), + ); if !update_items.is_empty() { sess.record_conversation_items(¤t_context, &update_items) .await; @@ -3178,9 +3264,11 @@ mod tests { #[tokio::test] async fn reconstruct_history_matches_live_compactions() { let (session, turn_context) = make_session_and_context().await; - let (rollout_items, expected) = sample_rollout(&session, &turn_context); + let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; - let reconstructed = session.reconstruct_history_from_rollout(&turn_context, &rollout_items); + let reconstructed = session + .reconstruct_history_from_rollout(&turn_context, &rollout_items) + .await; assert_eq!(expected, reconstructed); } @@ -3188,7 +3276,7 @@ mod tests { #[tokio::test] async fn record_initial_history_reconstructs_resumed_transcript() { let (session, turn_context) = make_session_and_context().await; - let (rollout_items, mut expected) = sample_rollout(&session, &turn_context); + let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; session .record_initial_history(InitialHistory::Resumed(ResumedHistory { @@ -3198,7 +3286,7 @@ mod tests { })) .await; - expected.extend(session.build_initial_context(&turn_context)); + expected.extend(session.build_initial_context(&turn_context).await); let history = session.state.lock().await.clone_history(); assert_eq!(expected, history.raw_items()); } @@ -3206,7 +3294,7 @@ mod tests { #[tokio::test] async fn record_initial_history_seeds_token_info_from_rollout() { let (session, turn_context) = make_session_and_context().await; - let (mut rollout_items, _expected) = sample_rollout(&session, &turn_context); + let (mut rollout_items, _expected) = sample_rollout(&session, &turn_context).await; let info1 = TokenUsageInfo { total_token_usage: TokenUsage { @@ -3283,13 +3371,13 @@ mod tests { #[tokio::test] async fn record_initial_history_reconstructs_forked_transcript() { let (session, turn_context) = make_session_and_context().await; - let (rollout_items, mut expected) = sample_rollout(&session, &turn_context); + let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; session .record_initial_history(InitialHistory::Forked(rollout_items)) .await; - expected.extend(session.build_initial_context(&turn_context)); + expected.extend(session.build_initial_context(&turn_context).await); let history = session.state.lock().await.clone_history(); assert_eq!(expected, history.raw_items()); } @@ -3298,7 +3386,7 @@ mod tests { async fn thread_rollback_drops_last_turn_from_history() { let (sess, tc, rx) = make_session_and_context_with_rx().await; - let initial_context = sess.build_initial_context(tc.as_ref()); + let initial_context = sess.build_initial_context(tc.as_ref()).await; sess.record_into_history(&initial_context, tc.as_ref()) .await; @@ -3355,7 +3443,7 @@ mod tests { async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns() { let (sess, tc, rx) = make_session_and_context_with_rx().await; - let initial_context = sess.build_initial_context(tc.as_ref()); + let initial_context = sess.build_initial_context(tc.as_ref()).await; sess.record_into_history(&initial_context, tc.as_ref()) .await; @@ -3381,7 +3469,7 @@ mod tests { async fn thread_rollback_fails_when_turn_in_progress() { let (sess, tc, rx) = make_session_and_context_with_rx().await; - let initial_context = sess.build_initial_context(tc.as_ref()); + let initial_context = sess.build_initial_context(tc.as_ref()).await; sess.record_into_history(&initial_context, tc.as_ref()) .await; @@ -3402,7 +3490,7 @@ mod tests { async fn thread_rollback_fails_when_num_turns_is_zero() { let (sess, tc, rx) = make_session_and_context_with_rx().await; - let initial_context = sess.build_initial_context(tc.as_ref()); + let initial_context = sess.build_initial_context(tc.as_ref()).await; sess.record_into_history(&initial_context, tc.as_ref()) .await; @@ -4188,14 +4276,14 @@ mod tests { } } - fn sample_rollout( + async fn sample_rollout( session: &Session, turn_context: &TurnContext, ) -> (Vec, Vec) { let mut rollout_items = Vec::new(); let mut live_history = ContextManager::new(); - let initial_context = session.build_initial_context(turn_context); + let initial_context = session.build_initial_context(turn_context).await; for item in &initial_context { rollout_items.push(RolloutItem::ResponseItem(item.clone())); } @@ -4225,7 +4313,7 @@ mod tests { let snapshot1 = live_history.clone().for_prompt(); let user_messages1 = collect_user_messages(&snapshot1); let rebuilt1 = compact::build_compacted_history( - session.build_initial_context(turn_context), + session.build_initial_context(turn_context).await, &user_messages1, summary1, ); @@ -4259,7 +4347,7 @@ mod tests { let snapshot2 = live_history.clone().for_prompt(); let user_messages2 = collect_user_messages(&snapshot2); let rebuilt2 = compact::build_compacted_history( - session.build_initial_context(turn_context), + session.build_initial_context(turn_context).await, &user_messages2, summary2, ); @@ -4277,7 +4365,7 @@ mod tests { }], }; live_history.record_items(std::iter::once(&user3), turn_context.truncation_policy); - rollout_items.push(RolloutItem::ResponseItem(user3.clone())); + rollout_items.push(RolloutItem::ResponseItem(user3)); let assistant3 = ResponseItem::Message { id: None, @@ -4287,7 +4375,7 @@ mod tests { }], }; live_history.record_items(std::iter::once(&assistant3), turn_context.truncation_policy); - rollout_items.push(RolloutItem::ResponseItem(assistant3.clone())); + rollout_items.push(RolloutItem::ResponseItem(assistant3)); (rollout_items, live_history.for_prompt()) } diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index 4dc56f10d..250b91415 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -167,7 +167,7 @@ async fn run_compact_task_inner( let summary_text = format!("{SUMMARY_PREFIX}\n{summary_suffix}"); let user_messages = collect_user_messages(history_items); - let initial_context = sess.build_initial_context(turn_context.as_ref()); + let initial_context = sess.build_initial_context(turn_context.as_ref()).await; let mut new_history = build_compacted_history(initial_context, &user_messages, &summary_text); let ghost_snapshots: Vec = history_items .iter() diff --git a/codex-rs/core/src/rollout/truncation.rs b/codex-rs/core/src/rollout/truncation.rs index 1f70be46f..0f72cfc50 100644 --- a/codex-rs/core/src/rollout/truncation.rs +++ b/codex-rs/core/src/rollout/truncation.rs @@ -189,7 +189,7 @@ mod tests { #[tokio::test] async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() { let (session, turn_context) = make_session_and_context().await; - let mut items = session.build_initial_context(&turn_context); + let mut items = session.build_initial_context(&turn_context).await; items.push(user_msg("feature request")); items.push(assistant_msg("ack")); items.push(user_msg("second question")); diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index 8d02828d5..f533e608f 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -467,7 +467,7 @@ mod tests { #[tokio::test] async fn ignores_session_prefix_messages_when_truncating() { let (session, turn_context) = make_session_and_context().await; - let mut items = session.build_initial_context(&turn_context); + let mut items = session.build_initial_context(&turn_context).await; items.push(user_msg("feature request")); items.push(assistant_msg("ack")); items.push(user_msg("second question")); diff --git a/codex-rs/core/tests/suite/collaboration_instructions.rs b/codex-rs/core/tests/suite/collaboration_instructions.rs new file mode 100644 index 000000000..c0df7e5ec --- /dev/null +++ b/codex-rs/core/tests/suite/collaboration_instructions.rs @@ -0,0 +1,500 @@ +use anyhow::Result; +use codex_core::protocol::COLLABORATION_MODE_CLOSE_TAG; +use codex_core::protocol::COLLABORATION_MODE_OPEN_TAG; +use codex_core::protocol::EventMsg; +use codex_core::protocol::Op; +use codex_protocol::config_types::CollaborationMode; +use codex_protocol::config_types::Settings; +use codex_protocol::user_input::UserInput; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::mount_sse_once; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use pretty_assertions::assert_eq; +use serde_json::Value; + +fn sse_completed(id: &str) -> String { + sse(vec![ev_response_created(id), ev_completed(id)]) +} + +fn collab_mode_with_instructions(instructions: Option<&str>) -> CollaborationMode { + CollaborationMode::Custom(Settings { + model: "gpt-5.1".to_string(), + reasoning_effort: None, + developer_instructions: instructions.map(str::to_string), + }) +} + +fn developer_texts(input: &[Value]) -> Vec { + input + .iter() + .filter_map(|item| { + let role = item.get("role")?.as_str()?; + if role != "developer" { + return None; + } + let text = item + .get("content")? + .as_array()? + .first()? + .get("text")? + .as_str()?; + Some(text.to_string()) + }) + .collect() +} + +fn collab_xml(text: &str) -> String { + format!("{COLLABORATION_MODE_OPEN_TAG}{text}{COLLABORATION_MODE_CLOSE_TAG}") +} + +fn count_exact(texts: &[String], target: &str) -> usize { + texts.iter().filter(|text| text.as_str() == target).count() +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn no_collaboration_instructions_by_default() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let req = mount_sse_once(&server, sse_completed("resp-1")).await; + + let test = test_codex().build(&server).await?; + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req.single_request().input(); + let dev_texts = developer_texts(&input); + assert_eq!(dev_texts.len(), 1); + assert!(dev_texts[0].contains("`approval_policy`")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn user_input_includes_collaboration_instructions_after_override() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let req = mount_sse_once(&server, sse_completed("resp-1")).await; + + let test = test_codex().build(&server).await?; + + let collab_text = "collab instructions"; + let collaboration_mode = collab_mode_with_instructions(Some(collab_text)); + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collaboration_mode), + }) + .await?; + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req.single_request().input(); + let dev_texts = developer_texts(&input); + let collab_text = collab_xml(collab_text); + assert_eq!(count_exact(&dev_texts, &collab_text), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn collaboration_instructions_added_on_user_turn() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let req = mount_sse_once(&server, sse_completed("resp-1")).await; + + let test = test_codex().build(&server).await?; + let collab_text = "turn instructions"; + let collaboration_mode = collab_mode_with_instructions(Some(collab_text)); + + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + cwd: test.config.cwd.clone(), + approval_policy: test.config.approval_policy.value(), + sandbox_policy: test.config.sandbox_policy.get().clone(), + model: test.session_configured.model.clone(), + effort: None, + summary: test.config.model_reasoning_summary, + collaboration_mode: Some(collaboration_mode), + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req.single_request().input(); + let dev_texts = developer_texts(&input); + let collab_text = collab_xml(collab_text); + assert_eq!(count_exact(&dev_texts, &collab_text), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn override_then_user_turn_uses_updated_collaboration_instructions() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let req = mount_sse_once(&server, sse_completed("resp-1")).await; + + let test = test_codex().build(&server).await?; + let collab_text = "override instructions"; + let collaboration_mode = collab_mode_with_instructions(Some(collab_text)); + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collaboration_mode), + }) + .await?; + + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + cwd: test.config.cwd.clone(), + approval_policy: test.config.approval_policy.value(), + sandbox_policy: test.config.sandbox_policy.get().clone(), + model: test.session_configured.model.clone(), + effort: None, + summary: test.config.model_reasoning_summary, + collaboration_mode: None, + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req.single_request().input(); + let dev_texts = developer_texts(&input); + let collab_text = collab_xml(collab_text); + assert_eq!(count_exact(&dev_texts, &collab_text), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn user_turn_overrides_collaboration_instructions_after_override() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let req = mount_sse_once(&server, sse_completed("resp-1")).await; + + let test = test_codex().build(&server).await?; + let base_text = "base instructions"; + let base_mode = collab_mode_with_instructions(Some(base_text)); + let turn_text = "turn override"; + let turn_mode = collab_mode_with_instructions(Some(turn_text)); + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(base_mode), + }) + .await?; + + test.codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + cwd: test.config.cwd.clone(), + approval_policy: test.config.approval_policy.value(), + sandbox_policy: test.config.sandbox_policy.get().clone(), + model: test.session_configured.model.clone(), + effort: None, + summary: test.config.model_reasoning_summary, + collaboration_mode: Some(turn_mode), + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req.single_request().input(); + let dev_texts = developer_texts(&input); + let base_text = collab_xml(base_text); + let turn_text = collab_xml(turn_text); + assert_eq!(count_exact(&dev_texts, &base_text), 1); + assert_eq!(count_exact(&dev_texts, &turn_text), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn collaboration_mode_update_emits_new_instruction_message() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let _req1 = mount_sse_once(&server, sse_completed("resp-1")).await; + let req2 = mount_sse_once(&server, sse_completed("resp-2")).await; + + let test = test_codex().build(&server).await?; + let first_text = "first instructions"; + let second_text = "second instructions"; + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collab_mode_with_instructions(Some(first_text))), + }) + .await?; + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello 1".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collab_mode_with_instructions(Some(second_text))), + }) + .await?; + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello 2".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req2.single_request().input(); + let dev_texts = developer_texts(&input); + let first_text = collab_xml(first_text); + let second_text = collab_xml(second_text); + assert_eq!(count_exact(&dev_texts, &first_text), 1); + assert_eq!(count_exact(&dev_texts, &second_text), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn collaboration_mode_update_noop_does_not_append() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let _req1 = mount_sse_once(&server, sse_completed("resp-1")).await; + let req2 = mount_sse_once(&server, sse_completed("resp-2")).await; + + let test = test_codex().build(&server).await?; + let collab_text = "same instructions"; + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collab_mode_with_instructions(Some(collab_text))), + }) + .await?; + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello 1".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collab_mode_with_instructions(Some(collab_text))), + }) + .await?; + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello 2".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req2.single_request().input(); + let dev_texts = developer_texts(&input); + let collab_text = collab_xml(collab_text); + assert_eq!(count_exact(&dev_texts, &collab_text), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn resume_replays_collaboration_instructions() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let _req1 = mount_sse_once(&server, sse_completed("resp-1")).await; + let req2 = mount_sse_once(&server, sse_completed("resp-2")).await; + + let mut builder = test_codex(); + let initial = builder.build(&server).await?; + let rollout_path = initial.session_configured.rollout_path.clone(); + let home = initial.home.clone(); + + let collab_text = "resume instructions"; + initial + .codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collab_mode_with_instructions(Some(collab_text))), + }) + .await?; + + initial + .codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&initial.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let resumed = builder.resume(&server, home, rollout_path).await?; + resumed + .codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "after resume".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&resumed.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req2.single_request().input(); + let dev_texts = developer_texts(&input); + let collab_text = collab_xml(collab_text); + assert_eq!(count_exact(&dev_texts, &collab_text), 1); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn empty_collaboration_instructions_are_ignored() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let req = mount_sse_once(&server, sse_completed("resp-1")).await; + + let test = test_codex().build(&server).await?; + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collab_mode_with_instructions(Some(""))), + }) + .await?; + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "hello".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + let input = req.single_request().input(); + let dev_texts = developer_texts(&input); + assert_eq!(dev_texts.len(), 1); + let collab_text = collab_xml(""); + assert_eq!(count_exact(&dev_texts, &collab_text), 0); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 0d22e8c39..b8066db38 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -24,6 +24,7 @@ mod cli_stream; mod client; mod client_websockets; mod codex_delegate; +mod collaboration_instructions; mod compact; mod compact_remote; mod compact_resume_fork; diff --git a/codex-rs/core/tests/suite/override_updates.rs b/codex-rs/core/tests/suite/override_updates.rs new file mode 100644 index 000000000..ddacd97ab --- /dev/null +++ b/codex-rs/core/tests/suite/override_updates.rs @@ -0,0 +1,216 @@ +use anyhow::Result; +use codex_core::config::Constrained; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::COLLABORATION_MODE_CLOSE_TAG; +use codex_core::protocol::COLLABORATION_MODE_OPEN_TAG; +use codex_core::protocol::EventMsg; +use codex_core::protocol::Op; +use codex_core::protocol::RolloutItem; +use codex_core::protocol::RolloutLine; +use codex_core::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG; +use codex_protocol::config_types::CollaborationMode; +use codex_protocol::config_types::Settings; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use pretty_assertions::assert_eq; +use std::collections::HashSet; +use std::path::Path; +use std::time::Duration; +use tempfile::TempDir; + +fn collab_mode_with_instructions(instructions: Option<&str>) -> CollaborationMode { + CollaborationMode::Custom(Settings { + model: "gpt-5.1".to_string(), + reasoning_effort: None, + developer_instructions: instructions.map(str::to_string), + }) +} + +fn collab_xml(text: &str) -> String { + format!("{COLLABORATION_MODE_OPEN_TAG}{text}{COLLABORATION_MODE_CLOSE_TAG}") +} + +async fn read_rollout_text(path: &Path) -> anyhow::Result { + for _ in 0..50 { + if path.exists() + && let Ok(text) = std::fs::read_to_string(path) + && !text.trim().is_empty() + { + return Ok(text); + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + Ok(std::fs::read_to_string(path)?) +} + +fn rollout_developer_texts(text: &str) -> Vec { + let mut texts = Vec::new(); + for line in text.lines() { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + let rollout: RolloutLine = match serde_json::from_str(trimmed) { + Ok(rollout) => rollout, + Err(_) => continue, + }; + if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = + rollout.item + && role == "developer" + { + for item in content { + if let ContentItem::InputText { text } = item { + texts.push(text); + } + } + } + } + texts +} + +fn rollout_environment_texts(text: &str) -> Vec { + let mut texts = Vec::new(); + for line in text.lines() { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + let rollout: RolloutLine = match serde_json::from_str(trimmed) { + Ok(rollout) => rollout, + Err(_) => continue, + }; + if let RolloutItem::ResponseItem(ResponseItem::Message { role, content, .. }) = + rollout.item + && role == "user" + { + for item in content { + if let ContentItem::InputText { text } = item + && text.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG) + { + texts.push(text); + } + } + } + } + texts +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn override_turn_context_records_permissions_update() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.approval_policy = Constrained::allow_any(AskForApproval::OnRequest); + }); + let test = builder.build(&server).await?; + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: Some(AskForApproval::Never), + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: None, + }) + .await?; + + test.codex.submit(Op::Shutdown).await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await; + + let rollout_path = test.codex.rollout_path(); + let rollout_text = read_rollout_text(&rollout_path).await?; + let developer_texts = rollout_developer_texts(&rollout_text); + let approval_texts: Vec<&String> = developer_texts + .iter() + .filter(|text| text.contains("`approval_policy`")) + .collect(); + assert!( + approval_texts + .iter() + .any(|text| text.contains("`approval_policy` is `never`")), + "expected updated approval policy instructions in rollout" + ); + let unique: HashSet<&String> = approval_texts.iter().copied().collect(); + assert_eq!(unique.len(), 2); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn override_turn_context_records_environment_update() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let test = test_codex().build(&server).await?; + let new_cwd = TempDir::new()?; + + test.codex + .submit(Op::OverrideTurnContext { + cwd: Some(new_cwd.path().to_path_buf()), + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: None, + }) + .await?; + + test.codex.submit(Op::Shutdown).await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await; + + let rollout_path = test.codex.rollout_path(); + let rollout_text = read_rollout_text(&rollout_path).await?; + let env_texts = rollout_environment_texts(&rollout_text); + let new_cwd_text = new_cwd.path().display().to_string(); + assert!( + env_texts.iter().any(|text| text.contains(&new_cwd_text)), + "expected environment update with new cwd in rollout" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn override_turn_context_records_collaboration_update() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let test = test_codex().build(&server).await?; + let collab_text = "override collaboration instructions"; + let collaboration_mode = collab_mode_with_instructions(Some(collab_text)); + + test.codex + .submit(Op::OverrideTurnContext { + cwd: None, + approval_policy: None, + sandbox_policy: None, + model: None, + effort: None, + summary: None, + collaboration_mode: Some(collaboration_mode), + }) + .await?; + + test.codex.submit(Op::Shutdown).await?; + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await; + + let rollout_path = test.codex.rollout_path(); + let rollout_text = read_rollout_text(&rollout_path).await?; + let developer_texts = rollout_developer_texts(&rollout_text); + let collab_text = collab_xml(collab_text); + let collab_count = developer_texts + .iter() + .filter(|text| text.as_str() == collab_text.as_str()) + .count(); + assert_eq!(collab_count, 1); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/permissions_messages.rs b/codex-rs/core/tests/suite/permissions_messages.rs index 3e54fa29a..af61e54a5 100644 --- a/codex-rs/core/tests/suite/permissions_messages.rs +++ b/codex-rs/core/tests/suite/permissions_messages.rs @@ -132,7 +132,7 @@ async fn permissions_message_added_on_override_change() -> Result<()> { let permissions_2 = permissions_texts(input2); assert_eq!(permissions_1.len(), 1); - assert_eq!(permissions_2.len(), 2); + assert_eq!(permissions_2.len(), 3); let unique = permissions_2.into_iter().collect::>(); assert_eq!(unique.len(), 2); @@ -257,7 +257,7 @@ async fn resume_replays_permissions_messages() -> Result<()> { let body3 = req3.single_request().body_json(); let input = body3["input"].as_array().expect("input array"); let permissions = permissions_texts(input); - assert_eq!(permissions.len(), 3); + assert_eq!(permissions.len(), 4); let unique = permissions.into_iter().collect::>(); assert_eq!(unique.len(), 2); @@ -321,7 +321,7 @@ async fn resume_and_fork_append_permissions_messages() -> Result<()> { let body2 = req2.single_request().body_json(); let input2 = body2["input"].as_array().expect("input array"); let permissions_base = permissions_texts(input2); - assert_eq!(permissions_base.len(), 2); + assert_eq!(permissions_base.len(), 3); builder = builder.with_config(|config| { config.approval_policy = Constrained::allow_any(AskForApproval::UnlessTrusted); diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index 093932512..4e87a4364 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -379,15 +379,18 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() -> an "content": [ { "type": "input_text", "text": "hello 2" } ] }); let expected_permissions_msg = body1["input"][0].clone(); - // After overriding the turn context, emit a new permissions message. let body1_input = body1["input"].as_array().expect("input array"); + // After overriding the turn context, emit two updated permissions messages. let expected_permissions_msg_2 = body2["input"][body1_input.len()].clone(); + let expected_permissions_msg_3 = body2["input"][body1_input.len() + 1].clone(); assert_ne!( expected_permissions_msg_2, expected_permissions_msg, "expected updated permissions message after override" ); - let mut expected_body2 = body1["input"].as_array().expect("input array").to_vec(); + assert_eq!(expected_permissions_msg_2, expected_permissions_msg_3); + let mut expected_body2 = body1_input.to_vec(); expected_body2.push(expected_permissions_msg_2); + expected_body2.push(expected_permissions_msg_3); expected_body2.push(expected_user_message_2); assert_eq!(body2["input"], serde_json::Value::Array(expected_body2)); diff --git a/codex-rs/protocol/src/models.rs b/codex-rs/protocol/src/models.rs index d2c9ddd56..4d54ed08f 100644 --- a/codex-rs/protocol/src/models.rs +++ b/codex-rs/protocol/src/models.rs @@ -10,8 +10,11 @@ use serde::Serialize; use serde::ser::Serializer; use ts_rs::TS; +use crate::config_types::CollaborationMode; use crate::config_types::SandboxMode; use crate::protocol::AskForApproval; +use crate::protocol::COLLABORATION_MODE_CLOSE_TAG; +use crate::protocol::COLLABORATION_MODE_OPEN_TAG; use crate::protocol::NetworkAccess; use crate::protocol::SandboxPolicy; use crate::protocol::WritableRoot; @@ -230,6 +233,25 @@ impl DeveloperInstructions { ) } + /// Returns developer instructions from a collaboration mode if they exist and are non-empty. + pub fn from_collaboration_mode(collaboration_mode: &CollaborationMode) -> Option { + let settings = match collaboration_mode { + CollaborationMode::Plan(settings) + | CollaborationMode::Collaborate(settings) + | CollaborationMode::Execute(settings) + | CollaborationMode::Custom(settings) => settings, + }; + settings + .developer_instructions + .as_ref() + .filter(|instructions| !instructions.is_empty()) + .map(|instructions| { + DeveloperInstructions::new(format!( + "{COLLABORATION_MODE_OPEN_TAG}{instructions}{COLLABORATION_MODE_CLOSE_TAG}" + )) + }) + } + fn from_permissions_with_network( sandbox_mode: SandboxMode, network_access: NetworkAccess, diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index d5ead442d..1142f182c 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -51,6 +51,8 @@ pub const USER_INSTRUCTIONS_OPEN_TAG: &str = ""; pub const USER_INSTRUCTIONS_CLOSE_TAG: &str = ""; pub const ENVIRONMENT_CONTEXT_OPEN_TAG: &str = ""; pub const ENVIRONMENT_CONTEXT_CLOSE_TAG: &str = ""; +pub const COLLABORATION_MODE_OPEN_TAG: &str = ""; +pub const COLLABORATION_MODE_CLOSE_TAG: &str = ""; pub const USER_MESSAGE_BEGIN: &str = "## My request for Codex:"; /// Submission Queue Entry - requests from user