diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 41cc7771b..bd32dfe02 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -6862,3958 +6862,5 @@ pub(crate) use tests::make_session_and_context_with_rx; pub(crate) use tests::make_session_configuration_for_tests; #[cfg(test)] -mod tests { - use super::*; - use crate::CodexAuth; - use crate::config::ConfigBuilder; - use crate::config::test_config; - use crate::config_loader::ConfigLayerStack; - use crate::config_loader::ConfigLayerStackOrdering; - use crate::config_loader::NetworkConstraints; - use crate::config_loader::RequirementSource; - use crate::config_loader::Sourced; - use crate::exec::ExecToolCallOutput; - use crate::function_tool::FunctionCallError; - use crate::mcp_connection_manager::ToolInfo; - use crate::models_manager::model_info; - use crate::shell::default_user_shell; - use crate::tools::format_exec_output_str; - - use codex_protocol::ThreadId; - use codex_protocol::models::FunctionCallOutputBody; - use codex_protocol::models::FunctionCallOutputPayload; - use tracing::Span; - - use crate::protocol::CompactedItem; - use crate::protocol::CreditsSnapshot; - use crate::protocol::InitialHistory; - use crate::protocol::NetworkApprovalProtocol; - use crate::protocol::RateLimitSnapshot; - use crate::protocol::RateLimitWindow; - use crate::protocol::ResumedHistory; - use crate::protocol::TokenCountEvent; - use crate::protocol::TokenUsage; - use crate::protocol::TokenUsageInfo; - use crate::protocol::TurnCompleteEvent; - use crate::protocol::UserMessageEvent; - use crate::rollout::policy::EventPersistenceMode; - use crate::rollout::recorder::RolloutRecorder; - use crate::rollout::recorder::RolloutRecorderParams; - use crate::state::TaskKind; - use crate::tasks::SessionTask; - use crate::tasks::SessionTaskContext; - use crate::tools::ToolRouter; - use crate::tools::context::ToolInvocation; - use crate::tools::context::ToolOutput; - use crate::tools::context::ToolPayload; - use crate::tools::handlers::ShellHandler; - use crate::tools::handlers::UnifiedExecHandler; - use crate::tools::registry::ToolHandler; - use crate::tools::router::ToolCallSource; - use crate::turn_diff_tracker::TurnDiffTracker; - use codex_app_server_protocol::AppInfo; - use codex_otel::TelemetryAuthMode; - use codex_protocol::models::BaseInstructions; - use codex_protocol::models::ContentItem; - use codex_protocol::models::ResponseInputItem; - use codex_protocol::models::ResponseItem; - use codex_protocol::openai_models::ModelsResponse; - use codex_protocol::protocol::ConversationAudioParams; - use codex_protocol::protocol::RealtimeAudioFrame; - use codex_protocol::protocol::Submission; - use codex_protocol::protocol::W3cTraceContext; - use opentelemetry::trace::TraceContextExt; - use opentelemetry::trace::TraceId; - use opentelemetry::trace::TracerProvider as _; - use opentelemetry_sdk::trace::SdkTracerProvider; - use std::path::Path; - use std::time::Duration; - use tokio::time::sleep; - use tracing_opentelemetry::OpenTelemetrySpanExt; - use tracing_subscriber::prelude::*; - - use codex_protocol::mcp::CallToolResult as McpCallToolResult; - use pretty_assertions::assert_eq; - use rmcp::model::JsonObject; - use rmcp::model::Tool; - use serde::Deserialize; - use serde_json::json; - use std::path::PathBuf; - use std::sync::Arc; - use std::sync::Once; - use std::time::Duration as StdDuration; - - struct InstructionsTestCase { - slug: &'static str, - expects_apply_patch_instructions: bool, - } - - fn user_message(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - fn assistant_message(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - fn skill_message(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - fn developer_input_texts(items: &[ResponseItem]) -> Vec<&str> { - items - .iter() - .filter_map(|item| match item { - ResponseItem::Message { role, content, .. } if role == "developer" => { - Some(content.as_slice()) - } - _ => None, - }) - .flat_map(|content| content.iter()) - .filter_map(|item| match item { - ContentItem::InputText { text } => Some(text.as_str()), - _ => None, - }) - .collect() - } - - fn make_connector(id: &str, name: &str) -> AppInfo { - AppInfo { - id: id.to_string(), - name: name.to_string(), - description: None, - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - branding: None, - app_metadata: None, - labels: None, - install_url: None, - is_accessible: true, - is_enabled: true, - plugin_display_names: Vec::new(), - } - } - - #[test] - fn assistant_message_stream_parsers_can_be_seeded_from_output_item_added_text() { - let mut parsers = AssistantMessageStreamParsers::new(false); - let item_id = "msg-1"; - - let seeded = parsers.seed_item_text(item_id, "hello doc"); - let parsed = parsers.parse_delta(item_id, "1 world"); - let tail = parsers.finish_item(item_id); - - assert_eq!(seeded.visible_text, "hello "); - assert_eq!(seeded.citations, Vec::::new()); - assert_eq!(parsed.visible_text, " world"); - assert_eq!(parsed.citations, vec!["doc1".to_string()]); - assert_eq!(tail.visible_text, ""); - assert_eq!(tail.citations, Vec::::new()); - } - - #[test] - fn assistant_message_stream_parsers_seed_buffered_prefix_stays_out_of_finish_tail() { - let mut parsers = AssistantMessageStreamParsers::new(false); - let item_id = "msg-1"; - - let seeded = parsers.seed_item_text(item_id, "hello doc world"); - let tail = parsers.finish_item(item_id); - - assert_eq!(seeded.visible_text, "hello "); - assert_eq!(seeded.citations, Vec::::new()); - assert_eq!(parsed.visible_text, " world"); - assert_eq!(parsed.citations, vec!["doc".to_string()]); - assert_eq!(tail.visible_text, ""); - assert_eq!(tail.citations, Vec::::new()); - } - - #[test] - fn assistant_message_stream_parsers_seed_plan_parser_across_added_and_delta_boundaries() { - let mut parsers = AssistantMessageStreamParsers::new(true); - let item_id = "msg-1"; - - let seeded = parsers.seed_item_text(item_id, "Intro\n\n- step\n\nOutro"); - let tail = parsers.finish_item(item_id); - - assert_eq!(seeded.visible_text, "Intro\n"); - assert_eq!( - seeded.plan_segments, - vec![ProposedPlanSegment::Normal("Intro\n".to_string())] - ); - assert_eq!(parsed.visible_text, "Outro"); - assert_eq!( - parsed.plan_segments, - vec![ - ProposedPlanSegment::ProposedPlanStart, - ProposedPlanSegment::ProposedPlanDelta("- step\n".to_string()), - ProposedPlanSegment::ProposedPlanEnd, - ProposedPlanSegment::Normal("Outro".to_string()), - ] - ); - assert_eq!(tail.visible_text, ""); - assert!(tail.plan_segments.is_empty()); - } - - fn make_mcp_tool( - server_name: &str, - tool_name: &str, - connector_id: Option<&str>, - connector_name: Option<&str>, - ) -> ToolInfo { - ToolInfo { - server_name: server_name.to_string(), - tool_name: tool_name.to_string(), - tool: Tool { - name: tool_name.to_string().into(), - title: None, - description: Some(format!("Test tool: {tool_name}").into()), - input_schema: Arc::new(JsonObject::default()), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }, - connector_id: connector_id.map(str::to_string), - connector_name: connector_name.map(str::to_string), - plugin_display_names: Vec::new(), - } - } - - fn function_call_rollout_item(name: &str, call_id: &str) -> RolloutItem { - RolloutItem::ResponseItem(ResponseItem::FunctionCall { - id: None, - name: name.to_string(), - arguments: "{}".to_string(), - call_id: call_id.to_string(), - }) - } - - fn function_call_output_rollout_item(call_id: &str, output: &str) -> RolloutItem { - RolloutItem::ResponseItem(ResponseItem::FunctionCallOutput { - call_id: call_id.to_string(), - output: FunctionCallOutputPayload::from_text(output.to_string()), - }) - } - - #[test] - fn validated_network_policy_amendment_host_allows_normalized_match() { - let amendment = NetworkPolicyAmendment { - host: "ExAmPlE.Com.:443".to_string(), - action: NetworkPolicyRuleAction::Allow, - }; - let context = NetworkApprovalContext { - host: "example.com".to_string(), - protocol: NetworkApprovalProtocol::Https, - }; - - let host = Session::validated_network_policy_amendment_host(&amendment, &context) - .expect("normalized hosts should match"); - - assert_eq!(host, "example.com"); - } - - #[test] - fn validated_network_policy_amendment_host_rejects_mismatch() { - let amendment = NetworkPolicyAmendment { - host: "evil.example.com".to_string(), - action: NetworkPolicyRuleAction::Deny, - }; - let context = NetworkApprovalContext { - host: "api.example.com".to_string(), - protocol: NetworkApprovalProtocol::Https, - }; - - let err = Session::validated_network_policy_amendment_host(&amendment, &context) - .expect_err("mismatched hosts should be rejected"); - - let message = err.to_string(); - assert!(message.contains("does not match approved host")); - } - - #[tokio::test] - async fn get_base_instructions_no_user_content() { - let prompt_with_apply_patch_instructions = - include_str!("../prompt_with_apply_patch_instructions.md"); - let models_response: ModelsResponse = - serde_json::from_str(include_str!("../models.json")).expect("valid models.json"); - let model_info_for_slug = |slug: &str, config: &Config| { - let model = models_response - .models - .iter() - .find(|candidate| candidate.slug == slug) - .cloned() - .unwrap_or_else(|| panic!("model slug {slug} is missing from models.json")); - model_info::with_config_overrides(model, config) - }; - let test_cases = vec![ - InstructionsTestCase { - slug: "gpt-5", - expects_apply_patch_instructions: false, - }, - InstructionsTestCase { - slug: "gpt-5.1", - expects_apply_patch_instructions: false, - }, - InstructionsTestCase { - slug: "gpt-5.1-codex", - expects_apply_patch_instructions: false, - }, - InstructionsTestCase { - slug: "gpt-5.1-codex-max", - expects_apply_patch_instructions: false, - }, - ]; - - let (session, _turn_context) = make_session_and_context().await; - let config = test_config(); - - for test_case in test_cases { - let model_info = model_info_for_slug(test_case.slug, &config); - if test_case.expects_apply_patch_instructions { - assert_eq!( - model_info.base_instructions.as_str(), - prompt_with_apply_patch_instructions - ); - } - - { - let mut state = session.state.lock().await; - state.session_configuration.base_instructions = - model_info.base_instructions.clone(); - } - - let base_instructions = session.get_base_instructions().await; - assert_eq!(base_instructions.text, model_info.base_instructions); - } - } - - #[tokio::test] - async fn reload_user_config_layer_updates_effective_apps_config() { - let (session, _turn_context) = make_session_and_context().await; - let codex_home = session.codex_home().await; - std::fs::create_dir_all(&codex_home).expect("create codex home"); - let config_toml_path = codex_home.join(CONFIG_TOML_FILE); - std::fs::write( - &config_toml_path, - "[apps.calendar]\nenabled = false\ndestructive_enabled = false\n", - ) - .expect("write user config"); - - session.reload_user_config_layer().await; - - let config = session.get_config().await; - let apps_toml = config - .config_layer_stack - .effective_config() - .as_table() - .and_then(|table| table.get("apps")) - .cloned() - .expect("apps table"); - let apps = crate::config::types::AppsConfigToml::deserialize(apps_toml) - .expect("deserialize apps config"); - let app = apps - .apps - .get("calendar") - .expect("calendar app config exists"); - - assert!(!app.enabled); - assert_eq!(app.destructive_enabled, Some(false)); - } - - #[test] - fn filter_connectors_for_input_skips_duplicate_slug_mentions() { - let connectors = vec![ - make_connector("one", "Foo Bar"), - make_connector("two", "Foo-Bar"), - ]; - let input = vec![user_message("use $foo-bar")]; - let explicitly_enabled_connectors = HashSet::new(); - let skill_name_counts_lower = HashMap::new(); - - let selected = filter_connectors_for_input( - &connectors, - &input, - &explicitly_enabled_connectors, - &skill_name_counts_lower, - ); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn filter_connectors_for_input_skips_when_skill_name_conflicts() { - let connectors = vec![make_connector("one", "Todoist")]; - let input = vec![user_message("use $todoist")]; - let explicitly_enabled_connectors = HashSet::new(); - let skill_name_counts_lower = HashMap::from([("todoist".to_string(), 1)]); - - let selected = filter_connectors_for_input( - &connectors, - &input, - &explicitly_enabled_connectors, - &skill_name_counts_lower, - ); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn filter_connectors_for_input_skips_disabled_connectors() { - let mut connector = make_connector("calendar", "Calendar"); - connector.is_enabled = false; - let input = vec![user_message("use $calendar")]; - let explicitly_enabled_connectors = HashSet::new(); - let selected = filter_connectors_for_input( - &[connector], - &input, - &explicitly_enabled_connectors, - &HashMap::new(), - ); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn collect_explicit_app_ids_from_skill_items_includes_linked_mentions() { - let connectors = vec![make_connector("calendar", "Calendar")]; - let skill_items = vec![skill_message( - "\ndemo\n/tmp/skills/demo/SKILL.md\nuse [$calendar](app://calendar)\n", - )]; - - let connector_ids = - collect_explicit_app_ids_from_skill_items(&skill_items, &connectors, &HashMap::new()); - - assert_eq!(connector_ids, HashSet::from(["calendar".to_string()])); - } - - #[test] - fn collect_explicit_app_ids_from_skill_items_resolves_unambiguous_plain_mentions() { - let connectors = vec![make_connector("calendar", "Calendar")]; - let skill_items = vec![skill_message( - "\ndemo\n/tmp/skills/demo/SKILL.md\nuse $calendar\n", - )]; - - let connector_ids = - collect_explicit_app_ids_from_skill_items(&skill_items, &connectors, &HashMap::new()); - - assert_eq!(connector_ids, HashSet::from(["calendar".to_string()])); - } - - #[test] - fn collect_explicit_app_ids_from_skill_items_skips_plain_mentions_with_skill_conflicts() { - let connectors = vec![make_connector("calendar", "Calendar")]; - let skill_items = vec![skill_message( - "\ndemo\n/tmp/skills/demo/SKILL.md\nuse $calendar\n", - )]; - let skill_name_counts_lower = HashMap::from([("calendar".to_string(), 1)]); - - let connector_ids = collect_explicit_app_ids_from_skill_items( - &skill_items, - &connectors, - &skill_name_counts_lower, - ); - - assert_eq!(connector_ids, HashSet::::new()); - } - - #[test] - fn non_app_mcp_tools_remain_visible_without_search_selection() { - let mcp_tools = HashMap::from([ - ( - "mcp__codex_apps__calendar_create_event".to_string(), - make_mcp_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_create_event", - Some("calendar"), - Some("Calendar"), - ), - ), - ( - "mcp__rmcp__echo".to_string(), - make_mcp_tool("rmcp", "echo", None, None), - ), - ]); - - let mut selected_mcp_tools = mcp_tools - .iter() - .filter(|(_, tool)| tool.server_name != CODEX_APPS_MCP_SERVER_NAME) - .map(|(name, tool)| (name.clone(), tool.clone())) - .collect::>(); - - let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); - let explicitly_enabled_connectors = HashSet::new(); - let connectors = filter_connectors_for_input( - &connectors, - &[user_message("run echo")], - &explicitly_enabled_connectors, - &HashMap::new(), - ); - let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); - selected_mcp_tools.extend(apps_mcp_tools); - - let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); - tool_names.sort(); - assert_eq!(tool_names, vec!["mcp__rmcp__echo".to_string()]); - } - - #[test] - fn search_tool_selection_keeps_codex_apps_tools_without_mentions() { - let selected_tool_names = vec![ - "mcp__codex_apps__calendar_create_event".to_string(), - "mcp__rmcp__echo".to_string(), - ]; - let mcp_tools = HashMap::from([ - ( - "mcp__codex_apps__calendar_create_event".to_string(), - make_mcp_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_create_event", - Some("calendar"), - Some("Calendar"), - ), - ), - ( - "mcp__rmcp__echo".to_string(), - make_mcp_tool("rmcp", "echo", None, None), - ), - ]); - - let mut selected_mcp_tools = filter_mcp_tools_by_name(&mcp_tools, &selected_tool_names); - let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); - let explicitly_enabled_connectors = HashSet::new(); - let connectors = filter_connectors_for_input( - &connectors, - &[user_message("run the selected tools")], - &explicitly_enabled_connectors, - &HashMap::new(), - ); - let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); - selected_mcp_tools.extend(apps_mcp_tools); - - let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); - tool_names.sort(); - assert_eq!( - tool_names, - vec![ - "mcp__codex_apps__calendar_create_event".to_string(), - "mcp__rmcp__echo".to_string(), - ] - ); - } - - #[test] - fn apps_mentions_add_codex_apps_tools_to_search_selected_set() { - let selected_tool_names = vec!["mcp__rmcp__echo".to_string()]; - let mcp_tools = HashMap::from([ - ( - "mcp__codex_apps__calendar_create_event".to_string(), - make_mcp_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_create_event", - Some("calendar"), - Some("Calendar"), - ), - ), - ( - "mcp__rmcp__echo".to_string(), - make_mcp_tool("rmcp", "echo", None, None), - ), - ]); - - let mut selected_mcp_tools = filter_mcp_tools_by_name(&mcp_tools, &selected_tool_names); - let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); - let explicitly_enabled_connectors = HashSet::new(); - let connectors = filter_connectors_for_input( - &connectors, - &[user_message("use $calendar and then echo the response")], - &explicitly_enabled_connectors, - &HashMap::new(), - ); - let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); - selected_mcp_tools.extend(apps_mcp_tools); - - let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); - tool_names.sort(); - assert_eq!( - tool_names, - vec![ - "mcp__codex_apps__calendar_create_event".to_string(), - "mcp__rmcp__echo".to_string(), - ] - ); - } - - #[test] - fn extract_mcp_tool_selection_from_rollout_reads_search_tool_output() { - let rollout_items = vec![ - function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), - function_call_output_rollout_item( - "search-1", - &json!({ - "active_selected_tools": [ - "mcp__codex_apps__calendar_create_event", - "mcp__codex_apps__calendar_list_events", - ], - }) - .to_string(), - ), - ]; - - let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); - assert_eq!( - selected, - Some(vec![ - "mcp__codex_apps__calendar_create_event".to_string(), - "mcp__codex_apps__calendar_list_events".to_string(), - ]) - ); - } - - #[test] - fn extract_mcp_tool_selection_from_rollout_latest_valid_payload_wins() { - let rollout_items = vec![ - function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), - function_call_output_rollout_item( - "search-1", - &json!({ - "active_selected_tools": ["mcp__codex_apps__calendar_create_event"], - }) - .to_string(), - ), - function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-2"), - function_call_output_rollout_item( - "search-2", - &json!({ - "active_selected_tools": ["mcp__codex_apps__calendar_delete_event"], - }) - .to_string(), - ), - ]; - - let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); - assert_eq!( - selected, - Some(vec!["mcp__codex_apps__calendar_delete_event".to_string(),]) - ); - } - - #[test] - fn extract_mcp_tool_selection_from_rollout_ignores_non_search_and_malformed_payloads() { - let rollout_items = vec![ - function_call_rollout_item("shell", "shell-1"), - function_call_output_rollout_item( - "shell-1", - &json!({ - "active_selected_tools": ["mcp__codex_apps__should_be_ignored"], - }) - .to_string(), - ), - function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), - function_call_output_rollout_item("search-1", "{not-json"), - function_call_output_rollout_item( - "unknown-search-call", - &json!({ - "active_selected_tools": ["mcp__codex_apps__also_ignored"], - }) - .to_string(), - ), - function_call_output_rollout_item( - "search-1", - &json!({ - "active_selected_tools": ["mcp__codex_apps__calendar_list_events"], - }) - .to_string(), - ), - ]; - - let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); - assert_eq!( - selected, - Some(vec!["mcp__codex_apps__calendar_list_events".to_string(),]) - ); - } - - #[test] - fn extract_mcp_tool_selection_from_rollout_returns_none_without_valid_search_output() { - let rollout_items = vec![function_call_rollout_item( - SEARCH_TOOL_BM25_TOOL_NAME, - "search-1", - )]; - let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); - assert_eq!(selected, None); - } - - #[tokio::test] - async fn reconstruct_history_matches_live_compactions() { - let (session, turn_context) = make_session_and_context().await; - let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; - - let reconstruction_turn = session.new_default_turn().await; - let reconstructed = session - .reconstruct_history_from_rollout(reconstruction_turn.as_ref(), &rollout_items) - .await; - - assert_eq!(expected, reconstructed.history); - } - - #[tokio::test] - async fn reconstruct_history_uses_replacement_history_verbatim() { - let (session, turn_context) = make_session_and_context().await; - let summary_item = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }; - let replacement_history = vec![ - summary_item.clone(), - ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "stale developer instructions".to_string(), - }], - end_turn: None, - phase: None, - }, - ]; - let rollout_items = vec![RolloutItem::Compacted(CompactedItem { - message: String::new(), - replacement_history: Some(replacement_history.clone()), - })]; - - let reconstructed = session - .reconstruct_history_from_rollout(&turn_context, &rollout_items) - .await; - - assert_eq!(reconstructed.history, replacement_history); - } - - #[tokio::test] - async fn record_initial_history_reconstructs_resumed_transcript() { - let (session, turn_context) = make_session_and_context().await; - let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; - - session - .record_initial_history(InitialHistory::Resumed(ResumedHistory { - conversation_id: ThreadId::default(), - history: rollout_items, - rollout_path: PathBuf::from("/tmp/resume.jsonl"), - })) - .await; - - let history = session.state.lock().await.clone_history(); - assert_eq!(expected, history.raw_items()); - } - - #[tokio::test] - async fn resumed_history_injects_initial_context_on_first_context_update_only() { - let (session, turn_context) = make_session_and_context().await; - let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; - - session - .record_initial_history(InitialHistory::Resumed(ResumedHistory { - conversation_id: ThreadId::default(), - history: rollout_items, - rollout_path: PathBuf::from("/tmp/resume.jsonl"), - })) - .await; - - let history_before_seed = session.state.lock().await.clone_history(); - assert_eq!(expected, history_before_seed.raw_items()); - - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - expected.extend(session.build_initial_context(&turn_context).await); - let history_after_seed = session.clone_history().await; - assert_eq!(expected, history_after_seed.raw_items()); - - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - let history_after_second_seed = session.clone_history().await; - assert_eq!( - history_after_seed.raw_items(), - history_after_second_seed.raw_items() - ); - } - - #[tokio::test] - async fn record_initial_history_seeds_token_info_from_rollout() { - let (session, turn_context) = make_session_and_context().await; - let (mut rollout_items, _expected) = sample_rollout(&session, &turn_context).await; - - let info1 = TokenUsageInfo { - total_token_usage: TokenUsage { - input_tokens: 10, - cached_input_tokens: 0, - output_tokens: 20, - reasoning_output_tokens: 0, - total_tokens: 30, - }, - last_token_usage: TokenUsage { - input_tokens: 3, - cached_input_tokens: 0, - output_tokens: 4, - reasoning_output_tokens: 0, - total_tokens: 7, - }, - model_context_window: Some(1_000), - }; - let info2 = TokenUsageInfo { - total_token_usage: TokenUsage { - input_tokens: 100, - cached_input_tokens: 50, - output_tokens: 200, - reasoning_output_tokens: 25, - total_tokens: 375, - }, - last_token_usage: TokenUsage { - input_tokens: 10, - cached_input_tokens: 0, - output_tokens: 20, - reasoning_output_tokens: 5, - total_tokens: 35, - }, - model_context_window: Some(2_000), - }; - - rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( - TokenCountEvent { - info: Some(info1), - rate_limits: None, - }, - ))); - rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( - TokenCountEvent { - info: None, - rate_limits: None, - }, - ))); - rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( - TokenCountEvent { - info: Some(info2.clone()), - rate_limits: None, - }, - ))); - rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( - TokenCountEvent { - info: None, - rate_limits: None, - }, - ))); - - session - .record_initial_history(InitialHistory::Resumed(ResumedHistory { - conversation_id: ThreadId::default(), - history: rollout_items, - rollout_path: PathBuf::from("/tmp/resume.jsonl"), - })) - .await; - - let actual = session.state.lock().await.token_info(); - assert_eq!(actual, Some(info2)); - } - - #[tokio::test] - async fn recompute_token_usage_uses_session_base_instructions() { - let (session, turn_context) = make_session_and_context().await; - - let override_instructions = "SESSION_OVERRIDE_INSTRUCTIONS_ONLY".repeat(120); - { - let mut state = session.state.lock().await; - state.session_configuration.base_instructions = override_instructions.clone(); - } - - let item = user_message("hello"); - session - .record_into_history(std::slice::from_ref(&item), &turn_context) - .await; - - let history = session.clone_history().await; - let session_base_instructions = BaseInstructions { - text: override_instructions, - }; - let expected_tokens = history - .estimate_token_count_with_base_instructions(&session_base_instructions) - .expect("estimate with session base instructions"); - let model_estimated_tokens = history - .estimate_token_count(&turn_context) - .expect("estimate with model instructions"); - assert_ne!(expected_tokens, model_estimated_tokens); - - session.recompute_token_usage(&turn_context).await; - - let actual_tokens = session - .state - .lock() - .await - .token_info() - .expect("token info") - .last_token_usage - .total_tokens; - assert_eq!(actual_tokens, expected_tokens.max(0)); - } - - #[tokio::test] - async fn recompute_token_usage_updates_model_context_window() { - let (session, mut turn_context) = make_session_and_context().await; - - { - let mut state = session.state.lock().await; - state.set_token_info(Some(TokenUsageInfo { - total_token_usage: TokenUsage::default(), - last_token_usage: TokenUsage::default(), - model_context_window: Some(258_400), - })); - } - - turn_context.model_info.context_window = Some(128_000); - turn_context.model_info.effective_context_window_percent = 100; - - session.recompute_token_usage(&turn_context).await; - - let actual = session.state.lock().await.token_info().expect("token info"); - assert_eq!(actual.model_context_window, Some(128_000)); - } - - #[tokio::test] - async fn record_initial_history_reconstructs_forked_transcript() { - let (session, turn_context) = make_session_and_context().await; - let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; - - session - .record_initial_history(InitialHistory::Forked(rollout_items)) - .await; - - let reconstruction_turn = session.new_default_turn().await; - expected.extend( - session - .build_initial_context(reconstruction_turn.as_ref()) - .await, - ); - let history = session.state.lock().await.clone_history(); - assert_eq!(expected, history.raw_items()); - } - - #[tokio::test] - async fn record_initial_history_forked_hydrates_previous_turn_settings() { - let (session, turn_context) = make_session_and_context().await; - let previous_model = "forked-rollout-model"; - let previous_context_item = TurnContextItem { - turn_id: Some(turn_context.sub_id.clone()), - trace_id: turn_context.trace_id.clone(), - cwd: turn_context.cwd.clone(), - current_date: turn_context.current_date.clone(), - timezone: turn_context.timezone.clone(), - approval_policy: turn_context.approval_policy.value(), - sandbox_policy: turn_context.sandbox_policy.get().clone(), - network: None, - model: previous_model.to_string(), - personality: turn_context.personality, - collaboration_mode: Some(turn_context.collaboration_mode.clone()), - realtime_active: Some(turn_context.realtime_active), - effort: turn_context.reasoning_effort, - summary: turn_context.reasoning_summary, - user_instructions: None, - developer_instructions: None, - final_output_json_schema: None, - truncation_policy: Some(turn_context.truncation_policy.into()), - }; - let turn_id = previous_context_item - .turn_id - .clone() - .expect("turn context should have turn_id"); - let rollout_items = vec![ - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: turn_id.clone(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage( - codex_protocol::protocol::UserMessageEvent { - message: "forked seed".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - )), - RolloutItem::TurnContext(previous_context_item), - RolloutItem::EventMsg(EventMsg::TurnComplete( - codex_protocol::protocol::TurnCompleteEvent { - turn_id, - last_agent_message: None, - }, - )), - ]; - - session - .record_initial_history(InitialHistory::Forked(rollout_items)) - .await; - - assert_eq!( - session.previous_turn_settings().await, - Some(PreviousTurnSettings { - model: previous_model.to_string(), - realtime_active: Some(turn_context.realtime_active), - }) - ); - } - - #[tokio::test] - async fn thread_rollback_drops_last_turn_from_history() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let rollout_path = attach_rollout_recorder(&sess).await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - let turn_1 = vec![ - user_message("turn 1 user"), - assistant_message("turn 1 assistant"), - ]; - let turn_2 = vec![ - user_message("turn 2 user"), - assistant_message("turn 2 assistant"), - ]; - let mut full_history = Vec::new(); - full_history.extend(initial_context.clone()); - full_history.extend(turn_1.clone()); - full_history.extend(turn_2); - sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item())) - .await; - let rollout_items: Vec = full_history - .into_iter() - .map(RolloutItem::ResponseItem) - .collect(); - sess.persist_rollout_items(&rollout_items).await; - sess.set_previous_turn_settings(Some(PreviousTurnSettings { - model: "stale-model".to_string(), - realtime_active: Some(tc.realtime_active), - })) - .await; - { - let mut state = sess.state.lock().await; - state.set_reference_context_item(Some(tc.to_turn_context_item())); - } - - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - - let rollback_event = wait_for_thread_rolled_back(&rx).await; - assert_eq!(rollback_event.num_turns, 1); - - let mut expected = Vec::new(); - expected.extend(initial_context); - expected.extend(turn_1); - - let history = sess.clone_history().await; - assert_eq!(expected, history.raw_items()); - assert_eq!(sess.previous_turn_settings().await, None); - assert!(sess.reference_context_item().await.is_none()); - - let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) - .await - .expect("read rollout history") - else { - panic!("expected resumed rollout history"); - }; - assert!(resumed.history.iter().any(|item| { - matches!( - item, - RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) - if rollback.num_turns == 1 - ) - })); - } - - #[tokio::test] - async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - attach_rollout_recorder(&sess).await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - let turn_1 = vec![user_message("turn 1 user")]; - let mut full_history = Vec::new(); - full_history.extend(initial_context.clone()); - full_history.extend(turn_1); - sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item())) - .await; - let rollout_items: Vec = full_history - .into_iter() - .map(RolloutItem::ResponseItem) - .collect(); - sess.persist_rollout_items(&rollout_items).await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 99).await; - - let rollback_event = wait_for_thread_rolled_back(&rx).await; - assert_eq!(rollback_event.num_turns, 99); - - let history = sess.clone_history().await; - assert_eq!(initial_context, history.raw_items()); - } - - #[tokio::test] - async fn thread_rollback_fails_without_persisted_rollout_path() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - sess.record_into_history(&initial_context, tc.as_ref()) - .await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - - let error_event = wait_for_thread_rollback_failed(&rx).await; - assert_eq!( - error_event.message, - "thread rollback requires a persisted rollout path" - ); - assert_eq!( - error_event.codex_error_info, - Some(CodexErrorInfo::ThreadRollbackFailed) - ); - assert_eq!(sess.clone_history().await.raw_items(), initial_context); - } - - #[tokio::test] - async fn thread_rollback_recomputes_previous_turn_settings_and_reference_context_from_replay() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - attach_rollout_recorder(&sess).await; - - let first_context_item = tc.to_turn_context_item(); - let first_turn_id = first_context_item - .turn_id - .clone() - .expect("turn context should have turn_id"); - let mut rolled_back_context_item = first_context_item.clone(); - rolled_back_context_item.turn_id = Some("rolled-back-turn".to_string()); - rolled_back_context_item.model = "rolled-back-model".to_string(); - let rolled_back_turn_id = rolled_back_context_item - .turn_id - .clone() - .expect("turn context should have turn_id"); - let turn_one_user = user_message("turn 1 user"); - let turn_one_assistant = assistant_message("turn 1 assistant"); - let turn_two_user = user_message("turn 2 user"); - let turn_two_assistant = assistant_message("turn 2 assistant"); - - sess.persist_rollout_items(&[ - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: first_turn_id.clone(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage( - codex_protocol::protocol::UserMessageEvent { - message: "turn 1 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - )), - RolloutItem::TurnContext(first_context_item.clone()), - RolloutItem::ResponseItem(turn_one_user.clone()), - RolloutItem::ResponseItem(turn_one_assistant.clone()), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: first_turn_id, - last_agent_message: None, - })), - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: rolled_back_turn_id.clone(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage( - codex_protocol::protocol::UserMessageEvent { - message: "turn 2 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - )), - RolloutItem::TurnContext(rolled_back_context_item), - RolloutItem::ResponseItem(turn_two_user), - RolloutItem::ResponseItem(turn_two_assistant), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: rolled_back_turn_id, - last_agent_message: None, - })), - ]) - .await; - sess.replace_history( - vec![assistant_message("stale history")], - Some(first_context_item.clone()), - ) - .await; - sess.set_previous_turn_settings(Some(PreviousTurnSettings { - model: "stale-model".to_string(), - realtime_active: None, - })) - .await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - let rollback_event = wait_for_thread_rolled_back(&rx).await; - assert_eq!(rollback_event.num_turns, 1); - - assert_eq!( - sess.clone_history().await.raw_items(), - vec![turn_one_user, turn_one_assistant] - ); - assert_eq!( - sess.previous_turn_settings().await, - Some(PreviousTurnSettings { - model: tc.model_info.slug.clone(), - realtime_active: Some(tc.realtime_active), - }) - ); - assert_eq!( - serde_json::to_value(sess.reference_context_item().await) - .expect("serialize replay reference context item"), - serde_json::to_value(Some(first_context_item)) - .expect("serialize expected reference context item") - ); - } - - #[tokio::test] - async fn thread_rollback_persists_marker_and_replays_cumulatively() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let rollout_path = attach_rollout_recorder(&sess).await; - let turn_context_item = tc.to_turn_context_item(); - - sess.persist_rollout_items(&[ - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: "turn-1".to_string(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { - message: "turn 1 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - })), - RolloutItem::TurnContext(turn_context_item.clone()), - RolloutItem::ResponseItem(user_message("turn 1 user")), - RolloutItem::ResponseItem(assistant_message("turn 1 assistant")), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: "turn-1".to_string(), - last_agent_message: None, - })), - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: "turn-2".to_string(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { - message: "turn 2 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - })), - RolloutItem::TurnContext(turn_context_item.clone()), - RolloutItem::ResponseItem(user_message("turn 2 user")), - RolloutItem::ResponseItem(assistant_message("turn 2 assistant")), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: "turn-2".to_string(), - last_agent_message: None, - })), - RolloutItem::EventMsg(EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { - turn_id: "turn-3".to_string(), - model_context_window: Some(128_000), - collaboration_mode_kind: ModeKind::Default, - }, - )), - RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { - message: "turn 3 user".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - })), - RolloutItem::TurnContext(turn_context_item), - RolloutItem::ResponseItem(user_message("turn 3 user")), - RolloutItem::ResponseItem(assistant_message("turn 3 assistant")), - RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: "turn-3".to_string(), - last_agent_message: None, - })), - ]) - .await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - let first_rollback = wait_for_thread_rolled_back(&rx).await; - assert_eq!(first_rollback.num_turns, 1); - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - let second_rollback = wait_for_thread_rolled_back(&rx).await; - assert_eq!(second_rollback.num_turns, 1); - - assert_eq!( - sess.clone_history().await.raw_items(), - vec![ - user_message("turn 1 user"), - assistant_message("turn 1 assistant") - ] - ); - - let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) - .await - .expect("read rollout history") - else { - panic!("expected resumed rollout history"); - }; - let rollback_markers = resumed - .history - .iter() - .filter(|item| matches!(item, RolloutItem::EventMsg(EventMsg::ThreadRolledBack(_)))) - .count(); - assert_eq!(rollback_markers, 2); - } - - #[tokio::test] - async fn thread_rollback_fails_when_turn_in_progress() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - sess.record_into_history(&initial_context, tc.as_ref()) - .await; - - *sess.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; - - let error_event = wait_for_thread_rollback_failed(&rx).await; - assert_eq!( - error_event.codex_error_info, - Some(CodexErrorInfo::ThreadRollbackFailed) - ); - - let history = sess.clone_history().await; - assert_eq!(initial_context, history.raw_items()); - } - - #[tokio::test] - async fn thread_rollback_fails_when_num_turns_is_zero() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - - let initial_context = sess.build_initial_context(tc.as_ref()).await; - sess.record_into_history(&initial_context, tc.as_ref()) - .await; - - handlers::thread_rollback(&sess, "sub-1".to_string(), 0).await; - - let error_event = wait_for_thread_rollback_failed(&rx).await; - assert_eq!(error_event.message, "num_turns must be >= 1"); - assert_eq!( - error_event.codex_error_info, - Some(CodexErrorInfo::ThreadRollbackFailed) - ); - - let history = sess.clone_history().await; - assert_eq!(initial_context, history.raw_items()); - } - - #[tokio::test] - async fn set_rate_limits_retains_previous_credits() { - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - - let mut state = SessionState::new(session_configuration); - let initial = RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 10.0, - window_minutes: Some(15), - resets_at: Some(1_700), - }), - secondary: None, - credits: Some(CreditsSnapshot { - has_credits: true, - unlimited: false, - balance: Some("10.00".to_string()), - }), - plan_type: Some(codex_protocol::account::PlanType::Plus), - }; - state.set_rate_limits(initial.clone()); - - let update = RateLimitSnapshot { - limit_id: Some("codex_other".to_string()), - limit_name: Some("codex_other".to_string()), - primary: Some(RateLimitWindow { - used_percent: 40.0, - window_minutes: Some(30), - resets_at: Some(1_800), - }), - secondary: Some(RateLimitWindow { - used_percent: 5.0, - window_minutes: Some(60), - resets_at: Some(1_900), - }), - credits: None, - plan_type: None, - }; - state.set_rate_limits(update.clone()); - - assert_eq!( - state.latest_rate_limits, - Some(RateLimitSnapshot { - limit_id: Some("codex_other".to_string()), - limit_name: Some("codex_other".to_string()), - primary: update.primary.clone(), - secondary: update.secondary, - credits: initial.credits, - plan_type: initial.plan_type, - }) - ); - } - - #[tokio::test] - async fn set_rate_limits_updates_plan_type_when_present() { - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - - let mut state = SessionState::new(session_configuration); - let initial = RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 15.0, - window_minutes: Some(20), - resets_at: Some(1_600), - }), - secondary: Some(RateLimitWindow { - used_percent: 5.0, - window_minutes: Some(45), - resets_at: Some(1_650), - }), - credits: Some(CreditsSnapshot { - has_credits: true, - unlimited: false, - balance: Some("15.00".to_string()), - }), - plan_type: Some(codex_protocol::account::PlanType::Plus), - }; - state.set_rate_limits(initial.clone()); - - let update = RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 35.0, - window_minutes: Some(25), - resets_at: Some(1_700), - }), - secondary: None, - credits: None, - plan_type: Some(codex_protocol::account::PlanType::Pro), - }; - state.set_rate_limits(update.clone()); - - assert_eq!( - state.latest_rate_limits, - Some(RateLimitSnapshot { - limit_id: Some("codex".to_string()), - limit_name: None, - primary: update.primary, - secondary: update.secondary, - credits: initial.credits, - plan_type: update.plan_type, - }) - ); - } - - #[test] - fn prefers_structured_content_when_present() { - let ctr = McpCallToolResult { - // Content present but should be ignored because structured_content is set. - content: vec![text_block("ignored")], - is_error: None, - structured_content: Some(json!({ - "ok": true, - "value": 42 - })), - meta: None, - }; - - let got = FunctionCallOutputPayload::from(&ctr); - let expected = FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text( - serde_json::to_string(&json!({ - "ok": true, - "value": 42 - })) - .unwrap(), - ), - success: Some(true), - }; - - assert_eq!(expected, got); - } - - #[tokio::test] - async fn includes_timed_out_message() { - let exec = ExecToolCallOutput { - exit_code: 0, - stdout: StreamOutput::new(String::new()), - stderr: StreamOutput::new(String::new()), - aggregated_output: StreamOutput::new("Command output".to_string()), - duration: StdDuration::from_secs(1), - timed_out: true, - }; - let (_, turn_context) = make_session_and_context().await; - - let out = format_exec_output_str(&exec, turn_context.truncation_policy); - - assert_eq!( - out, - "command timed out after 1000 milliseconds\nCommand output" - ); - } - - #[tokio::test] - async fn turn_context_with_model_updates_model_fields() { - let (session, mut turn_context) = make_session_and_context().await; - turn_context.reasoning_effort = Some(ReasoningEffortConfig::Minimal); - let updated = turn_context - .with_model("gpt-5.1".to_string(), &session.services.models_manager) - .await; - let expected_model_info = session - .services - .models_manager - .get_model_info("gpt-5.1", updated.config.as_ref()) - .await; - - assert_eq!(updated.config.model.as_deref(), Some("gpt-5.1")); - assert_eq!(updated.collaboration_mode.model(), "gpt-5.1"); - assert_eq!(updated.model_info, expected_model_info); - assert_eq!( - updated.reasoning_effort, - Some(ReasoningEffortConfig::Medium) - ); - assert_eq!( - updated.collaboration_mode.reasoning_effort(), - Some(ReasoningEffortConfig::Medium) - ); - assert_eq!( - updated.config.model_reasoning_effort, - Some(ReasoningEffortConfig::Medium) - ); - assert_eq!( - updated.truncation_policy, - expected_model_info.truncation_policy.into() - ); - assert!(!Arc::ptr_eq( - &updated.tool_call_gate, - &turn_context.tool_call_gate - )); - } - - #[test] - fn falls_back_to_content_when_structured_is_null() { - let ctr = McpCallToolResult { - content: vec![text_block("hello"), text_block("world")], - is_error: None, - structured_content: Some(serde_json::Value::Null), - meta: None, - }; - - let got = FunctionCallOutputPayload::from(&ctr); - let expected = FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text( - serde_json::to_string(&vec![text_block("hello"), text_block("world")]).unwrap(), - ), - success: Some(true), - }; - - assert_eq!(expected, got); - } - - #[test] - fn success_flag_reflects_is_error_true() { - let ctr = McpCallToolResult { - content: vec![text_block("unused")], - is_error: Some(true), - structured_content: Some(json!({ "message": "bad" })), - meta: None, - }; - - let got = FunctionCallOutputPayload::from(&ctr); - let expected = FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text( - serde_json::to_string(&json!({ "message": "bad" })).unwrap(), - ), - success: Some(false), - }; - - assert_eq!(expected, got); - } - - #[test] - fn success_flag_true_with_no_error_and_content_used() { - let ctr = McpCallToolResult { - content: vec![text_block("alpha")], - is_error: Some(false), - structured_content: None, - meta: None, - }; - - let got = FunctionCallOutputPayload::from(&ctr); - let expected = FunctionCallOutputPayload { - body: FunctionCallOutputBody::Text( - serde_json::to_string(&vec![text_block("alpha")]).unwrap(), - ), - success: Some(true), - }; - - assert_eq!(expected, got); - } - - async fn wait_for_thread_rolled_back( - rx: &async_channel::Receiver, - ) -> crate::protocol::ThreadRolledBackEvent { - let deadline = StdDuration::from_secs(2); - let start = std::time::Instant::now(); - loop { - let remaining = deadline.saturating_sub(start.elapsed()); - let evt = tokio::time::timeout(remaining, rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - match evt.msg { - EventMsg::ThreadRolledBack(payload) => return payload, - _ => continue, - } - } - } - - async fn wait_for_thread_rollback_failed(rx: &async_channel::Receiver) -> ErrorEvent { - let deadline = StdDuration::from_secs(2); - let start = std::time::Instant::now(); - loop { - let remaining = deadline.saturating_sub(start.elapsed()); - let evt = tokio::time::timeout(remaining, rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - match evt.msg { - EventMsg::Error(payload) - if payload.codex_error_info == Some(CodexErrorInfo::ThreadRollbackFailed) => - { - return payload; - } - _ => continue, - } - } - } - - async fn attach_rollout_recorder(session: &Arc) -> PathBuf { - let config = session.get_config().await; - let recorder = RolloutRecorder::new( - config.as_ref(), - RolloutRecorderParams::new( - ThreadId::default(), - None, - SessionSource::Exec, - BaseInstructions::default(), - Vec::new(), - EventPersistenceMode::Limited, - ), - None, - None, - ) - .await - .expect("create rollout recorder"); - let rollout_path = recorder.rollout_path().to_path_buf(); - { - let mut rollout = session.services.rollout.lock().await; - *rollout = Some(recorder); - } - session.ensure_rollout_materialized().await; - session.flush_rollout().await; - rollout_path - } - - fn text_block(s: &str) -> serde_json::Value { - json!({ - "type": "text", - "text": s, - }) - } - - fn init_test_tracing() { - static INIT: Once = Once::new(); - INIT.call_once(|| { - let provider = SdkTracerProvider::builder().build(); - let tracer = provider.tracer("codex-core-tests"); - let subscriber = tracing_subscriber::registry() - .with(tracing_opentelemetry::layer().with_tracer(tracer)); - tracing::subscriber::set_global_default(subscriber) - .expect("global tracing subscriber should only be installed once"); - }); - } - - async fn build_test_config(codex_home: &Path) -> Config { - ConfigBuilder::default() - .codex_home(codex_home.to_path_buf()) - .build() - .await - .expect("load default test config") - } - - fn otel_manager( - conversation_id: ThreadId, - config: &Config, - model_info: &ModelInfo, - session_source: SessionSource, - ) -> OtelManager { - OtelManager::new( - conversation_id, - ModelsManager::get_model_offline_for_tests(config.model.as_deref()).as_str(), - model_info.slug.as_str(), - None, - Some("test@test.com".to_string()), - Some(TelemetryAuthMode::Chatgpt), - "test_originator".to_string(), - false, - "test".to_string(), - session_source, - ) - } - - pub(crate) async fn make_session_configuration_for_tests() -> SessionConfiguration { - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - - SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - } - } - - #[tokio::test] - async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() { - let codex_home = tempfile::tempdir().expect("create temp dir"); - let mut config = build_test_config(codex_home.path()).await; - config - .features - .enable(Feature::ShellZshFork) - .expect("test config should allow shell_zsh_fork"); - config.zsh_path = None; - let config = Arc::new(config); - - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( - config.codex_home.clone(), - auth_manager.clone(), - None, - CollaborationModesConfig::default(), - )); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort: config.model_reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - - let (tx_event, _rx_event) = async_channel::unbounded(); - let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); - let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); - let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); - let skills_manager = Arc::new(SkillsManager::new( - config.codex_home.clone(), - Arc::clone(&plugins_manager), - )); - let result = Session::new( - session_configuration, - Arc::clone(&config), - auth_manager, - models_manager, - ExecPolicyManager::default(), - tx_event, - agent_status_tx, - InitialHistory::New, - SessionSource::Exec, - skills_manager, - plugins_manager, - mcp_manager, - Arc::new(FileWatcher::noop()), - AgentControl::default(), - ) - .await; - - let err = match result { - Ok(_) => panic!("expected startup to fail"), - Err(err) => err, - }; - let msg = format!("{err:#}"); - assert!(msg.contains("zsh fork feature enabled, but `zsh_path` is not configured")); - } - - // todo: use online model info - pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { - let (tx_event, _rx_event) = async_channel::unbounded(); - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let conversation_id = ThreadId::default(); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( - config.codex_home.clone(), - auth_manager.clone(), - None, - CollaborationModesConfig::default(), - )); - let agent_control = AgentControl::default(); - let exec_policy = ExecPolicyManager::default(); - let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools: Vec::new(), - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - let per_turn_config = Session::build_per_turn_config(&session_configuration); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - session_configuration.collaboration_mode.model(), - &per_turn_config, - ); - let otel_manager = otel_manager( - conversation_id, - config.as_ref(), - &model_info, - session_configuration.session_source.clone(), - ); - - let state = SessionState::new(session_configuration.clone()); - let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); - let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); - let skills_manager = Arc::new(SkillsManager::new( - config.codex_home.clone(), - Arc::clone(&plugins_manager), - )); - let network_approval = Arc::new(NetworkApprovalService::default()); - - let file_watcher = Arc::new(FileWatcher::noop()); - let services = SessionServices { - mcp_connection_manager: Arc::new(RwLock::new( - McpConnectionManager::new_mcp_connection_manager_for_tests( - &config.permissions.approval_policy, - ), - )), - mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), - unified_exec_manager: UnifiedExecProcessManager::new( - config.background_terminal_max_timeout, - ), - shell_zsh_path: None, - main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(), - analytics_events_client: AnalyticsEventsClient::new( - Arc::clone(&config), - Arc::clone(&auth_manager), - ), - hooks: Hooks::new(HooksConfig { - legacy_notify_argv: config.notify.clone(), - }), - rollout: Mutex::new(None), - user_shell: Arc::new(default_user_shell()), - shell_snapshot_tx: watch::channel(None).0, - show_raw_agent_reasoning: config.show_raw_agent_reasoning, - exec_policy, - auth_manager: auth_manager.clone(), - otel_manager: otel_manager.clone(), - models_manager: Arc::clone(&models_manager), - tool_approvals: Mutex::new(ApprovalStore::default()), - execve_session_approvals: RwLock::new(HashMap::new()), - skills_manager, - plugins_manager, - mcp_manager, - file_watcher, - agent_control, - network_proxy: None, - network_approval: Arc::clone(&network_approval), - state_db: None, - model_client: ModelClient::new( - Some(auth_manager.clone()), - conversation_id, - session_configuration.provider.clone(), - session_configuration.session_source.clone(), - config.model_verbosity, - ws_version_from_features(config.as_ref()), - config.features.enabled(Feature::EnableRequestCompression), - config.features.enabled(Feature::RuntimeMetrics), - Session::build_model_client_beta_features_header(config.as_ref()), - ), - }; - let js_repl = Arc::new(JsReplHandle::with_node_path( - config.js_repl_node_path.clone(), - config.js_repl_node_module_dirs.clone(), - )); - - let skills_outcome = Arc::new(services.skills_manager.skills_for_config(&per_turn_config)); - let turn_context = Session::make_turn_context( - Some(Arc::clone(&auth_manager)), - &otel_manager, - session_configuration.provider.clone(), - &session_configuration, - per_turn_config, - model_info, - None, - "turn_id".to_string(), - Arc::clone(&js_repl), - skills_outcome, - ); - - let session = Session { - conversation_id, - tx_event, - agent_status: agent_status_tx, - state: Mutex::new(state), - features: config.features.clone(), - pending_mcp_server_refresh_config: Mutex::new(None), - conversation: Arc::new(RealtimeConversationManager::new()), - active_turn: Mutex::new(None), - services, - js_repl, - next_internal_sub_id: AtomicU64::new(0), - }; - - (session, turn_context) - } - - #[tokio::test] - async fn submit_with_id_captures_current_span_trace_context() { - let (session, _turn_context) = make_session_and_context().await; - let (tx_sub, rx_sub) = async_channel::bounded(1); - let (_tx_event, rx_event) = async_channel::unbounded(); - let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); - let codex = Codex { - tx_sub, - rx_event, - agent_status, - session: Arc::new(session), - }; - - init_test_tracing(); - - let request_parent = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), - tracestate: Some("vendor=value".into()), - }; - let request_span = info_span!("app_server.request"); - assert!(set_parent_from_w3c_trace_context( - &request_span, - &request_parent - )); - - let expected_trace = async { - let expected_trace = - current_span_w3c_trace_context().expect("current span should have trace context"); - codex - .submit_with_id(Submission { - id: "sub-1".into(), - op: Op::Interrupt, - trace: None, - }) - .await - .expect("submit should succeed"); - expected_trace - } - .instrument(request_span) - .await; - - let submitted = rx_sub.recv().await.expect("submission"); - assert_eq!(submitted.trace, Some(expected_trace)); - } - - #[tokio::test] - async fn new_default_turn_captures_current_span_trace_id() { - let (session, _turn_context) = make_session_and_context().await; - - init_test_tracing(); - - let request_parent = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), - tracestate: Some("vendor=value".into()), - }; - let request_span = info_span!("app_server.request"); - assert!(set_parent_from_w3c_trace_context( - &request_span, - &request_parent - )); - - let turn_context_item = async { - let expected_trace_id = Span::current() - .context() - .span() - .span_context() - .trace_id() - .to_string(); - let turn_context = session.new_default_turn().await; - let turn_context_item = turn_context.to_turn_context_item(); - assert_eq!(turn_context_item.trace_id, Some(expected_trace_id)); - turn_context_item - } - .instrument(request_span) - .await; - - assert_eq!( - turn_context_item.trace_id.as_deref(), - Some("00000000000000000000000000000011") - ); - } - - #[test] - fn submission_dispatch_span_prefers_submission_trace_context() { - init_test_tracing(); - - let ambient_parent = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000033-0000000000000044-01".into()), - tracestate: None, - }; - let ambient_span = info_span!("ambient"); - assert!(set_parent_from_w3c_trace_context( - &ambient_span, - &ambient_parent - )); - - let submission_trace = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000055-0000000000000066-01".into()), - tracestate: Some("vendor=value".into()), - }; - let dispatch_span = ambient_span.in_scope(|| { - submission_dispatch_span(&Submission { - id: "sub-1".into(), - op: Op::Interrupt, - trace: Some(submission_trace), - }) - }); - - let trace_id = dispatch_span.context().span().span_context().trace_id(); - assert_eq!( - trace_id, - TraceId::from_hex("00000000000000000000000000000055").expect("trace id") - ); - } - - #[test] - fn submission_dispatch_span_uses_debug_for_realtime_audio() { - init_test_tracing(); - - let dispatch_span = submission_dispatch_span(&Submission { - id: "sub-1".into(), - op: Op::RealtimeConversationAudio(ConversationAudioParams { - frame: RealtimeAudioFrame { - data: "ZmFrZQ==".into(), - sample_rate: 16_000, - num_channels: 1, - samples_per_channel: Some(160), - }, - }), - trace: None, - }); - - assert_eq!( - dispatch_span.metadata().expect("span metadata").level(), - &tracing::Level::DEBUG - ); - } - - #[tokio::test] - async fn spawn_task_turn_span_inherits_dispatch_trace_context() { - struct TraceCaptureTask { - captured_trace: Arc>>, - } - - #[async_trait::async_trait] - impl SessionTask for TraceCaptureTask { - fn kind(&self) -> TaskKind { - TaskKind::Regular - } - - fn span_name(&self) -> &'static str { - "session_task.trace_capture" - } - - async fn run( - self: Arc, - _session: Arc, - _ctx: Arc, - _input: Vec, - _cancellation_token: CancellationToken, - ) -> Option { - let mut trace = self - .captured_trace - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *trace = current_span_w3c_trace_context(); - None - } - } - - init_test_tracing(); - - let request_parent = W3cTraceContext { - traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), - tracestate: Some("vendor=value".into()), - }; - let request_span = tracing::info_span!("app_server.request"); - assert!(set_parent_from_w3c_trace_context( - &request_span, - &request_parent - )); - - let submission_trace = async { - current_span_w3c_trace_context().expect("request span should have trace context") - } - .instrument(request_span) - .await; - - let dispatch_span = submission_dispatch_span(&Submission { - id: "sub-1".into(), - op: Op::Interrupt, - trace: Some(submission_trace.clone()), - }); - let dispatch_span_id = dispatch_span.context().span().span_context().span_id(); - - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let captured_trace = Arc::new(std::sync::Mutex::new(None)); - - async { - sess.spawn_task( - Arc::clone(&tc), - vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }], - TraceCaptureTask { - captured_trace: Arc::clone(&captured_trace), - }, - ) - .await; - } - .instrument(dispatch_span) - .await; - - let evt = tokio::time::timeout(StdDuration::from_secs(2), rx.recv()) - .await - .expect("timeout waiting for turn completion") - .expect("event"); - assert!(matches!(evt.msg, EventMsg::TurnComplete(_))); - - let task_trace = captured_trace - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone() - .expect("turn task should capture the current span trace context"); - let submission_context = - codex_otel::context_from_w3c_trace_context(&submission_trace).expect("submission"); - let task_context = - codex_otel::context_from_w3c_trace_context(&task_trace).expect("task trace"); - - assert_eq!( - task_context.span().span_context().trace_id(), - submission_context.span().span_context().trace_id() - ); - assert_ne!( - task_context.span().span_context().span_id(), - dispatch_span_id - ); - } - - pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( - dynamic_tools: Vec, - ) -> ( - Arc, - Arc, - async_channel::Receiver, - ) { - let (tx_event, rx_event) = async_channel::unbounded(); - let codex_home = tempfile::tempdir().expect("create temp dir"); - let config = build_test_config(codex_home.path()).await; - let config = Arc::new(config); - let conversation_id = ThreadId::default(); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( - config.codex_home.clone(), - auth_manager.clone(), - None, - CollaborationModesConfig::default(), - )); - let agent_control = AgentControl::default(); - let exec_policy = ExecPolicyManager::default(); - let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = - ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); - let reasoning_effort = config.model_reasoning_effort; - let collaboration_mode = CollaborationMode { - mode: ModeKind::Default, - settings: Settings { - model, - reasoning_effort, - developer_instructions: None, - }, - }; - let session_configuration = SessionConfiguration { - provider: config.model_provider.clone(), - collaboration_mode, - model_reasoning_summary: config.model_reasoning_summary, - developer_instructions: config.developer_instructions.clone(), - user_instructions: config.user_instructions.clone(), - service_tier: None, - personality: config.personality, - base_instructions: config - .base_instructions - .clone() - .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), - compact_prompt: config.compact_prompt.clone(), - approval_policy: config.permissions.approval_policy.clone(), - sandbox_policy: config.permissions.sandbox_policy.clone(), - windows_sandbox_level: WindowsSandboxLevel::from_config(&config), - cwd: config.cwd.clone(), - codex_home: config.codex_home.clone(), - thread_name: None, - original_config_do_not_use: Arc::clone(&config), - metrics_service_name: None, - app_server_client_name: None, - session_source: SessionSource::Exec, - dynamic_tools, - persist_extended_history: false, - inherited_shell_snapshot: None, - }; - let per_turn_config = Session::build_per_turn_config(&session_configuration); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - session_configuration.collaboration_mode.model(), - &per_turn_config, - ); - let otel_manager = otel_manager( - conversation_id, - config.as_ref(), - &model_info, - session_configuration.session_source.clone(), - ); - - let state = SessionState::new(session_configuration.clone()); - let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); - let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); - let skills_manager = Arc::new(SkillsManager::new( - config.codex_home.clone(), - Arc::clone(&plugins_manager), - )); - let network_approval = Arc::new(NetworkApprovalService::default()); - - let file_watcher = Arc::new(FileWatcher::noop()); - let services = SessionServices { - mcp_connection_manager: Arc::new(RwLock::new( - McpConnectionManager::new_mcp_connection_manager_for_tests( - &config.permissions.approval_policy, - ), - )), - mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), - unified_exec_manager: UnifiedExecProcessManager::new( - config.background_terminal_max_timeout, - ), - shell_zsh_path: None, - main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(), - analytics_events_client: AnalyticsEventsClient::new( - Arc::clone(&config), - Arc::clone(&auth_manager), - ), - hooks: Hooks::new(HooksConfig { - legacy_notify_argv: config.notify.clone(), - }), - rollout: Mutex::new(None), - user_shell: Arc::new(default_user_shell()), - shell_snapshot_tx: watch::channel(None).0, - show_raw_agent_reasoning: config.show_raw_agent_reasoning, - exec_policy, - auth_manager: Arc::clone(&auth_manager), - otel_manager: otel_manager.clone(), - models_manager: Arc::clone(&models_manager), - tool_approvals: Mutex::new(ApprovalStore::default()), - execve_session_approvals: RwLock::new(HashMap::new()), - skills_manager, - plugins_manager, - mcp_manager, - file_watcher, - agent_control, - network_proxy: None, - network_approval: Arc::clone(&network_approval), - state_db: None, - model_client: ModelClient::new( - Some(Arc::clone(&auth_manager)), - conversation_id, - session_configuration.provider.clone(), - session_configuration.session_source.clone(), - config.model_verbosity, - ws_version_from_features(config.as_ref()), - config.features.enabled(Feature::EnableRequestCompression), - config.features.enabled(Feature::RuntimeMetrics), - Session::build_model_client_beta_features_header(config.as_ref()), - ), - }; - let js_repl = Arc::new(JsReplHandle::with_node_path( - config.js_repl_node_path.clone(), - config.js_repl_node_module_dirs.clone(), - )); - - let skills_outcome = Arc::new(services.skills_manager.skills_for_config(&per_turn_config)); - let turn_context = Arc::new(Session::make_turn_context( - Some(Arc::clone(&auth_manager)), - &otel_manager, - session_configuration.provider.clone(), - &session_configuration, - per_turn_config, - model_info, - None, - "turn_id".to_string(), - Arc::clone(&js_repl), - skills_outcome, - )); - - let session = Arc::new(Session { - conversation_id, - tx_event, - agent_status: agent_status_tx, - state: Mutex::new(state), - features: config.features.clone(), - pending_mcp_server_refresh_config: Mutex::new(None), - conversation: Arc::new(RealtimeConversationManager::new()), - active_turn: Mutex::new(None), - services, - js_repl, - next_internal_sub_id: AtomicU64::new(0), - }); - - (session, turn_context, rx_event) - } - - // Like make_session_and_context, but returns Arc and the event receiver - // so tests can assert on emitted events. - pub(crate) async fn make_session_and_context_with_rx() -> ( - Arc, - Arc, - async_channel::Receiver, - ) { - make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await - } - - #[tokio::test] - async fn refresh_mcp_servers_is_deferred_until_next_turn() { - let (session, turn_context) = make_session_and_context().await; - let old_token = session.mcp_startup_cancellation_token().await; - assert!(!old_token.is_cancelled()); - - let mcp_oauth_credentials_store_mode = - serde_json::to_value(OAuthCredentialsStoreMode::Auto).expect("serialize store mode"); - let refresh_config = McpServerRefreshConfig { - mcp_servers: json!({}), - mcp_oauth_credentials_store_mode, - }; - { - let mut guard = session.pending_mcp_server_refresh_config.lock().await; - *guard = Some(refresh_config); - } - - assert!(!old_token.is_cancelled()); - assert!( - session - .pending_mcp_server_refresh_config - .lock() - .await - .is_some() - ); - - session - .refresh_mcp_servers_if_requested(&turn_context) - .await; - - assert!(old_token.is_cancelled()); - assert!( - session - .pending_mcp_server_refresh_config - .lock() - .await - .is_none() - ); - let new_token = session.mcp_startup_cancellation_token().await; - assert!(!new_token.is_cancelled()); - } - - #[tokio::test] - async fn record_model_warning_appends_user_message() { - let (mut session, turn_context) = make_session_and_context().await; - let features = crate::features::Features::with_defaults().into(); - session.features = features; - - session - .record_model_warning("too many unified exec processes", &turn_context) - .await; - - let history = session.clone_history().await; - let history_items = history.raw_items(); - let last = history_items.last().expect("warning recorded"); - - match last { - ResponseItem::Message { role, content, .. } => { - assert_eq!(role, "user"); - assert_eq!( - content, - &vec![ContentItem::InputText { - text: "Warning: too many unified exec processes".to_string(), - }] - ); - } - other => panic!("expected user message, got {other:?}"), - } - } - - #[tokio::test] - async fn spawn_task_does_not_update_previous_turn_settings_for_non_run_turn_tasks() { - let (sess, tc, _rx) = make_session_and_context_with_rx().await; - sess.set_previous_turn_settings(None).await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: true, - }, - ) - .await; - - sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - assert_eq!(sess.previous_turn_settings().await, None); - } - - #[tokio::test] - async fn build_settings_update_items_emits_environment_item_for_network_changes() { - let (session, previous_context) = make_session_and_context().await; - let previous_context = Arc::new(previous_context); - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - - let mut config = (*current_context.config).clone(); - let mut requirements = config.config_layer_stack.requirements().clone(); - requirements.network = Some(Sourced::new( - NetworkConstraints { - allowed_domains: Some(vec!["api.example.com".to_string()]), - denied_domains: Some(vec!["blocked.example.com".to_string()]), - ..Default::default() - }, - RequirementSource::CloudRequirements, - )); - let layers = config - .config_layer_stack - .get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, true) - .into_iter() - .cloned() - .collect(); - config.config_layer_stack = ConfigLayerStack::new( - layers, - requirements, - config.config_layer_stack.requirements_toml().clone(), - ) - .expect("rebuild config layer stack with network requirements"); - current_context.config = Arc::new(config); - - let reference_context_item = previous_context.to_turn_context_item(); - let update_items = session - .build_settings_update_items(Some(&reference_context_item), ¤t_context) - .await; - - let environment_update = update_items - .iter() - .find_map(|item| match item { - ResponseItem::Message { role, content, .. } if role == "user" => { - let [ContentItem::InputText { text }] = content.as_slice() else { - return None; - }; - text.contains("").then_some(text) - } - _ => None, - }) - .expect("environment update item should be emitted"); - assert!(environment_update.contains("")); - assert!(environment_update.contains("api.example.com")); - assert!(environment_update.contains("blocked.example.com")); - } - - #[tokio::test] - async fn build_settings_update_items_emits_environment_item_for_time_changes() { - let (session, previous_context) = make_session_and_context().await; - let previous_context = Arc::new(previous_context); - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - current_context.current_date = Some("2026-02-27".to_string()); - current_context.timezone = Some("Europe/Berlin".to_string()); - - let reference_context_item = previous_context.to_turn_context_item(); - let update_items = session - .build_settings_update_items(Some(&reference_context_item), ¤t_context) - .await; - - let environment_update = update_items - .iter() - .find_map(|item| match item { - ResponseItem::Message { role, content, .. } if role == "user" => { - let [ContentItem::InputText { text }] = content.as_slice() else { - return None; - }; - text.contains("").then_some(text) - } - _ => None, - }) - .expect("environment update item should be emitted"); - assert!(environment_update.contains("2026-02-27")); - assert!(environment_update.contains("Europe/Berlin")); - } - - #[tokio::test] - async fn build_settings_update_items_emits_realtime_start_when_session_becomes_live() { - let (session, previous_context) = make_session_and_context().await; - let previous_context = Arc::new(previous_context); - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - current_context.realtime_active = true; - - let update_items = session - .build_settings_update_items( - Some(&previous_context.to_turn_context_item()), - ¤t_context, - ) - .await; - - let developer_texts = developer_input_texts(&update_items); - assert!( - developer_texts - .iter() - .any(|text| text.contains("")), - "expected a realtime start update, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_settings_update_items_emits_realtime_end_when_session_stops_being_live() { - let (session, mut previous_context) = make_session_and_context().await; - previous_context.realtime_active = true; - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - current_context.realtime_active = false; - - let update_items = session - .build_settings_update_items( - Some(&previous_context.to_turn_context_item()), - ¤t_context, - ) - .await; - - let developer_texts = developer_input_texts(&update_items); - assert!( - developer_texts - .iter() - .any(|text| text.contains("Reason: inactive")), - "expected a realtime end update, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_settings_update_items_uses_previous_turn_settings_for_realtime_end() { - let (session, previous_context) = make_session_and_context().await; - let mut previous_context_item = previous_context.to_turn_context_item(); - previous_context_item.realtime_active = None; - let previous_turn_settings = PreviousTurnSettings { - model: previous_context.model_info.slug.clone(), - realtime_active: Some(true), - }; - let mut current_context = previous_context - .with_model( - previous_context.model_info.slug.clone(), - &session.services.models_manager, - ) - .await; - current_context.realtime_active = false; - - session - .set_previous_turn_settings(Some(previous_turn_settings)) - .await; - let update_items = session - .build_settings_update_items(Some(&previous_context_item), ¤t_context) - .await; - - let developer_texts = developer_input_texts(&update_items); - assert!( - developer_texts - .iter() - .any(|text| text.contains("Reason: inactive")), - "expected a realtime end update from previous turn settings, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_initial_context_uses_previous_realtime_state() { - let (session, mut turn_context) = make_session_and_context().await; - turn_context.realtime_active = true; - - let initial_context = session.build_initial_context(&turn_context).await; - let developer_texts = developer_input_texts(&initial_context); - assert!( - developer_texts - .iter() - .any(|text| text.contains("")), - "expected initial context to describe active realtime state, got {developer_texts:?}" - ); - - let previous_context_item = turn_context.to_turn_context_item(); - { - let mut state = session.state.lock().await; - state.set_reference_context_item(Some(previous_context_item)); - } - let resumed_context = session.build_initial_context(&turn_context).await; - let resumed_developer_texts = developer_input_texts(&resumed_context); - assert!( - !resumed_developer_texts - .iter() - .any(|text| text.contains("")), - "did not expect a duplicate realtime update, got {resumed_developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_initial_context_uses_previous_turn_settings_for_realtime_end() { - let (session, turn_context) = make_session_and_context().await; - let previous_turn_settings = PreviousTurnSettings { - model: turn_context.model_info.slug.clone(), - realtime_active: Some(true), - }; - - session - .set_previous_turn_settings(Some(previous_turn_settings)) - .await; - let initial_context = session.build_initial_context(&turn_context).await; - let developer_texts = developer_input_texts(&initial_context); - assert!( - developer_texts - .iter() - .any(|text| text.contains("Reason: inactive")), - "expected initial context to describe an ended realtime session, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn build_initial_context_restates_realtime_start_when_reference_context_is_missing() { - let (session, mut turn_context) = make_session_and_context().await; - turn_context.realtime_active = true; - let previous_turn_settings = PreviousTurnSettings { - model: turn_context.model_info.slug.clone(), - realtime_active: Some(true), - }; - - session - .set_previous_turn_settings(Some(previous_turn_settings)) - .await; - let initial_context = session.build_initial_context(&turn_context).await; - let developer_texts = developer_input_texts(&initial_context); - assert!( - developer_texts - .iter() - .any(|text| text.contains("")), - "expected initial context to restate active realtime when the reference context is missing, got {developer_texts:?}" - ); - } - - #[tokio::test] - async fn record_context_updates_and_set_reference_context_item_injects_full_context_when_baseline_missing() - { - let (session, turn_context) = make_session_and_context().await; - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - let history = session.clone_history().await; - let initial_context = session.build_initial_context(&turn_context).await; - assert_eq!(history.raw_items().to_vec(), initial_context); - - let current_context = session.reference_context_item().await; - assert_eq!( - serde_json::to_value(current_context).expect("serialize current context item"), - serde_json::to_value(Some(turn_context.to_turn_context_item())) - .expect("serialize expected context item") - ); - } - - #[tokio::test] - async fn record_context_updates_and_set_reference_context_item_reinjects_full_context_after_clear() - { - let (session, turn_context) = make_session_and_context().await; - let compacted_summary = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: format!("{}\nsummary", crate::compact::SUMMARY_PREFIX), - }], - end_turn: None, - phase: None, - }; - session - .record_into_history(std::slice::from_ref(&compacted_summary), &turn_context) - .await; - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - { - let mut state = session.state.lock().await; - state.set_reference_context_item(None); - } - session - .replace_history(vec![compacted_summary.clone()], None) - .await; - - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - - let history = session.clone_history().await; - let mut expected_history = vec![compacted_summary]; - expected_history.extend(session.build_initial_context(&turn_context).await); - assert_eq!(history.raw_items().to_vec(), expected_history); - } - - #[tokio::test] - async fn record_context_updates_and_set_reference_context_item_persists_baseline_without_emitting_diffs() - { - let (session, previous_context) = make_session_and_context().await; - let next_model = if previous_context.model_info.slug == "gpt-5.1" { - "gpt-5" - } else { - "gpt-5.1" - }; - let turn_context = previous_context - .with_model(next_model.to_string(), &session.services.models_manager) - .await; - let previous_context_item = previous_context.to_turn_context_item(); - { - let mut state = session.state.lock().await; - state.set_reference_context_item(Some(previous_context_item.clone())); - } - let config = session.get_config().await; - let recorder = RolloutRecorder::new( - config.as_ref(), - RolloutRecorderParams::new( - ThreadId::default(), - None, - SessionSource::Exec, - BaseInstructions::default(), - Vec::new(), - EventPersistenceMode::Limited, - ), - None, - None, - ) - .await - .expect("create rollout recorder"); - let rollout_path = recorder.rollout_path().to_path_buf(); - { - let mut rollout = session.services.rollout.lock().await; - *rollout = Some(recorder); - } - - let update_items = session - .build_settings_update_items(Some(&previous_context_item), &turn_context) - .await; - assert_eq!(update_items, Vec::new()); - - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - - assert_eq!( - session.clone_history().await.raw_items().to_vec(), - Vec::new() - ); - assert_eq!( - serde_json::to_value(session.reference_context_item().await) - .expect("serialize current context item"), - serde_json::to_value(Some(turn_context.to_turn_context_item())) - .expect("serialize expected context item") - ); - session.ensure_rollout_materialized().await; - session.flush_rollout().await; - - let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) - .await - .expect("read rollout history") - else { - panic!("expected resumed rollout history"); - }; - let persisted_turn_context = resumed.history.iter().find_map(|item| match item { - RolloutItem::TurnContext(ctx) => Some(ctx.clone()), - _ => None, - }); - assert_eq!( - serde_json::to_value(persisted_turn_context) - .expect("serialize persisted turn context item"), - serde_json::to_value(Some(turn_context.to_turn_context_item())) - .expect("serialize expected turn context item") - ); - } - - #[tokio::test] - async fn build_initial_context_prepends_model_switch_message() { - let (session, turn_context) = make_session_and_context().await; - let previous_turn_settings = PreviousTurnSettings { - model: "previous-regular-model".to_string(), - realtime_active: None, - }; - - session - .set_previous_turn_settings(Some(previous_turn_settings)) - .await; - let initial_context = session.build_initial_context(&turn_context).await; - - let ResponseItem::Message { role, content, .. } = &initial_context[0] else { - panic!("expected developer message"); - }; - assert_eq!(role, "developer"); - let [ContentItem::InputText { text }, ..] = content.as_slice() else { - panic!("expected developer text"); - }; - assert!(text.contains("")); - } - - #[tokio::test] - async fn record_context_updates_and_set_reference_context_item_persists_full_reinjection_to_rollout() - { - let (session, previous_context) = make_session_and_context().await; - let next_model = if previous_context.model_info.slug == "gpt-5.1" { - "gpt-5" - } else { - "gpt-5.1" - }; - let turn_context = previous_context - .with_model(next_model.to_string(), &session.services.models_manager) - .await; - let config = session.get_config().await; - let recorder = RolloutRecorder::new( - config.as_ref(), - RolloutRecorderParams::new( - ThreadId::default(), - None, - SessionSource::Exec, - BaseInstructions::default(), - Vec::new(), - EventPersistenceMode::Limited, - ), - None, - None, - ) - .await - .expect("create rollout recorder"); - let rollout_path = recorder.rollout_path().to_path_buf(); - { - let mut rollout = session.services.rollout.lock().await; - *rollout = Some(recorder); - } - - session - .persist_rollout_items(&[RolloutItem::EventMsg(EventMsg::UserMessage( - UserMessageEvent { - message: "seed rollout".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - ))]) - .await; - { - let mut state = session.state.lock().await; - state.set_reference_context_item(None); - } - - session - .set_previous_turn_settings(Some(PreviousTurnSettings { - model: previous_context.model_info.slug.clone(), - realtime_active: Some(previous_context.realtime_active), - })) - .await; - session - .record_context_updates_and_set_reference_context_item(&turn_context) - .await; - session.ensure_rollout_materialized().await; - session.flush_rollout().await; - - let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) - .await - .expect("read rollout history") - else { - panic!("expected resumed rollout history"); - }; - let persisted_turn_context = resumed.history.iter().find_map(|item| match item { - RolloutItem::TurnContext(ctx) => Some(ctx.clone()), - _ => None, - }); - - assert_eq!( - serde_json::to_value(persisted_turn_context) - .expect("serialize persisted turn context item"), - serde_json::to_value(Some(turn_context.to_turn_context_item())) - .expect("serialize expected turn context item") - ); - } - - #[tokio::test] - async fn run_user_shell_command_does_not_set_reference_context_item() { - let (session, _turn_context, rx) = make_session_and_context_with_rx().await; - { - let mut state = session.state.lock().await; - state.set_reference_context_item(None); - } - - handlers::run_user_shell_command(&session, "sub-id".to_string(), "echo shell".to_string()) - .await; - - let deadline = StdDuration::from_secs(15); - let start = std::time::Instant::now(); - loop { - let remaining = deadline.saturating_sub(start.elapsed()); - let evt = tokio::time::timeout(remaining, rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - if matches!(evt.msg, EventMsg::TurnComplete(_)) { - break; - } - } - - assert!( - session.reference_context_item().await.is_none(), - "standalone shell tasks should not mutate previous context" - ); - } - - #[derive(Clone, Copy)] - struct NeverEndingTask { - kind: TaskKind, - listen_to_cancellation_token: bool, - } - - #[async_trait::async_trait] - impl SessionTask for NeverEndingTask { - fn kind(&self) -> TaskKind { - self.kind - } - - fn span_name(&self) -> &'static str { - "session_task.never_ending" - } - - async fn run( - self: Arc, - _session: Arc, - _ctx: Arc, - _input: Vec, - cancellation_token: CancellationToken, - ) -> Option { - if self.listen_to_cancellation_token { - cancellation_token.cancelled().await; - return None; - } - loop { - sleep(Duration::from_secs(60)).await; - } - } - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - #[test_log::test] - async fn abort_regular_task_emits_turn_aborted_only() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: false, - }, - ) - .await; - - sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - - // Interrupts persist a model-visible `` marker into history, but there is no - // separate client-visible event for that marker (only `EventMsg::TurnAborted`). - let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - match evt.msg { - EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), - other => panic!("unexpected event: {other:?}"), - } - // No extra events should be emitted after an abort. - assert!(rx.try_recv().is_err()); - } - - #[tokio::test] - async fn abort_gracefully_emits_turn_aborted_only() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: true, - }, - ) - .await; - - sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - - // Even if tasks handle cancellation gracefully, interrupts still result in `TurnAborted` - // being the only client-visible signal. - let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - match evt.msg { - EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), - other => panic!("unexpected event: {other:?}"), - } - // No extra events should be emitted after an abort. - assert!(rx.try_recv().is_err()); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: false, - }, - ) - .await; - - while rx.try_recv().is_ok() {} - - sess.inject_response_items(vec![ResponseInputItem::Message { - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "late pending input".to_string(), - }], - }]) - .await - .expect("inject pending input into active turn"); - - sess.on_task_finished(Arc::clone(&tc), None).await; - - let history = sess.clone_history().await; - let expected = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "late pending input".to_string(), - }], - end_turn: None, - phase: None, - }; - assert!( - history.raw_items().iter().any(|item| item == &expected), - "expected pending input to be persisted into history on turn completion" - ); - - let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected raw response item event") - .expect("channel open"); - assert!(matches!(first.msg, EventMsg::RawResponseItem(_))); - - let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected item started event") - .expect("channel open"); - assert!(matches!( - second.msg, - EventMsg::ItemStarted(ItemStartedEvent { - item: TurnItem::UserMessage(UserMessageItem { content, .. }), - .. - }) if content == vec![UserInput::Text { - text: "late pending input".to_string(), - text_elements: Vec::new(), - }] - )); - - let third = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected item completed event") - .expect("channel open"); - assert!(matches!( - third.msg, - EventMsg::ItemCompleted(ItemCompletedEvent { - item: TurnItem::UserMessage(UserMessageItem { content, .. }), - .. - }) if content == vec![UserInput::Text { - text: "late pending input".to_string(), - text_elements: Vec::new(), - }] - )); - - let fourth = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected legacy user message event") - .expect("channel open"); - assert!(matches!( - fourth.msg, - EventMsg::UserMessage(UserMessageEvent { - message, - images, - text_elements, - local_images, - }) if message == "late pending input" - && images == Some(Vec::new()) - && text_elements.is_empty() - && local_images.is_empty() - )); - - let fifth = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) - .await - .expect("expected turn complete event") - .expect("channel open"); - assert!(matches!( - fifth.msg, - EventMsg::TurnComplete(TurnCompleteEvent { - turn_id, - last_agent_message: None, - }) if turn_id == tc.sub_id - )); - } - - #[tokio::test] - async fn steer_input_requires_active_turn() { - let (sess, _tc, _rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "steer".to_string(), - text_elements: Vec::new(), - }]; - - let err = sess - .steer_input(input, None) - .await - .expect_err("steering without active turn should fail"); - - assert!(matches!(err, SteerInputError::NoActiveTurn(_))); - } - - #[tokio::test] - async fn steer_input_enforces_expected_turn_id() { - let (sess, tc, _rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: false, - }, - ) - .await; - - let steer_input = vec![UserInput::Text { - text: "steer".to_string(), - text_elements: Vec::new(), - }]; - let err = sess - .steer_input(steer_input, Some("different-turn-id")) - .await - .expect_err("mismatched expected turn id should fail"); - - match err { - SteerInputError::ExpectedTurnMismatch { expected, actual } => { - assert_eq!( - (expected, actual), - ("different-turn-id".to_string(), tc.sub_id.clone()) - ); - } - other => panic!("unexpected error: {other:?}"), - } - } - - #[tokio::test] - async fn steer_input_returns_active_turn_id() { - let (sess, tc, _rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task( - Arc::clone(&tc), - input, - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: false, - }, - ) - .await; - - let steer_input = vec![UserInput::Text { - text: "steer".to_string(), - text_elements: Vec::new(), - }]; - let turn_id = sess - .steer_input(steer_input, Some(&tc.sub_id)) - .await - .expect("steering with matching expected turn id should succeed"); - - assert_eq!(turn_id, tc.sub_id); - assert!(sess.has_pending_input().await); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn abort_review_task_emits_exited_then_aborted_and_records_history() { - let (sess, tc, rx) = make_session_and_context_with_rx().await; - let input = vec![UserInput::Text { - text: "start review".to_string(), - text_elements: Vec::new(), - }]; - sess.spawn_task(Arc::clone(&tc), input, ReviewTask::new()) - .await; - - sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - - // Aborting a review task should exit review mode before surfacing the abort to the client. - // We scan for these events (rather than relying on fixed ordering) since unrelated events - // may interleave. - let mut exited_review_mode_idx = None; - let mut turn_aborted_idx = None; - let mut idx = 0usize; - let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(3); - while tokio::time::Instant::now() < deadline { - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - let evt = tokio::time::timeout(remaining, rx.recv()) - .await - .expect("timeout waiting for event") - .expect("event"); - let event_idx = idx; - idx = idx.saturating_add(1); - match evt.msg { - EventMsg::ExitedReviewMode(ev) => { - assert!(ev.review_output.is_none()); - exited_review_mode_idx = Some(event_idx); - } - EventMsg::TurnAborted(ev) => { - assert_eq!(TurnAbortReason::Interrupted, ev.reason); - turn_aborted_idx = Some(event_idx); - break; - } - _ => {} - } - } - assert!( - exited_review_mode_idx.is_some(), - "expected ExitedReviewMode after abort" - ); - assert!( - turn_aborted_idx.is_some(), - "expected TurnAborted after abort" - ); - assert!( - exited_review_mode_idx.unwrap() < turn_aborted_idx.unwrap(), - "expected ExitedReviewMode before TurnAborted" - ); - - let history = sess.clone_history().await; - // The `` marker is silent in the event stream, so verify it is still - // recorded in history for the model. - assert!( - history.raw_items().iter().any(|item| { - let ResponseItem::Message { role, content, .. } = item else { - return false; - }; - if role != "user" { - return false; - } - content.iter().any(|content_item| { - let ContentItem::InputText { text } = content_item else { - return false; - }; - text.contains(crate::contextual_user_message::TURN_ABORTED_OPEN_TAG) - }) - }), - "expected a model-visible turn aborted marker in history after interrupt" - ); - } - - #[tokio::test] - async fn fatal_tool_error_stops_turn_and_reports_error() { - let (session, turn_context, _rx) = make_session_and_context_with_rx().await; - let tools = { - session - .services - .mcp_connection_manager - .read() - .await - .list_all_tools() - .await - }; - let app_tools = Some(tools.clone()); - let router = ToolRouter::from_config( - &turn_context.tools_config, - Some( - tools - .into_iter() - .map(|(name, tool)| (name, tool.tool)) - .collect(), - ), - app_tools, - turn_context.dynamic_tools.as_slice(), - ); - let item = ResponseItem::CustomToolCall { - id: None, - status: None, - call_id: "call-1".to_string(), - name: "shell".to_string(), - input: "{}".to_string(), - }; - - let call = ToolRouter::build_tool_call(session.as_ref(), item.clone()) - .await - .expect("build tool call") - .expect("tool call present"); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - let err = router - .dispatch_tool_call( - Arc::clone(&session), - Arc::clone(&turn_context), - tracker, - call, - ToolCallSource::Direct, - ) - .await - .expect_err("expected fatal error"); - - match err { - FunctionCallError::Fatal(message) => { - assert_eq!(message, "tool shell invoked with incompatible payload"); - } - other => panic!("expected FunctionCallError::Fatal, got {other:?}"), - } - } - - async fn sample_rollout( - session: &Session, - _turn_context: &TurnContext, - ) -> (Vec, Vec) { - let mut rollout_items = Vec::new(); - let mut live_history = ContextManager::new(); - - // Use the same turn_context source as record_initial_history so model_info (and thus - // personality_spec) matches reconstruction. - let reconstruction_turn = session.new_default_turn().await; - let mut initial_context = session - .build_initial_context(reconstruction_turn.as_ref()) - .await; - // Ensure personality_spec is present when Personality is enabled, so expected matches - // what reconstruction produces (build_initial_context may omit it when baked into model). - if !initial_context.iter().any(|m| { - matches!(m, ResponseItem::Message { role, content, .. } - if role == "developer" - && content.iter().any(|c| { - matches!(c, ContentItem::InputText { text } if text.contains("")) - })) - }) - && let Some(p) = reconstruction_turn.personality - && session.features.enabled(Feature::Personality) - && let Some(personality_message) = reconstruction_turn - .model_info - .model_messages - .as_ref() - .and_then(|m| m.get_personality_message(Some(p)).filter(|s| !s.is_empty())) - { - let msg = - DeveloperInstructions::personality_spec_message(personality_message).into(); - let insert_at = initial_context - .iter() - .position(|m| matches!(m, ResponseItem::Message { role, .. } if role == "developer")) - .map(|i| i + 1) - .unwrap_or(0); - initial_context.insert(insert_at, msg); - } - for item in &initial_context { - rollout_items.push(RolloutItem::ResponseItem(item.clone())); - } - live_history.record_items( - initial_context.iter(), - reconstruction_turn.truncation_policy, - ); - - let user1 = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "first user".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&user1), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(user1.clone())); - - let assistant1 = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "assistant reply one".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&assistant1), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(assistant1.clone())); - - let summary1 = "summary one"; - let snapshot1 = live_history - .clone() - .for_prompt(&reconstruction_turn.model_info.input_modalities); - let user_messages1 = collect_user_messages(&snapshot1); - let rebuilt1 = compact::build_compacted_history(Vec::new(), &user_messages1, summary1); - live_history.replace(rebuilt1); - rollout_items.push(RolloutItem::Compacted(CompactedItem { - message: summary1.to_string(), - replacement_history: None, - })); - - let user2 = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "second user".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&user2), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(user2.clone())); - - let assistant2 = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "assistant reply two".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&assistant2), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(assistant2.clone())); - - let summary2 = "summary two"; - let snapshot2 = live_history - .clone() - .for_prompt(&reconstruction_turn.model_info.input_modalities); - let user_messages2 = collect_user_messages(&snapshot2); - let rebuilt2 = compact::build_compacted_history(Vec::new(), &user_messages2, summary2); - live_history.replace(rebuilt2); - rollout_items.push(RolloutItem::Compacted(CompactedItem { - message: summary2.to_string(), - replacement_history: None, - })); - - let user3 = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "third user".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&user3), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(user3)); - - let assistant3 = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "assistant reply three".to_string(), - }], - end_turn: None, - phase: None, - }; - live_history.record_items( - std::iter::once(&assistant3), - reconstruction_turn.truncation_policy, - ); - rollout_items.push(RolloutItem::ResponseItem(assistant3)); - - ( - rollout_items, - live_history.for_prompt(&reconstruction_turn.model_info.input_modalities), - ) - } - - #[tokio::test] - async fn rejects_escalated_permissions_when_policy_not_on_request() { - use crate::exec::ExecParams; - use crate::protocol::AskForApproval; - use crate::protocol::SandboxPolicy; - use crate::sandboxing::SandboxPermissions; - use crate::turn_diff_tracker::TurnDiffTracker; - use std::collections::HashMap; - - let (session, mut turn_context_raw) = make_session_and_context().await; - // Ensure policy is NOT OnRequest so the early rejection path triggers - turn_context_raw - .approval_policy - .set(AskForApproval::OnFailure) - .expect("test setup should allow updating approval policy"); - let session = Arc::new(session); - let mut turn_context = Arc::new(turn_context_raw); - - let timeout_ms = 1000; - let sandbox_permissions = SandboxPermissions::RequireEscalated; - let params = ExecParams { - command: if cfg!(windows) { - vec![ - "cmd.exe".to_string(), - "/C".to_string(), - "echo hi".to_string(), - ] - } else { - vec![ - "/bin/sh".to_string(), - "-c".to_string(), - "echo hi".to_string(), - ] - }, - cwd: turn_context.cwd.clone(), - expiration: timeout_ms.into(), - env: HashMap::new(), - network: None, - sandbox_permissions, - windows_sandbox_level: turn_context.windows_sandbox_level, - justification: Some("test".to_string()), - arg0: None, - }; - - let params2 = ExecParams { - sandbox_permissions: SandboxPermissions::UseDefault, - command: params.command.clone(), - cwd: params.cwd.clone(), - expiration: timeout_ms.into(), - env: HashMap::new(), - network: None, - windows_sandbox_level: turn_context.windows_sandbox_level, - justification: params.justification.clone(), - arg0: None, - }; - - let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - - let tool_name = "shell"; - let call_id = "test-call".to_string(); - - let handler = ShellHandler; - let resp = handler - .handle(ToolInvocation { - session: Arc::clone(&session), - turn: Arc::clone(&turn_context), - tracker: Arc::clone(&turn_diff_tracker), - call_id, - tool_name: tool_name.to_string(), - payload: ToolPayload::Function { - arguments: serde_json::json!({ - "command": params.command.clone(), - "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), - "timeout_ms": params.expiration.timeout_ms(), - "sandbox_permissions": params.sandbox_permissions, - "justification": params.justification.clone(), - }) - .to_string(), - }, - }) - .await; - - let Err(FunctionCallError::RespondToModel(output)) = resp else { - panic!("expected error result"); - }; - - let expected = format!( - "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}", - policy = turn_context.approval_policy.value() - ); - - pretty_assertions::assert_eq!(output, expected); - - // Now retry the same command WITHOUT escalated permissions; should succeed. - // Force DangerFullAccess to avoid platform sandbox dependencies in tests. - Arc::get_mut(&mut turn_context) - .expect("unique turn context Arc") - .sandbox_policy - .set(SandboxPolicy::DangerFullAccess) - .expect("test setup should allow updating sandbox policy"); - - let resp2 = handler - .handle(ToolInvocation { - session: Arc::clone(&session), - turn: Arc::clone(&turn_context), - tracker: Arc::clone(&turn_diff_tracker), - call_id: "test-call-2".to_string(), - tool_name: tool_name.to_string(), - payload: ToolPayload::Function { - arguments: serde_json::json!({ - "command": params2.command.clone(), - "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), - "timeout_ms": params2.expiration.timeout_ms(), - "sandbox_permissions": params2.sandbox_permissions, - "justification": params2.justification.clone(), - }) - .to_string(), - }, - }) - .await; - - let output = match resp2.expect("expected Ok result") { - ToolOutput::Function { - body: FunctionCallOutputBody::Text(content), - .. - } => content, - _ => panic!("unexpected tool output"), - }; - - #[derive(Deserialize, PartialEq, Eq, Debug)] - struct ResponseExecMetadata { - exit_code: i32, - } - - #[derive(Deserialize)] - struct ResponseExecOutput { - output: String, - metadata: ResponseExecMetadata, - } - - let exec_output: ResponseExecOutput = - serde_json::from_str(&output).expect("valid exec output json"); - - pretty_assertions::assert_eq!(exec_output.metadata, ResponseExecMetadata { exit_code: 0 }); - assert!(exec_output.output.contains("hi")); - } - #[tokio::test] - async fn unified_exec_rejects_escalated_permissions_when_policy_not_on_request() { - use crate::protocol::AskForApproval; - use crate::sandboxing::SandboxPermissions; - use crate::turn_diff_tracker::TurnDiffTracker; - - let (session, mut turn_context_raw) = make_session_and_context().await; - turn_context_raw - .approval_policy - .set(AskForApproval::OnFailure) - .expect("test setup should allow updating approval policy"); - let session = Arc::new(session); - let turn_context = Arc::new(turn_context_raw); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - - let handler = UnifiedExecHandler; - let resp = handler - .handle(ToolInvocation { - session: Arc::clone(&session), - turn: Arc::clone(&turn_context), - tracker: Arc::clone(&tracker), - call_id: "exec-call".to_string(), - tool_name: "exec_command".to_string(), - payload: ToolPayload::Function { - arguments: serde_json::json!({ - "cmd": "echo hi", - "sandbox_permissions": SandboxPermissions::RequireEscalated, - "justification": "need unsandboxed execution", - }) - .to_string(), - }, - }) - .await; - - let Err(FunctionCallError::RespondToModel(output)) = resp else { - panic!("expected error result"); - }; - - let expected = format!( - "approval policy is {policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {policy:?}", - policy = turn_context.approval_policy.value() - ); - - pretty_assertions::assert_eq!(output, expected); - } -} +#[path = "codex_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs new file mode 100644 index 000000000..cb44d1cad --- /dev/null +++ b/codex-rs/core/src/codex_tests.rs @@ -0,0 +1,3939 @@ +use super::*; +use crate::CodexAuth; +use crate::config::ConfigBuilder; +use crate::config::test_config; +use crate::config_loader::ConfigLayerStack; +use crate::config_loader::ConfigLayerStackOrdering; +use crate::config_loader::NetworkConstraints; +use crate::config_loader::RequirementSource; +use crate::config_loader::Sourced; +use crate::exec::ExecToolCallOutput; +use crate::function_tool::FunctionCallError; +use crate::mcp_connection_manager::ToolInfo; +use crate::models_manager::model_info; +use crate::shell::default_user_shell; +use crate::tools::format_exec_output_str; + +use codex_protocol::ThreadId; +use codex_protocol::models::FunctionCallOutputBody; +use codex_protocol::models::FunctionCallOutputPayload; +use tracing::Span; + +use crate::protocol::CompactedItem; +use crate::protocol::CreditsSnapshot; +use crate::protocol::InitialHistory; +use crate::protocol::NetworkApprovalProtocol; +use crate::protocol::RateLimitSnapshot; +use crate::protocol::RateLimitWindow; +use crate::protocol::ResumedHistory; +use crate::protocol::TokenCountEvent; +use crate::protocol::TokenUsage; +use crate::protocol::TokenUsageInfo; +use crate::protocol::TurnCompleteEvent; +use crate::protocol::UserMessageEvent; +use crate::rollout::policy::EventPersistenceMode; +use crate::rollout::recorder::RolloutRecorder; +use crate::rollout::recorder::RolloutRecorderParams; +use crate::state::TaskKind; +use crate::tasks::SessionTask; +use crate::tasks::SessionTaskContext; +use crate::tools::ToolRouter; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::ShellHandler; +use crate::tools::handlers::UnifiedExecHandler; +use crate::tools::registry::ToolHandler; +use crate::tools::router::ToolCallSource; +use crate::turn_diff_tracker::TurnDiffTracker; +use codex_app_server_protocol::AppInfo; +use codex_otel::TelemetryAuthMode; +use codex_protocol::models::BaseInstructions; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseInputItem; +use codex_protocol::models::ResponseItem; +use codex_protocol::openai_models::ModelsResponse; +use codex_protocol::protocol::ConversationAudioParams; +use codex_protocol::protocol::RealtimeAudioFrame; +use codex_protocol::protocol::Submission; +use codex_protocol::protocol::W3cTraceContext; +use opentelemetry::trace::TraceContextExt; +use opentelemetry::trace::TraceId; +use opentelemetry::trace::TracerProvider as _; +use opentelemetry_sdk::trace::SdkTracerProvider; +use std::path::Path; +use std::time::Duration; +use tokio::time::sleep; +use tracing_opentelemetry::OpenTelemetrySpanExt; +use tracing_subscriber::prelude::*; + +use codex_protocol::mcp::CallToolResult as McpCallToolResult; +use pretty_assertions::assert_eq; +use rmcp::model::JsonObject; +use rmcp::model::Tool; +use serde::Deserialize; +use serde_json::json; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::Once; +use std::time::Duration as StdDuration; + +struct InstructionsTestCase { + slug: &'static str, + expects_apply_patch_instructions: bool, +} + +fn user_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } +} + +fn assistant_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } +} + +fn skill_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } +} + +fn developer_input_texts(items: &[ResponseItem]) -> Vec<&str> { + items + .iter() + .filter_map(|item| match item { + ResponseItem::Message { role, content, .. } if role == "developer" => { + Some(content.as_slice()) + } + _ => None, + }) + .flat_map(|content| content.iter()) + .filter_map(|item| match item { + ContentItem::InputText { text } => Some(text.as_str()), + _ => None, + }) + .collect() +} + +fn make_connector(id: &str, name: &str) -> AppInfo { + AppInfo { + id: id.to_string(), + name: name.to_string(), + description: None, + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: None, + is_accessible: true, + is_enabled: true, + plugin_display_names: Vec::new(), + } +} + +#[test] +fn assistant_message_stream_parsers_can_be_seeded_from_output_item_added_text() { + let mut parsers = AssistantMessageStreamParsers::new(false); + let item_id = "msg-1"; + + let seeded = parsers.seed_item_text(item_id, "hello doc"); + let parsed = parsers.parse_delta(item_id, "1 world"); + let tail = parsers.finish_item(item_id); + + assert_eq!(seeded.visible_text, "hello "); + assert_eq!(seeded.citations, Vec::::new()); + assert_eq!(parsed.visible_text, " world"); + assert_eq!(parsed.citations, vec!["doc1".to_string()]); + assert_eq!(tail.visible_text, ""); + assert_eq!(tail.citations, Vec::::new()); +} + +#[test] +fn assistant_message_stream_parsers_seed_buffered_prefix_stays_out_of_finish_tail() { + let mut parsers = AssistantMessageStreamParsers::new(false); + let item_id = "msg-1"; + + let seeded = parsers.seed_item_text(item_id, "hello doc world"); + let tail = parsers.finish_item(item_id); + + assert_eq!(seeded.visible_text, "hello "); + assert_eq!(seeded.citations, Vec::::new()); + assert_eq!(parsed.visible_text, " world"); + assert_eq!(parsed.citations, vec!["doc".to_string()]); + assert_eq!(tail.visible_text, ""); + assert_eq!(tail.citations, Vec::::new()); +} + +#[test] +fn assistant_message_stream_parsers_seed_plan_parser_across_added_and_delta_boundaries() { + let mut parsers = AssistantMessageStreamParsers::new(true); + let item_id = "msg-1"; + + let seeded = parsers.seed_item_text(item_id, "Intro\n\n- step\n\nOutro"); + let tail = parsers.finish_item(item_id); + + assert_eq!(seeded.visible_text, "Intro\n"); + assert_eq!( + seeded.plan_segments, + vec![ProposedPlanSegment::Normal("Intro\n".to_string())] + ); + assert_eq!(parsed.visible_text, "Outro"); + assert_eq!( + parsed.plan_segments, + vec![ + ProposedPlanSegment::ProposedPlanStart, + ProposedPlanSegment::ProposedPlanDelta("- step\n".to_string()), + ProposedPlanSegment::ProposedPlanEnd, + ProposedPlanSegment::Normal("Outro".to_string()), + ] + ); + assert_eq!(tail.visible_text, ""); + assert!(tail.plan_segments.is_empty()); +} + +fn make_mcp_tool( + server_name: &str, + tool_name: &str, + connector_id: Option<&str>, + connector_name: Option<&str>, +) -> ToolInfo { + ToolInfo { + server_name: server_name.to_string(), + tool_name: tool_name.to_string(), + tool: Tool { + name: tool_name.to_string().into(), + title: None, + description: Some(format!("Test tool: {tool_name}").into()), + input_schema: Arc::new(JsonObject::default()), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: connector_id.map(str::to_string), + connector_name: connector_name.map(str::to_string), + plugin_display_names: Vec::new(), + } +} + +fn function_call_rollout_item(name: &str, call_id: &str) -> RolloutItem { + RolloutItem::ResponseItem(ResponseItem::FunctionCall { + id: None, + name: name.to_string(), + arguments: "{}".to_string(), + call_id: call_id.to_string(), + }) +} + +fn function_call_output_rollout_item(call_id: &str, output: &str) -> RolloutItem { + RolloutItem::ResponseItem(ResponseItem::FunctionCallOutput { + call_id: call_id.to_string(), + output: FunctionCallOutputPayload::from_text(output.to_string()), + }) +} + +#[test] +fn validated_network_policy_amendment_host_allows_normalized_match() { + let amendment = NetworkPolicyAmendment { + host: "ExAmPlE.Com.:443".to_string(), + action: NetworkPolicyRuleAction::Allow, + }; + let context = NetworkApprovalContext { + host: "example.com".to_string(), + protocol: NetworkApprovalProtocol::Https, + }; + + let host = Session::validated_network_policy_amendment_host(&amendment, &context) + .expect("normalized hosts should match"); + + assert_eq!(host, "example.com"); +} + +#[test] +fn validated_network_policy_amendment_host_rejects_mismatch() { + let amendment = NetworkPolicyAmendment { + host: "evil.example.com".to_string(), + action: NetworkPolicyRuleAction::Deny, + }; + let context = NetworkApprovalContext { + host: "api.example.com".to_string(), + protocol: NetworkApprovalProtocol::Https, + }; + + let err = Session::validated_network_policy_amendment_host(&amendment, &context) + .expect_err("mismatched hosts should be rejected"); + + let message = err.to_string(); + assert!(message.contains("does not match approved host")); +} + +#[tokio::test] +async fn get_base_instructions_no_user_content() { + let prompt_with_apply_patch_instructions = + include_str!("../prompt_with_apply_patch_instructions.md"); + let models_response: ModelsResponse = + serde_json::from_str(include_str!("../models.json")).expect("valid models.json"); + let model_info_for_slug = |slug: &str, config: &Config| { + let model = models_response + .models + .iter() + .find(|candidate| candidate.slug == slug) + .cloned() + .unwrap_or_else(|| panic!("model slug {slug} is missing from models.json")); + model_info::with_config_overrides(model, config) + }; + let test_cases = vec![ + InstructionsTestCase { + slug: "gpt-5", + expects_apply_patch_instructions: false, + }, + InstructionsTestCase { + slug: "gpt-5.1", + expects_apply_patch_instructions: false, + }, + InstructionsTestCase { + slug: "gpt-5.1-codex", + expects_apply_patch_instructions: false, + }, + InstructionsTestCase { + slug: "gpt-5.1-codex-max", + expects_apply_patch_instructions: false, + }, + ]; + + let (session, _turn_context) = make_session_and_context().await; + let config = test_config(); + + for test_case in test_cases { + let model_info = model_info_for_slug(test_case.slug, &config); + if test_case.expects_apply_patch_instructions { + assert_eq!( + model_info.base_instructions.as_str(), + prompt_with_apply_patch_instructions + ); + } + + { + let mut state = session.state.lock().await; + state.session_configuration.base_instructions = model_info.base_instructions.clone(); + } + + let base_instructions = session.get_base_instructions().await; + assert_eq!(base_instructions.text, model_info.base_instructions); + } +} + +#[tokio::test] +async fn reload_user_config_layer_updates_effective_apps_config() { + let (session, _turn_context) = make_session_and_context().await; + let codex_home = session.codex_home().await; + std::fs::create_dir_all(&codex_home).expect("create codex home"); + let config_toml_path = codex_home.join(CONFIG_TOML_FILE); + std::fs::write( + &config_toml_path, + "[apps.calendar]\nenabled = false\ndestructive_enabled = false\n", + ) + .expect("write user config"); + + session.reload_user_config_layer().await; + + let config = session.get_config().await; + let apps_toml = config + .config_layer_stack + .effective_config() + .as_table() + .and_then(|table| table.get("apps")) + .cloned() + .expect("apps table"); + let apps = crate::config::types::AppsConfigToml::deserialize(apps_toml) + .expect("deserialize apps config"); + let app = apps + .apps + .get("calendar") + .expect("calendar app config exists"); + + assert!(!app.enabled); + assert_eq!(app.destructive_enabled, Some(false)); +} + +#[test] +fn filter_connectors_for_input_skips_duplicate_slug_mentions() { + let connectors = vec![ + make_connector("one", "Foo Bar"), + make_connector("two", "Foo-Bar"), + ]; + let input = vec![user_message("use $foo-bar")]; + let explicitly_enabled_connectors = HashSet::new(); + let skill_name_counts_lower = HashMap::new(); + + let selected = filter_connectors_for_input( + &connectors, + &input, + &explicitly_enabled_connectors, + &skill_name_counts_lower, + ); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn filter_connectors_for_input_skips_when_skill_name_conflicts() { + let connectors = vec![make_connector("one", "Todoist")]; + let input = vec![user_message("use $todoist")]; + let explicitly_enabled_connectors = HashSet::new(); + let skill_name_counts_lower = HashMap::from([("todoist".to_string(), 1)]); + + let selected = filter_connectors_for_input( + &connectors, + &input, + &explicitly_enabled_connectors, + &skill_name_counts_lower, + ); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn filter_connectors_for_input_skips_disabled_connectors() { + let mut connector = make_connector("calendar", "Calendar"); + connector.is_enabled = false; + let input = vec![user_message("use $calendar")]; + let explicitly_enabled_connectors = HashSet::new(); + let selected = filter_connectors_for_input( + &[connector], + &input, + &explicitly_enabled_connectors, + &HashMap::new(), + ); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn collect_explicit_app_ids_from_skill_items_includes_linked_mentions() { + let connectors = vec![make_connector("calendar", "Calendar")]; + let skill_items = vec![skill_message( + "\ndemo\n/tmp/skills/demo/SKILL.md\nuse [$calendar](app://calendar)\n", + )]; + + let connector_ids = + collect_explicit_app_ids_from_skill_items(&skill_items, &connectors, &HashMap::new()); + + assert_eq!(connector_ids, HashSet::from(["calendar".to_string()])); +} + +#[test] +fn collect_explicit_app_ids_from_skill_items_resolves_unambiguous_plain_mentions() { + let connectors = vec![make_connector("calendar", "Calendar")]; + let skill_items = vec![skill_message( + "\ndemo\n/tmp/skills/demo/SKILL.md\nuse $calendar\n", + )]; + + let connector_ids = + collect_explicit_app_ids_from_skill_items(&skill_items, &connectors, &HashMap::new()); + + assert_eq!(connector_ids, HashSet::from(["calendar".to_string()])); +} + +#[test] +fn collect_explicit_app_ids_from_skill_items_skips_plain_mentions_with_skill_conflicts() { + let connectors = vec![make_connector("calendar", "Calendar")]; + let skill_items = vec![skill_message( + "\ndemo\n/tmp/skills/demo/SKILL.md\nuse $calendar\n", + )]; + let skill_name_counts_lower = HashMap::from([("calendar".to_string(), 1)]); + + let connector_ids = collect_explicit_app_ids_from_skill_items( + &skill_items, + &connectors, + &skill_name_counts_lower, + ); + + assert_eq!(connector_ids, HashSet::::new()); +} + +#[test] +fn non_app_mcp_tools_remain_visible_without_search_selection() { + let mcp_tools = HashMap::from([ + ( + "mcp__codex_apps__calendar_create_event".to_string(), + make_mcp_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_create_event", + Some("calendar"), + Some("Calendar"), + ), + ), + ( + "mcp__rmcp__echo".to_string(), + make_mcp_tool("rmcp", "echo", None, None), + ), + ]); + + let mut selected_mcp_tools = mcp_tools + .iter() + .filter(|(_, tool)| tool.server_name != CODEX_APPS_MCP_SERVER_NAME) + .map(|(name, tool)| (name.clone(), tool.clone())) + .collect::>(); + + let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); + let explicitly_enabled_connectors = HashSet::new(); + let connectors = filter_connectors_for_input( + &connectors, + &[user_message("run echo")], + &explicitly_enabled_connectors, + &HashMap::new(), + ); + let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); + selected_mcp_tools.extend(apps_mcp_tools); + + let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); + tool_names.sort(); + assert_eq!(tool_names, vec!["mcp__rmcp__echo".to_string()]); +} + +#[test] +fn search_tool_selection_keeps_codex_apps_tools_without_mentions() { + let selected_tool_names = vec![ + "mcp__codex_apps__calendar_create_event".to_string(), + "mcp__rmcp__echo".to_string(), + ]; + let mcp_tools = HashMap::from([ + ( + "mcp__codex_apps__calendar_create_event".to_string(), + make_mcp_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_create_event", + Some("calendar"), + Some("Calendar"), + ), + ), + ( + "mcp__rmcp__echo".to_string(), + make_mcp_tool("rmcp", "echo", None, None), + ), + ]); + + let mut selected_mcp_tools = filter_mcp_tools_by_name(&mcp_tools, &selected_tool_names); + let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); + let explicitly_enabled_connectors = HashSet::new(); + let connectors = filter_connectors_for_input( + &connectors, + &[user_message("run the selected tools")], + &explicitly_enabled_connectors, + &HashMap::new(), + ); + let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); + selected_mcp_tools.extend(apps_mcp_tools); + + let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); + tool_names.sort(); + assert_eq!( + tool_names, + vec![ + "mcp__codex_apps__calendar_create_event".to_string(), + "mcp__rmcp__echo".to_string(), + ] + ); +} + +#[test] +fn apps_mentions_add_codex_apps_tools_to_search_selected_set() { + let selected_tool_names = vec!["mcp__rmcp__echo".to_string()]; + let mcp_tools = HashMap::from([ + ( + "mcp__codex_apps__calendar_create_event".to_string(), + make_mcp_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_create_event", + Some("calendar"), + Some("Calendar"), + ), + ), + ( + "mcp__rmcp__echo".to_string(), + make_mcp_tool("rmcp", "echo", None, None), + ), + ]); + + let mut selected_mcp_tools = filter_mcp_tools_by_name(&mcp_tools, &selected_tool_names); + let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools); + let explicitly_enabled_connectors = HashSet::new(); + let connectors = filter_connectors_for_input( + &connectors, + &[user_message("use $calendar and then echo the response")], + &explicitly_enabled_connectors, + &HashMap::new(), + ); + let apps_mcp_tools = filter_codex_apps_mcp_tools_only(&mcp_tools, &connectors); + selected_mcp_tools.extend(apps_mcp_tools); + + let mut tool_names: Vec = selected_mcp_tools.into_keys().collect(); + tool_names.sort(); + assert_eq!( + tool_names, + vec![ + "mcp__codex_apps__calendar_create_event".to_string(), + "mcp__rmcp__echo".to_string(), + ] + ); +} + +#[test] +fn extract_mcp_tool_selection_from_rollout_reads_search_tool_output() { + let rollout_items = vec![ + function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), + function_call_output_rollout_item( + "search-1", + &json!({ + "active_selected_tools": [ + "mcp__codex_apps__calendar_create_event", + "mcp__codex_apps__calendar_list_events", + ], + }) + .to_string(), + ), + ]; + + let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); + assert_eq!( + selected, + Some(vec![ + "mcp__codex_apps__calendar_create_event".to_string(), + "mcp__codex_apps__calendar_list_events".to_string(), + ]) + ); +} + +#[test] +fn extract_mcp_tool_selection_from_rollout_latest_valid_payload_wins() { + let rollout_items = vec![ + function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), + function_call_output_rollout_item( + "search-1", + &json!({ + "active_selected_tools": ["mcp__codex_apps__calendar_create_event"], + }) + .to_string(), + ), + function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-2"), + function_call_output_rollout_item( + "search-2", + &json!({ + "active_selected_tools": ["mcp__codex_apps__calendar_delete_event"], + }) + .to_string(), + ), + ]; + + let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); + assert_eq!( + selected, + Some(vec!["mcp__codex_apps__calendar_delete_event".to_string(),]) + ); +} + +#[test] +fn extract_mcp_tool_selection_from_rollout_ignores_non_search_and_malformed_payloads() { + let rollout_items = vec![ + function_call_rollout_item("shell", "shell-1"), + function_call_output_rollout_item( + "shell-1", + &json!({ + "active_selected_tools": ["mcp__codex_apps__should_be_ignored"], + }) + .to_string(), + ), + function_call_rollout_item(SEARCH_TOOL_BM25_TOOL_NAME, "search-1"), + function_call_output_rollout_item("search-1", "{not-json"), + function_call_output_rollout_item( + "unknown-search-call", + &json!({ + "active_selected_tools": ["mcp__codex_apps__also_ignored"], + }) + .to_string(), + ), + function_call_output_rollout_item( + "search-1", + &json!({ + "active_selected_tools": ["mcp__codex_apps__calendar_list_events"], + }) + .to_string(), + ), + ]; + + let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); + assert_eq!( + selected, + Some(vec!["mcp__codex_apps__calendar_list_events".to_string(),]) + ); +} + +#[test] +fn extract_mcp_tool_selection_from_rollout_returns_none_without_valid_search_output() { + let rollout_items = vec![function_call_rollout_item( + SEARCH_TOOL_BM25_TOOL_NAME, + "search-1", + )]; + let selected = Session::extract_mcp_tool_selection_from_rollout(&rollout_items); + assert_eq!(selected, None); +} + +#[tokio::test] +async fn reconstruct_history_matches_live_compactions() { + let (session, turn_context) = make_session_and_context().await; + let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; + + let reconstruction_turn = session.new_default_turn().await; + let reconstructed = session + .reconstruct_history_from_rollout(reconstruction_turn.as_ref(), &rollout_items) + .await; + + assert_eq!(expected, reconstructed.history); +} + +#[tokio::test] +async fn reconstruct_history_uses_replacement_history_verbatim() { + let (session, turn_context) = make_session_and_context().await; + let summary_item = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }; + let replacement_history = vec![ + summary_item.clone(), + ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "stale developer instructions".to_string(), + }], + end_turn: None, + phase: None, + }, + ]; + let rollout_items = vec![RolloutItem::Compacted(CompactedItem { + message: String::new(), + replacement_history: Some(replacement_history.clone()), + })]; + + let reconstructed = session + .reconstruct_history_from_rollout(&turn_context, &rollout_items) + .await; + + assert_eq!(reconstructed.history, replacement_history); +} + +#[tokio::test] +async fn record_initial_history_reconstructs_resumed_transcript() { + let (session, turn_context) = make_session_and_context().await; + let (rollout_items, expected) = sample_rollout(&session, &turn_context).await; + + session + .record_initial_history(InitialHistory::Resumed(ResumedHistory { + conversation_id: ThreadId::default(), + history: rollout_items, + rollout_path: PathBuf::from("/tmp/resume.jsonl"), + })) + .await; + + let history = session.state.lock().await.clone_history(); + assert_eq!(expected, history.raw_items()); +} + +#[tokio::test] +async fn resumed_history_injects_initial_context_on_first_context_update_only() { + let (session, turn_context) = make_session_and_context().await; + let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; + + session + .record_initial_history(InitialHistory::Resumed(ResumedHistory { + conversation_id: ThreadId::default(), + history: rollout_items, + rollout_path: PathBuf::from("/tmp/resume.jsonl"), + })) + .await; + + let history_before_seed = session.state.lock().await.clone_history(); + assert_eq!(expected, history_before_seed.raw_items()); + + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + expected.extend(session.build_initial_context(&turn_context).await); + let history_after_seed = session.clone_history().await; + assert_eq!(expected, history_after_seed.raw_items()); + + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + let history_after_second_seed = session.clone_history().await; + assert_eq!( + history_after_seed.raw_items(), + history_after_second_seed.raw_items() + ); +} + +#[tokio::test] +async fn record_initial_history_seeds_token_info_from_rollout() { + let (session, turn_context) = make_session_and_context().await; + let (mut rollout_items, _expected) = sample_rollout(&session, &turn_context).await; + + let info1 = TokenUsageInfo { + total_token_usage: TokenUsage { + input_tokens: 10, + cached_input_tokens: 0, + output_tokens: 20, + reasoning_output_tokens: 0, + total_tokens: 30, + }, + last_token_usage: TokenUsage { + input_tokens: 3, + cached_input_tokens: 0, + output_tokens: 4, + reasoning_output_tokens: 0, + total_tokens: 7, + }, + model_context_window: Some(1_000), + }; + let info2 = TokenUsageInfo { + total_token_usage: TokenUsage { + input_tokens: 100, + cached_input_tokens: 50, + output_tokens: 200, + reasoning_output_tokens: 25, + total_tokens: 375, + }, + last_token_usage: TokenUsage { + input_tokens: 10, + cached_input_tokens: 0, + output_tokens: 20, + reasoning_output_tokens: 5, + total_tokens: 35, + }, + model_context_window: Some(2_000), + }; + + rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( + TokenCountEvent { + info: Some(info1), + rate_limits: None, + }, + ))); + rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( + TokenCountEvent { + info: None, + rate_limits: None, + }, + ))); + rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( + TokenCountEvent { + info: Some(info2.clone()), + rate_limits: None, + }, + ))); + rollout_items.push(RolloutItem::EventMsg(EventMsg::TokenCount( + TokenCountEvent { + info: None, + rate_limits: None, + }, + ))); + + session + .record_initial_history(InitialHistory::Resumed(ResumedHistory { + conversation_id: ThreadId::default(), + history: rollout_items, + rollout_path: PathBuf::from("/tmp/resume.jsonl"), + })) + .await; + + let actual = session.state.lock().await.token_info(); + assert_eq!(actual, Some(info2)); +} + +#[tokio::test] +async fn recompute_token_usage_uses_session_base_instructions() { + let (session, turn_context) = make_session_and_context().await; + + let override_instructions = "SESSION_OVERRIDE_INSTRUCTIONS_ONLY".repeat(120); + { + let mut state = session.state.lock().await; + state.session_configuration.base_instructions = override_instructions.clone(); + } + + let item = user_message("hello"); + session + .record_into_history(std::slice::from_ref(&item), &turn_context) + .await; + + let history = session.clone_history().await; + let session_base_instructions = BaseInstructions { + text: override_instructions, + }; + let expected_tokens = history + .estimate_token_count_with_base_instructions(&session_base_instructions) + .expect("estimate with session base instructions"); + let model_estimated_tokens = history + .estimate_token_count(&turn_context) + .expect("estimate with model instructions"); + assert_ne!(expected_tokens, model_estimated_tokens); + + session.recompute_token_usage(&turn_context).await; + + let actual_tokens = session + .state + .lock() + .await + .token_info() + .expect("token info") + .last_token_usage + .total_tokens; + assert_eq!(actual_tokens, expected_tokens.max(0)); +} + +#[tokio::test] +async fn recompute_token_usage_updates_model_context_window() { + let (session, mut turn_context) = make_session_and_context().await; + + { + let mut state = session.state.lock().await; + state.set_token_info(Some(TokenUsageInfo { + total_token_usage: TokenUsage::default(), + last_token_usage: TokenUsage::default(), + model_context_window: Some(258_400), + })); + } + + turn_context.model_info.context_window = Some(128_000); + turn_context.model_info.effective_context_window_percent = 100; + + session.recompute_token_usage(&turn_context).await; + + let actual = session.state.lock().await.token_info().expect("token info"); + assert_eq!(actual.model_context_window, Some(128_000)); +} + +#[tokio::test] +async fn record_initial_history_reconstructs_forked_transcript() { + let (session, turn_context) = make_session_and_context().await; + let (rollout_items, mut expected) = sample_rollout(&session, &turn_context).await; + + session + .record_initial_history(InitialHistory::Forked(rollout_items)) + .await; + + let reconstruction_turn = session.new_default_turn().await; + expected.extend( + session + .build_initial_context(reconstruction_turn.as_ref()) + .await, + ); + let history = session.state.lock().await.clone_history(); + assert_eq!(expected, history.raw_items()); +} + +#[tokio::test] +async fn record_initial_history_forked_hydrates_previous_turn_settings() { + let (session, turn_context) = make_session_and_context().await; + let previous_model = "forked-rollout-model"; + let previous_context_item = TurnContextItem { + turn_id: Some(turn_context.sub_id.clone()), + trace_id: turn_context.trace_id.clone(), + cwd: turn_context.cwd.clone(), + current_date: turn_context.current_date.clone(), + timezone: turn_context.timezone.clone(), + approval_policy: turn_context.approval_policy.value(), + sandbox_policy: turn_context.sandbox_policy.get().clone(), + network: None, + model: previous_model.to_string(), + personality: turn_context.personality, + collaboration_mode: Some(turn_context.collaboration_mode.clone()), + realtime_active: Some(turn_context.realtime_active), + effort: turn_context.reasoning_effort, + summary: turn_context.reasoning_summary, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: Some(turn_context.truncation_policy.into()), + }; + let turn_id = previous_context_item + .turn_id + .clone() + .expect("turn context should have turn_id"); + let rollout_items = vec![ + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage( + codex_protocol::protocol::UserMessageEvent { + message: "forked seed".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + )), + RolloutItem::TurnContext(previous_context_item), + RolloutItem::EventMsg(EventMsg::TurnComplete( + codex_protocol::protocol::TurnCompleteEvent { + turn_id, + last_agent_message: None, + }, + )), + ]; + + session + .record_initial_history(InitialHistory::Forked(rollout_items)) + .await; + + assert_eq!( + session.previous_turn_settings().await, + Some(PreviousTurnSettings { + model: previous_model.to_string(), + realtime_active: Some(turn_context.realtime_active), + }) + ); +} + +#[tokio::test] +async fn thread_rollback_drops_last_turn_from_history() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let rollout_path = attach_rollout_recorder(&sess).await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + let turn_1 = vec![ + user_message("turn 1 user"), + assistant_message("turn 1 assistant"), + ]; + let turn_2 = vec![ + user_message("turn 2 user"), + assistant_message("turn 2 assistant"), + ]; + let mut full_history = Vec::new(); + full_history.extend(initial_context.clone()); + full_history.extend(turn_1.clone()); + full_history.extend(turn_2); + sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item())) + .await; + let rollout_items: Vec = full_history + .into_iter() + .map(RolloutItem::ResponseItem) + .collect(); + sess.persist_rollout_items(&rollout_items).await; + sess.set_previous_turn_settings(Some(PreviousTurnSettings { + model: "stale-model".to_string(), + realtime_active: Some(tc.realtime_active), + })) + .await; + { + let mut state = sess.state.lock().await; + state.set_reference_context_item(Some(tc.to_turn_context_item())); + } + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 1); + + let mut expected = Vec::new(); + expected.extend(initial_context); + expected.extend(turn_1); + + let history = sess.clone_history().await; + assert_eq!(expected, history.raw_items()); + assert_eq!(sess.previous_turn_settings().await, None); + assert!(sess.reference_context_item().await.is_none()); + + let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) + .await + .expect("read rollout history") + else { + panic!("expected resumed rollout history"); + }; + assert!(resumed.history.iter().any(|item| { + matches!( + item, + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) + if rollback.num_turns == 1 + ) + })); +} + +#[tokio::test] +async fn thread_rollback_clears_history_when_num_turns_exceeds_existing_turns() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + attach_rollout_recorder(&sess).await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + let turn_1 = vec![user_message("turn 1 user")]; + let mut full_history = Vec::new(); + full_history.extend(initial_context.clone()); + full_history.extend(turn_1); + sess.replace_history(full_history.clone(), Some(tc.to_turn_context_item())) + .await; + let rollout_items: Vec = full_history + .into_iter() + .map(RolloutItem::ResponseItem) + .collect(); + sess.persist_rollout_items(&rollout_items).await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 99).await; + + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 99); + + let history = sess.clone_history().await; + assert_eq!(initial_context, history.raw_items()); +} + +#[tokio::test] +async fn thread_rollback_fails_without_persisted_rollout_path() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + + let error_event = wait_for_thread_rollback_failed(&rx).await; + assert_eq!( + error_event.message, + "thread rollback requires a persisted rollout path" + ); + assert_eq!( + error_event.codex_error_info, + Some(CodexErrorInfo::ThreadRollbackFailed) + ); + assert_eq!(sess.clone_history().await.raw_items(), initial_context); +} + +#[tokio::test] +async fn thread_rollback_recomputes_previous_turn_settings_and_reference_context_from_replay() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + attach_rollout_recorder(&sess).await; + + let first_context_item = tc.to_turn_context_item(); + let first_turn_id = first_context_item + .turn_id + .clone() + .expect("turn context should have turn_id"); + let mut rolled_back_context_item = first_context_item.clone(); + rolled_back_context_item.turn_id = Some("rolled-back-turn".to_string()); + rolled_back_context_item.model = "rolled-back-model".to_string(); + let rolled_back_turn_id = rolled_back_context_item + .turn_id + .clone() + .expect("turn context should have turn_id"); + let turn_one_user = user_message("turn 1 user"); + let turn_one_assistant = assistant_message("turn 1 assistant"); + let turn_two_user = user_message("turn 2 user"); + let turn_two_assistant = assistant_message("turn 2 assistant"); + + sess.persist_rollout_items(&[ + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: first_turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage( + codex_protocol::protocol::UserMessageEvent { + message: "turn 1 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + )), + RolloutItem::TurnContext(first_context_item.clone()), + RolloutItem::ResponseItem(turn_one_user.clone()), + RolloutItem::ResponseItem(turn_one_assistant.clone()), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: first_turn_id, + last_agent_message: None, + })), + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: rolled_back_turn_id.clone(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage( + codex_protocol::protocol::UserMessageEvent { + message: "turn 2 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + )), + RolloutItem::TurnContext(rolled_back_context_item), + RolloutItem::ResponseItem(turn_two_user), + RolloutItem::ResponseItem(turn_two_assistant), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: rolled_back_turn_id, + last_agent_message: None, + })), + ]) + .await; + sess.replace_history( + vec![assistant_message("stale history")], + Some(first_context_item.clone()), + ) + .await; + sess.set_previous_turn_settings(Some(PreviousTurnSettings { + model: "stale-model".to_string(), + realtime_active: None, + })) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + let rollback_event = wait_for_thread_rolled_back(&rx).await; + assert_eq!(rollback_event.num_turns, 1); + + assert_eq!( + sess.clone_history().await.raw_items(), + vec![turn_one_user, turn_one_assistant] + ); + assert_eq!( + sess.previous_turn_settings().await, + Some(PreviousTurnSettings { + model: tc.model_info.slug.clone(), + realtime_active: Some(tc.realtime_active), + }) + ); + assert_eq!( + serde_json::to_value(sess.reference_context_item().await) + .expect("serialize replay reference context item"), + serde_json::to_value(Some(first_context_item)) + .expect("serialize expected reference context item") + ); +} + +#[tokio::test] +async fn thread_rollback_persists_marker_and_replays_cumulatively() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let rollout_path = attach_rollout_recorder(&sess).await; + let turn_context_item = tc.to_turn_context_item(); + + sess.persist_rollout_items(&[ + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: "turn-1".to_string(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "turn 1 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + RolloutItem::TurnContext(turn_context_item.clone()), + RolloutItem::ResponseItem(user_message("turn 1 user")), + RolloutItem::ResponseItem(assistant_message("turn 1 assistant")), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-1".to_string(), + last_agent_message: None, + })), + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: "turn-2".to_string(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "turn 2 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + RolloutItem::TurnContext(turn_context_item.clone()), + RolloutItem::ResponseItem(user_message("turn 2 user")), + RolloutItem::ResponseItem(assistant_message("turn 2 assistant")), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-2".to_string(), + last_agent_message: None, + })), + RolloutItem::EventMsg(EventMsg::TurnStarted( + codex_protocol::protocol::TurnStartedEvent { + turn_id: "turn-3".to_string(), + model_context_window: Some(128_000), + collaboration_mode_kind: ModeKind::Default, + }, + )), + RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "turn 3 user".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + RolloutItem::TurnContext(turn_context_item), + RolloutItem::ResponseItem(user_message("turn 3 user")), + RolloutItem::ResponseItem(assistant_message("turn 3 assistant")), + RolloutItem::EventMsg(EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-3".to_string(), + last_agent_message: None, + })), + ]) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + let first_rollback = wait_for_thread_rolled_back(&rx).await; + assert_eq!(first_rollback.num_turns, 1); + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + let second_rollback = wait_for_thread_rolled_back(&rx).await; + assert_eq!(second_rollback.num_turns, 1); + + assert_eq!( + sess.clone_history().await.raw_items(), + vec![ + user_message("turn 1 user"), + assistant_message("turn 1 assistant") + ] + ); + + let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) + .await + .expect("read rollout history") + else { + panic!("expected resumed rollout history"); + }; + let rollback_markers = resumed + .history + .iter() + .filter(|item| matches!(item, RolloutItem::EventMsg(EventMsg::ThreadRolledBack(_)))) + .count(); + assert_eq!(rollback_markers, 2); +} + +#[tokio::test] +async fn thread_rollback_fails_when_turn_in_progress() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + *sess.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await; + + let error_event = wait_for_thread_rollback_failed(&rx).await; + assert_eq!( + error_event.codex_error_info, + Some(CodexErrorInfo::ThreadRollbackFailed) + ); + + let history = sess.clone_history().await; + assert_eq!(initial_context, history.raw_items()); +} + +#[tokio::test] +async fn thread_rollback_fails_when_num_turns_is_zero() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + + let initial_context = sess.build_initial_context(tc.as_ref()).await; + sess.record_into_history(&initial_context, tc.as_ref()) + .await; + + handlers::thread_rollback(&sess, "sub-1".to_string(), 0).await; + + let error_event = wait_for_thread_rollback_failed(&rx).await; + assert_eq!(error_event.message, "num_turns must be >= 1"); + assert_eq!( + error_event.codex_error_info, + Some(CodexErrorInfo::ThreadRollbackFailed) + ); + + let history = sess.clone_history().await; + assert_eq!(initial_context, history.raw_items()); +} + +#[tokio::test] +async fn set_rate_limits_retains_previous_credits() { + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + + let mut state = SessionState::new(session_configuration); + let initial = RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 10.0, + window_minutes: Some(15), + resets_at: Some(1_700), + }), + secondary: None, + credits: Some(CreditsSnapshot { + has_credits: true, + unlimited: false, + balance: Some("10.00".to_string()), + }), + plan_type: Some(codex_protocol::account::PlanType::Plus), + }; + state.set_rate_limits(initial.clone()); + + let update = RateLimitSnapshot { + limit_id: Some("codex_other".to_string()), + limit_name: Some("codex_other".to_string()), + primary: Some(RateLimitWindow { + used_percent: 40.0, + window_minutes: Some(30), + resets_at: Some(1_800), + }), + secondary: Some(RateLimitWindow { + used_percent: 5.0, + window_minutes: Some(60), + resets_at: Some(1_900), + }), + credits: None, + plan_type: None, + }; + state.set_rate_limits(update.clone()); + + assert_eq!( + state.latest_rate_limits, + Some(RateLimitSnapshot { + limit_id: Some("codex_other".to_string()), + limit_name: Some("codex_other".to_string()), + primary: update.primary.clone(), + secondary: update.secondary, + credits: initial.credits, + plan_type: initial.plan_type, + }) + ); +} + +#[tokio::test] +async fn set_rate_limits_updates_plan_type_when_present() { + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + + let mut state = SessionState::new(session_configuration); + let initial = RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 15.0, + window_minutes: Some(20), + resets_at: Some(1_600), + }), + secondary: Some(RateLimitWindow { + used_percent: 5.0, + window_minutes: Some(45), + resets_at: Some(1_650), + }), + credits: Some(CreditsSnapshot { + has_credits: true, + unlimited: false, + balance: Some("15.00".to_string()), + }), + plan_type: Some(codex_protocol::account::PlanType::Plus), + }; + state.set_rate_limits(initial.clone()); + + let update = RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 35.0, + window_minutes: Some(25), + resets_at: Some(1_700), + }), + secondary: None, + credits: None, + plan_type: Some(codex_protocol::account::PlanType::Pro), + }; + state.set_rate_limits(update.clone()); + + assert_eq!( + state.latest_rate_limits, + Some(RateLimitSnapshot { + limit_id: Some("codex".to_string()), + limit_name: None, + primary: update.primary, + secondary: update.secondary, + credits: initial.credits, + plan_type: update.plan_type, + }) + ); +} + +#[test] +fn prefers_structured_content_when_present() { + let ctr = McpCallToolResult { + // Content present but should be ignored because structured_content is set. + content: vec![text_block("ignored")], + is_error: None, + structured_content: Some(json!({ + "ok": true, + "value": 42 + })), + meta: None, + }; + + let got = FunctionCallOutputPayload::from(&ctr); + let expected = FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text( + serde_json::to_string(&json!({ + "ok": true, + "value": 42 + })) + .unwrap(), + ), + success: Some(true), + }; + + assert_eq!(expected, got); +} + +#[tokio::test] +async fn includes_timed_out_message() { + let exec = ExecToolCallOutput { + exit_code: 0, + stdout: StreamOutput::new(String::new()), + stderr: StreamOutput::new(String::new()), + aggregated_output: StreamOutput::new("Command output".to_string()), + duration: StdDuration::from_secs(1), + timed_out: true, + }; + let (_, turn_context) = make_session_and_context().await; + + let out = format_exec_output_str(&exec, turn_context.truncation_policy); + + assert_eq!( + out, + "command timed out after 1000 milliseconds\nCommand output" + ); +} + +#[tokio::test] +async fn turn_context_with_model_updates_model_fields() { + let (session, mut turn_context) = make_session_and_context().await; + turn_context.reasoning_effort = Some(ReasoningEffortConfig::Minimal); + let updated = turn_context + .with_model("gpt-5.1".to_string(), &session.services.models_manager) + .await; + let expected_model_info = session + .services + .models_manager + .get_model_info("gpt-5.1", updated.config.as_ref()) + .await; + + assert_eq!(updated.config.model.as_deref(), Some("gpt-5.1")); + assert_eq!(updated.collaboration_mode.model(), "gpt-5.1"); + assert_eq!(updated.model_info, expected_model_info); + assert_eq!( + updated.reasoning_effort, + Some(ReasoningEffortConfig::Medium) + ); + assert_eq!( + updated.collaboration_mode.reasoning_effort(), + Some(ReasoningEffortConfig::Medium) + ); + assert_eq!( + updated.config.model_reasoning_effort, + Some(ReasoningEffortConfig::Medium) + ); + assert_eq!( + updated.truncation_policy, + expected_model_info.truncation_policy.into() + ); + assert!(!Arc::ptr_eq( + &updated.tool_call_gate, + &turn_context.tool_call_gate + )); +} + +#[test] +fn falls_back_to_content_when_structured_is_null() { + let ctr = McpCallToolResult { + content: vec![text_block("hello"), text_block("world")], + is_error: None, + structured_content: Some(serde_json::Value::Null), + meta: None, + }; + + let got = FunctionCallOutputPayload::from(&ctr); + let expected = FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text( + serde_json::to_string(&vec![text_block("hello"), text_block("world")]).unwrap(), + ), + success: Some(true), + }; + + assert_eq!(expected, got); +} + +#[test] +fn success_flag_reflects_is_error_true() { + let ctr = McpCallToolResult { + content: vec![text_block("unused")], + is_error: Some(true), + structured_content: Some(json!({ "message": "bad" })), + meta: None, + }; + + let got = FunctionCallOutputPayload::from(&ctr); + let expected = FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text( + serde_json::to_string(&json!({ "message": "bad" })).unwrap(), + ), + success: Some(false), + }; + + assert_eq!(expected, got); +} + +#[test] +fn success_flag_true_with_no_error_and_content_used() { + let ctr = McpCallToolResult { + content: vec![text_block("alpha")], + is_error: Some(false), + structured_content: None, + meta: None, + }; + + let got = FunctionCallOutputPayload::from(&ctr); + let expected = FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text( + serde_json::to_string(&vec![text_block("alpha")]).unwrap(), + ), + success: Some(true), + }; + + assert_eq!(expected, got); +} + +async fn wait_for_thread_rolled_back( + rx: &async_channel::Receiver, +) -> crate::protocol::ThreadRolledBackEvent { + let deadline = StdDuration::from_secs(2); + let start = std::time::Instant::now(); + loop { + let remaining = deadline.saturating_sub(start.elapsed()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::ThreadRolledBack(payload) => return payload, + _ => continue, + } + } +} + +async fn wait_for_thread_rollback_failed(rx: &async_channel::Receiver) -> ErrorEvent { + let deadline = StdDuration::from_secs(2); + let start = std::time::Instant::now(); + loop { + let remaining = deadline.saturating_sub(start.elapsed()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::Error(payload) + if payload.codex_error_info == Some(CodexErrorInfo::ThreadRollbackFailed) => + { + return payload; + } + _ => continue, + } + } +} + +async fn attach_rollout_recorder(session: &Arc) -> PathBuf { + let config = session.get_config().await; + let recorder = RolloutRecorder::new( + config.as_ref(), + RolloutRecorderParams::new( + ThreadId::default(), + None, + SessionSource::Exec, + BaseInstructions::default(), + Vec::new(), + EventPersistenceMode::Limited, + ), + None, + None, + ) + .await + .expect("create rollout recorder"); + let rollout_path = recorder.rollout_path().to_path_buf(); + { + let mut rollout = session.services.rollout.lock().await; + *rollout = Some(recorder); + } + session.ensure_rollout_materialized().await; + session.flush_rollout().await; + rollout_path +} + +fn text_block(s: &str) -> serde_json::Value { + json!({ + "type": "text", + "text": s, + }) +} + +fn init_test_tracing() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + let provider = SdkTracerProvider::builder().build(); + let tracer = provider.tracer("codex-core-tests"); + let subscriber = + tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer)); + tracing::subscriber::set_global_default(subscriber) + .expect("global tracing subscriber should only be installed once"); + }); +} + +async fn build_test_config(codex_home: &Path) -> Config { + ConfigBuilder::default() + .codex_home(codex_home.to_path_buf()) + .build() + .await + .expect("load default test config") +} + +fn otel_manager( + conversation_id: ThreadId, + config: &Config, + model_info: &ModelInfo, + session_source: SessionSource, +) -> OtelManager { + OtelManager::new( + conversation_id, + ModelsManager::get_model_offline_for_tests(config.model.as_deref()).as_str(), + model_info.slug.as_str(), + None, + Some("test@test.com".to_string()), + Some(TelemetryAuthMode::Chatgpt), + "test_originator".to_string(), + false, + "test".to_string(), + session_source, + ) +} + +pub(crate) async fn make_session_configuration_for_tests() -> SessionConfiguration { + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + + SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + } +} + +#[tokio::test] +async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() { + let codex_home = tempfile::tempdir().expect("create temp dir"); + let mut config = build_test_config(codex_home.path()).await; + config + .features + .enable(Feature::ShellZshFork) + .expect("test config should allow shell_zsh_fork"); + config.zsh_path = None; + let config = Arc::new(config); + + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let models_manager = Arc::new(ModelsManager::new( + config.codex_home.clone(), + auth_manager.clone(), + None, + CollaborationModesConfig::default(), + )); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort: config.model_reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + + let (tx_event, _rx_event) = async_channel::unbounded(); + let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); + let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); + let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); + let skills_manager = Arc::new(SkillsManager::new( + config.codex_home.clone(), + Arc::clone(&plugins_manager), + )); + let result = Session::new( + session_configuration, + Arc::clone(&config), + auth_manager, + models_manager, + ExecPolicyManager::default(), + tx_event, + agent_status_tx, + InitialHistory::New, + SessionSource::Exec, + skills_manager, + plugins_manager, + mcp_manager, + Arc::new(FileWatcher::noop()), + AgentControl::default(), + ) + .await; + + let err = match result { + Ok(_) => panic!("expected startup to fail"), + Err(err) => err, + }; + let msg = format!("{err:#}"); + assert!(msg.contains("zsh fork feature enabled, but `zsh_path` is not configured")); +} + +// todo: use online model info +pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { + let (tx_event, _rx_event) = async_channel::unbounded(); + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let conversation_id = ThreadId::default(); + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let models_manager = Arc::new(ModelsManager::new( + config.codex_home.clone(), + auth_manager.clone(), + None, + CollaborationModesConfig::default(), + )); + let agent_control = AgentControl::default(); + let exec_policy = ExecPolicyManager::default(); + let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools: Vec::new(), + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + let per_turn_config = Session::build_per_turn_config(&session_configuration); + let model_info = ModelsManager::construct_model_info_offline_for_tests( + session_configuration.collaboration_mode.model(), + &per_turn_config, + ); + let otel_manager = otel_manager( + conversation_id, + config.as_ref(), + &model_info, + session_configuration.session_source.clone(), + ); + + let state = SessionState::new(session_configuration.clone()); + let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); + let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); + let skills_manager = Arc::new(SkillsManager::new( + config.codex_home.clone(), + Arc::clone(&plugins_manager), + )); + let network_approval = Arc::new(NetworkApprovalService::default()); + + let file_watcher = Arc::new(FileWatcher::noop()); + let services = SessionServices { + mcp_connection_manager: Arc::new(RwLock::new( + McpConnectionManager::new_mcp_connection_manager_for_tests( + &config.permissions.approval_policy, + ), + )), + mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), + unified_exec_manager: UnifiedExecProcessManager::new( + config.background_terminal_max_timeout, + ), + shell_zsh_path: None, + main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(), + analytics_events_client: AnalyticsEventsClient::new( + Arc::clone(&config), + Arc::clone(&auth_manager), + ), + hooks: Hooks::new(HooksConfig { + legacy_notify_argv: config.notify.clone(), + }), + rollout: Mutex::new(None), + user_shell: Arc::new(default_user_shell()), + shell_snapshot_tx: watch::channel(None).0, + show_raw_agent_reasoning: config.show_raw_agent_reasoning, + exec_policy, + auth_manager: auth_manager.clone(), + otel_manager: otel_manager.clone(), + models_manager: Arc::clone(&models_manager), + tool_approvals: Mutex::new(ApprovalStore::default()), + execve_session_approvals: RwLock::new(HashMap::new()), + skills_manager, + plugins_manager, + mcp_manager, + file_watcher, + agent_control, + network_proxy: None, + network_approval: Arc::clone(&network_approval), + state_db: None, + model_client: ModelClient::new( + Some(auth_manager.clone()), + conversation_id, + session_configuration.provider.clone(), + session_configuration.session_source.clone(), + config.model_verbosity, + ws_version_from_features(config.as_ref()), + config.features.enabled(Feature::EnableRequestCompression), + config.features.enabled(Feature::RuntimeMetrics), + Session::build_model_client_beta_features_header(config.as_ref()), + ), + }; + let js_repl = Arc::new(JsReplHandle::with_node_path( + config.js_repl_node_path.clone(), + config.js_repl_node_module_dirs.clone(), + )); + + let skills_outcome = Arc::new(services.skills_manager.skills_for_config(&per_turn_config)); + let turn_context = Session::make_turn_context( + Some(Arc::clone(&auth_manager)), + &otel_manager, + session_configuration.provider.clone(), + &session_configuration, + per_turn_config, + model_info, + None, + "turn_id".to_string(), + Arc::clone(&js_repl), + skills_outcome, + ); + + let session = Session { + conversation_id, + tx_event, + agent_status: agent_status_tx, + state: Mutex::new(state), + features: config.features.clone(), + pending_mcp_server_refresh_config: Mutex::new(None), + conversation: Arc::new(RealtimeConversationManager::new()), + active_turn: Mutex::new(None), + services, + js_repl, + next_internal_sub_id: AtomicU64::new(0), + }; + + (session, turn_context) +} + +#[tokio::test] +async fn submit_with_id_captures_current_span_trace_context() { + let (session, _turn_context) = make_session_and_context().await; + let (tx_sub, rx_sub) = async_channel::bounded(1); + let (_tx_event, rx_event) = async_channel::unbounded(); + let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); + let codex = Codex { + tx_sub, + rx_event, + agent_status, + session: Arc::new(session), + }; + + init_test_tracing(); + + let request_parent = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), + tracestate: Some("vendor=value".into()), + }; + let request_span = info_span!("app_server.request"); + assert!(set_parent_from_w3c_trace_context( + &request_span, + &request_parent + )); + + let expected_trace = async { + let expected_trace = + current_span_w3c_trace_context().expect("current span should have trace context"); + codex + .submit_with_id(Submission { + id: "sub-1".into(), + op: Op::Interrupt, + trace: None, + }) + .await + .expect("submit should succeed"); + expected_trace + } + .instrument(request_span) + .await; + + let submitted = rx_sub.recv().await.expect("submission"); + assert_eq!(submitted.trace, Some(expected_trace)); +} + +#[tokio::test] +async fn new_default_turn_captures_current_span_trace_id() { + let (session, _turn_context) = make_session_and_context().await; + + init_test_tracing(); + + let request_parent = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), + tracestate: Some("vendor=value".into()), + }; + let request_span = info_span!("app_server.request"); + assert!(set_parent_from_w3c_trace_context( + &request_span, + &request_parent + )); + + let turn_context_item = async { + let expected_trace_id = Span::current() + .context() + .span() + .span_context() + .trace_id() + .to_string(); + let turn_context = session.new_default_turn().await; + let turn_context_item = turn_context.to_turn_context_item(); + assert_eq!(turn_context_item.trace_id, Some(expected_trace_id)); + turn_context_item + } + .instrument(request_span) + .await; + + assert_eq!( + turn_context_item.trace_id.as_deref(), + Some("00000000000000000000000000000011") + ); +} + +#[test] +fn submission_dispatch_span_prefers_submission_trace_context() { + init_test_tracing(); + + let ambient_parent = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000033-0000000000000044-01".into()), + tracestate: None, + }; + let ambient_span = info_span!("ambient"); + assert!(set_parent_from_w3c_trace_context( + &ambient_span, + &ambient_parent + )); + + let submission_trace = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000055-0000000000000066-01".into()), + tracestate: Some("vendor=value".into()), + }; + let dispatch_span = ambient_span.in_scope(|| { + submission_dispatch_span(&Submission { + id: "sub-1".into(), + op: Op::Interrupt, + trace: Some(submission_trace), + }) + }); + + let trace_id = dispatch_span.context().span().span_context().trace_id(); + assert_eq!( + trace_id, + TraceId::from_hex("00000000000000000000000000000055").expect("trace id") + ); +} + +#[test] +fn submission_dispatch_span_uses_debug_for_realtime_audio() { + init_test_tracing(); + + let dispatch_span = submission_dispatch_span(&Submission { + id: "sub-1".into(), + op: Op::RealtimeConversationAudio(ConversationAudioParams { + frame: RealtimeAudioFrame { + data: "ZmFrZQ==".into(), + sample_rate: 16_000, + num_channels: 1, + samples_per_channel: Some(160), + }, + }), + trace: None, + }); + + assert_eq!( + dispatch_span.metadata().expect("span metadata").level(), + &tracing::Level::DEBUG + ); +} + +#[tokio::test] +async fn spawn_task_turn_span_inherits_dispatch_trace_context() { + struct TraceCaptureTask { + captured_trace: Arc>>, + } + + #[async_trait::async_trait] + impl SessionTask for TraceCaptureTask { + fn kind(&self) -> TaskKind { + TaskKind::Regular + } + + fn span_name(&self) -> &'static str { + "session_task.trace_capture" + } + + async fn run( + self: Arc, + _session: Arc, + _ctx: Arc, + _input: Vec, + _cancellation_token: CancellationToken, + ) -> Option { + let mut trace = self + .captured_trace + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *trace = current_span_w3c_trace_context(); + None + } + } + + init_test_tracing(); + + let request_parent = W3cTraceContext { + traceparent: Some("00-00000000000000000000000000000011-0000000000000022-01".into()), + tracestate: Some("vendor=value".into()), + }; + let request_span = tracing::info_span!("app_server.request"); + assert!(set_parent_from_w3c_trace_context( + &request_span, + &request_parent + )); + + let submission_trace = + async { current_span_w3c_trace_context().expect("request span should have trace context") } + .instrument(request_span) + .await; + + let dispatch_span = submission_dispatch_span(&Submission { + id: "sub-1".into(), + op: Op::Interrupt, + trace: Some(submission_trace.clone()), + }); + let dispatch_span_id = dispatch_span.context().span().span_context().span_id(); + + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let captured_trace = Arc::new(std::sync::Mutex::new(None)); + + async { + sess.spawn_task( + Arc::clone(&tc), + vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }], + TraceCaptureTask { + captured_trace: Arc::clone(&captured_trace), + }, + ) + .await; + } + .instrument(dispatch_span) + .await; + + let evt = tokio::time::timeout(StdDuration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for turn completion") + .expect("event"); + assert!(matches!(evt.msg, EventMsg::TurnComplete(_))); + + let task_trace = captured_trace + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + .expect("turn task should capture the current span trace context"); + let submission_context = + codex_otel::context_from_w3c_trace_context(&submission_trace).expect("submission"); + let task_context = codex_otel::context_from_w3c_trace_context(&task_trace).expect("task trace"); + + assert_eq!( + task_context.span().span_context().trace_id(), + submission_context.span().span_context().trace_id() + ); + assert_ne!( + task_context.span().span_context().span_id(), + dispatch_span_id + ); +} + +pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( + dynamic_tools: Vec, +) -> ( + Arc, + Arc, + async_channel::Receiver, +) { + let (tx_event, rx_event) = async_channel::unbounded(); + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = build_test_config(codex_home.path()).await; + let config = Arc::new(config); + let conversation_id = ThreadId::default(); + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let models_manager = Arc::new(ModelsManager::new( + config.codex_home.clone(), + auth_manager.clone(), + None, + CollaborationModesConfig::default(), + )); + let agent_control = AgentControl::default(); + let exec_policy = ExecPolicyManager::default(); + let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); + let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); + let model_info = ModelsManager::construct_model_info_offline_for_tests(model.as_str(), &config); + let reasoning_effort = config.model_reasoning_effort; + let collaboration_mode = CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model, + reasoning_effort, + developer_instructions: None, + }, + }; + let session_configuration = SessionConfiguration { + provider: config.model_provider.clone(), + collaboration_mode, + model_reasoning_summary: config.model_reasoning_summary, + developer_instructions: config.developer_instructions.clone(), + user_instructions: config.user_instructions.clone(), + service_tier: None, + personality: config.personality, + base_instructions: config + .base_instructions + .clone() + .unwrap_or_else(|| model_info.get_model_instructions(config.personality)), + compact_prompt: config.compact_prompt.clone(), + approval_policy: config.permissions.approval_policy.clone(), + sandbox_policy: config.permissions.sandbox_policy.clone(), + windows_sandbox_level: WindowsSandboxLevel::from_config(&config), + cwd: config.cwd.clone(), + codex_home: config.codex_home.clone(), + thread_name: None, + original_config_do_not_use: Arc::clone(&config), + metrics_service_name: None, + app_server_client_name: None, + session_source: SessionSource::Exec, + dynamic_tools, + persist_extended_history: false, + inherited_shell_snapshot: None, + }; + let per_turn_config = Session::build_per_turn_config(&session_configuration); + let model_info = ModelsManager::construct_model_info_offline_for_tests( + session_configuration.collaboration_mode.model(), + &per_turn_config, + ); + let otel_manager = otel_manager( + conversation_id, + config.as_ref(), + &model_info, + session_configuration.session_source.clone(), + ); + + let state = SessionState::new(session_configuration.clone()); + let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.clone())); + let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); + let skills_manager = Arc::new(SkillsManager::new( + config.codex_home.clone(), + Arc::clone(&plugins_manager), + )); + let network_approval = Arc::new(NetworkApprovalService::default()); + + let file_watcher = Arc::new(FileWatcher::noop()); + let services = SessionServices { + mcp_connection_manager: Arc::new(RwLock::new( + McpConnectionManager::new_mcp_connection_manager_for_tests( + &config.permissions.approval_policy, + ), + )), + mcp_startup_cancellation_token: Mutex::new(CancellationToken::new()), + unified_exec_manager: UnifiedExecProcessManager::new( + config.background_terminal_max_timeout, + ), + shell_zsh_path: None, + main_execve_wrapper_exe: config.main_execve_wrapper_exe.clone(), + analytics_events_client: AnalyticsEventsClient::new( + Arc::clone(&config), + Arc::clone(&auth_manager), + ), + hooks: Hooks::new(HooksConfig { + legacy_notify_argv: config.notify.clone(), + }), + rollout: Mutex::new(None), + user_shell: Arc::new(default_user_shell()), + shell_snapshot_tx: watch::channel(None).0, + show_raw_agent_reasoning: config.show_raw_agent_reasoning, + exec_policy, + auth_manager: Arc::clone(&auth_manager), + otel_manager: otel_manager.clone(), + models_manager: Arc::clone(&models_manager), + tool_approvals: Mutex::new(ApprovalStore::default()), + execve_session_approvals: RwLock::new(HashMap::new()), + skills_manager, + plugins_manager, + mcp_manager, + file_watcher, + agent_control, + network_proxy: None, + network_approval: Arc::clone(&network_approval), + state_db: None, + model_client: ModelClient::new( + Some(Arc::clone(&auth_manager)), + conversation_id, + session_configuration.provider.clone(), + session_configuration.session_source.clone(), + config.model_verbosity, + ws_version_from_features(config.as_ref()), + config.features.enabled(Feature::EnableRequestCompression), + config.features.enabled(Feature::RuntimeMetrics), + Session::build_model_client_beta_features_header(config.as_ref()), + ), + }; + let js_repl = Arc::new(JsReplHandle::with_node_path( + config.js_repl_node_path.clone(), + config.js_repl_node_module_dirs.clone(), + )); + + let skills_outcome = Arc::new(services.skills_manager.skills_for_config(&per_turn_config)); + let turn_context = Arc::new(Session::make_turn_context( + Some(Arc::clone(&auth_manager)), + &otel_manager, + session_configuration.provider.clone(), + &session_configuration, + per_turn_config, + model_info, + None, + "turn_id".to_string(), + Arc::clone(&js_repl), + skills_outcome, + )); + + let session = Arc::new(Session { + conversation_id, + tx_event, + agent_status: agent_status_tx, + state: Mutex::new(state), + features: config.features.clone(), + pending_mcp_server_refresh_config: Mutex::new(None), + conversation: Arc::new(RealtimeConversationManager::new()), + active_turn: Mutex::new(None), + services, + js_repl, + next_internal_sub_id: AtomicU64::new(0), + }); + + (session, turn_context, rx_event) +} + +// Like make_session_and_context, but returns Arc and the event receiver +// so tests can assert on emitted events. +pub(crate) async fn make_session_and_context_with_rx() -> ( + Arc, + Arc, + async_channel::Receiver, +) { + make_session_and_context_with_dynamic_tools_and_rx(Vec::new()).await +} + +#[tokio::test] +async fn refresh_mcp_servers_is_deferred_until_next_turn() { + let (session, turn_context) = make_session_and_context().await; + let old_token = session.mcp_startup_cancellation_token().await; + assert!(!old_token.is_cancelled()); + + let mcp_oauth_credentials_store_mode = + serde_json::to_value(OAuthCredentialsStoreMode::Auto).expect("serialize store mode"); + let refresh_config = McpServerRefreshConfig { + mcp_servers: json!({}), + mcp_oauth_credentials_store_mode, + }; + { + let mut guard = session.pending_mcp_server_refresh_config.lock().await; + *guard = Some(refresh_config); + } + + assert!(!old_token.is_cancelled()); + assert!( + session + .pending_mcp_server_refresh_config + .lock() + .await + .is_some() + ); + + session + .refresh_mcp_servers_if_requested(&turn_context) + .await; + + assert!(old_token.is_cancelled()); + assert!( + session + .pending_mcp_server_refresh_config + .lock() + .await + .is_none() + ); + let new_token = session.mcp_startup_cancellation_token().await; + assert!(!new_token.is_cancelled()); +} + +#[tokio::test] +async fn record_model_warning_appends_user_message() { + let (mut session, turn_context) = make_session_and_context().await; + let features = crate::features::Features::with_defaults().into(); + session.features = features; + + session + .record_model_warning("too many unified exec processes", &turn_context) + .await; + + let history = session.clone_history().await; + let history_items = history.raw_items(); + let last = history_items.last().expect("warning recorded"); + + match last { + ResponseItem::Message { role, content, .. } => { + assert_eq!(role, "user"); + assert_eq!( + content, + &vec![ContentItem::InputText { + text: "Warning: too many unified exec processes".to_string(), + }] + ); + } + other => panic!("expected user message, got {other:?}"), + } +} + +#[tokio::test] +async fn spawn_task_does_not_update_previous_turn_settings_for_non_run_turn_tasks() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + sess.set_previous_turn_settings(None).await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + assert_eq!(sess.previous_turn_settings().await, None); +} + +#[tokio::test] +async fn build_settings_update_items_emits_environment_item_for_network_changes() { + let (session, previous_context) = make_session_and_context().await; + let previous_context = Arc::new(previous_context); + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + + let mut config = (*current_context.config).clone(); + let mut requirements = config.config_layer_stack.requirements().clone(); + requirements.network = Some(Sourced::new( + NetworkConstraints { + allowed_domains: Some(vec!["api.example.com".to_string()]), + denied_domains: Some(vec!["blocked.example.com".to_string()]), + ..Default::default() + }, + RequirementSource::CloudRequirements, + )); + let layers = config + .config_layer_stack + .get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, true) + .into_iter() + .cloned() + .collect(); + config.config_layer_stack = ConfigLayerStack::new( + layers, + requirements, + config.config_layer_stack.requirements_toml().clone(), + ) + .expect("rebuild config layer stack with network requirements"); + current_context.config = Arc::new(config); + + let reference_context_item = previous_context.to_turn_context_item(); + let update_items = session + .build_settings_update_items(Some(&reference_context_item), ¤t_context) + .await; + + let environment_update = update_items + .iter() + .find_map(|item| match item { + ResponseItem::Message { role, content, .. } if role == "user" => { + let [ContentItem::InputText { text }] = content.as_slice() else { + return None; + }; + text.contains("").then_some(text) + } + _ => None, + }) + .expect("environment update item should be emitted"); + assert!(environment_update.contains("")); + assert!(environment_update.contains("api.example.com")); + assert!(environment_update.contains("blocked.example.com")); +} + +#[tokio::test] +async fn build_settings_update_items_emits_environment_item_for_time_changes() { + let (session, previous_context) = make_session_and_context().await; + let previous_context = Arc::new(previous_context); + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.current_date = Some("2026-02-27".to_string()); + current_context.timezone = Some("Europe/Berlin".to_string()); + + let reference_context_item = previous_context.to_turn_context_item(); + let update_items = session + .build_settings_update_items(Some(&reference_context_item), ¤t_context) + .await; + + let environment_update = update_items + .iter() + .find_map(|item| match item { + ResponseItem::Message { role, content, .. } if role == "user" => { + let [ContentItem::InputText { text }] = content.as_slice() else { + return None; + }; + text.contains("").then_some(text) + } + _ => None, + }) + .expect("environment update item should be emitted"); + assert!(environment_update.contains("2026-02-27")); + assert!(environment_update.contains("Europe/Berlin")); +} + +#[tokio::test] +async fn build_settings_update_items_emits_realtime_start_when_session_becomes_live() { + let (session, previous_context) = make_session_and_context().await; + let previous_context = Arc::new(previous_context); + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.realtime_active = true; + + let update_items = session + .build_settings_update_items( + Some(&previous_context.to_turn_context_item()), + ¤t_context, + ) + .await; + + let developer_texts = developer_input_texts(&update_items); + assert!( + developer_texts + .iter() + .any(|text| text.contains("")), + "expected a realtime start update, got {developer_texts:?}" + ); +} + +#[tokio::test] +async fn build_settings_update_items_emits_realtime_end_when_session_stops_being_live() { + let (session, mut previous_context) = make_session_and_context().await; + previous_context.realtime_active = true; + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.realtime_active = false; + + let update_items = session + .build_settings_update_items( + Some(&previous_context.to_turn_context_item()), + ¤t_context, + ) + .await; + + let developer_texts = developer_input_texts(&update_items); + assert!( + developer_texts + .iter() + .any(|text| text.contains("Reason: inactive")), + "expected a realtime end update, got {developer_texts:?}" + ); +} + +#[tokio::test] +async fn build_settings_update_items_uses_previous_turn_settings_for_realtime_end() { + let (session, previous_context) = make_session_and_context().await; + let mut previous_context_item = previous_context.to_turn_context_item(); + previous_context_item.realtime_active = None; + let previous_turn_settings = PreviousTurnSettings { + model: previous_context.model_info.slug.clone(), + realtime_active: Some(true), + }; + let mut current_context = previous_context + .with_model( + previous_context.model_info.slug.clone(), + &session.services.models_manager, + ) + .await; + current_context.realtime_active = false; + + session + .set_previous_turn_settings(Some(previous_turn_settings)) + .await; + let update_items = session + .build_settings_update_items(Some(&previous_context_item), ¤t_context) + .await; + + let developer_texts = developer_input_texts(&update_items); + assert!( + developer_texts + .iter() + .any(|text| text.contains("Reason: inactive")), + "expected a realtime end update from previous turn settings, got {developer_texts:?}" + ); +} + +#[tokio::test] +async fn build_initial_context_uses_previous_realtime_state() { + let (session, mut turn_context) = make_session_and_context().await; + turn_context.realtime_active = true; + + let initial_context = session.build_initial_context(&turn_context).await; + let developer_texts = developer_input_texts(&initial_context); + assert!( + developer_texts + .iter() + .any(|text| text.contains("")), + "expected initial context to describe active realtime state, got {developer_texts:?}" + ); + + let previous_context_item = turn_context.to_turn_context_item(); + { + let mut state = session.state.lock().await; + state.set_reference_context_item(Some(previous_context_item)); + } + let resumed_context = session.build_initial_context(&turn_context).await; + let resumed_developer_texts = developer_input_texts(&resumed_context); + assert!( + !resumed_developer_texts + .iter() + .any(|text| text.contains("")), + "did not expect a duplicate realtime update, got {resumed_developer_texts:?}" + ); +} + +#[tokio::test] +async fn build_initial_context_uses_previous_turn_settings_for_realtime_end() { + let (session, turn_context) = make_session_and_context().await; + let previous_turn_settings = PreviousTurnSettings { + model: turn_context.model_info.slug.clone(), + realtime_active: Some(true), + }; + + session + .set_previous_turn_settings(Some(previous_turn_settings)) + .await; + let initial_context = session.build_initial_context(&turn_context).await; + let developer_texts = developer_input_texts(&initial_context); + assert!( + developer_texts + .iter() + .any(|text| text.contains("Reason: inactive")), + "expected initial context to describe an ended realtime session, got {developer_texts:?}" + ); +} + +#[tokio::test] +async fn build_initial_context_restates_realtime_start_when_reference_context_is_missing() { + let (session, mut turn_context) = make_session_and_context().await; + turn_context.realtime_active = true; + let previous_turn_settings = PreviousTurnSettings { + model: turn_context.model_info.slug.clone(), + realtime_active: Some(true), + }; + + session + .set_previous_turn_settings(Some(previous_turn_settings)) + .await; + let initial_context = session.build_initial_context(&turn_context).await; + let developer_texts = developer_input_texts(&initial_context); + assert!( + developer_texts + .iter() + .any(|text| text.contains("")), + "expected initial context to restate active realtime when the reference context is missing, got {developer_texts:?}" + ); +} + +#[tokio::test] +async fn record_context_updates_and_set_reference_context_item_injects_full_context_when_baseline_missing() + { + let (session, turn_context) = make_session_and_context().await; + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + let history = session.clone_history().await; + let initial_context = session.build_initial_context(&turn_context).await; + assert_eq!(history.raw_items().to_vec(), initial_context); + + let current_context = session.reference_context_item().await; + assert_eq!( + serde_json::to_value(current_context).expect("serialize current context item"), + serde_json::to_value(Some(turn_context.to_turn_context_item())) + .expect("serialize expected context item") + ); +} + +#[tokio::test] +async fn record_context_updates_and_set_reference_context_item_reinjects_full_context_after_clear() +{ + let (session, turn_context) = make_session_and_context().await; + let compacted_summary = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!("{}\nsummary", crate::compact::SUMMARY_PREFIX), + }], + end_turn: None, + phase: None, + }; + session + .record_into_history(std::slice::from_ref(&compacted_summary), &turn_context) + .await; + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + { + let mut state = session.state.lock().await; + state.set_reference_context_item(None); + } + session + .replace_history(vec![compacted_summary.clone()], None) + .await; + + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + + let history = session.clone_history().await; + let mut expected_history = vec![compacted_summary]; + expected_history.extend(session.build_initial_context(&turn_context).await); + assert_eq!(history.raw_items().to_vec(), expected_history); +} + +#[tokio::test] +async fn record_context_updates_and_set_reference_context_item_persists_baseline_without_emitting_diffs() + { + let (session, previous_context) = make_session_and_context().await; + let next_model = if previous_context.model_info.slug == "gpt-5.1" { + "gpt-5" + } else { + "gpt-5.1" + }; + let turn_context = previous_context + .with_model(next_model.to_string(), &session.services.models_manager) + .await; + let previous_context_item = previous_context.to_turn_context_item(); + { + let mut state = session.state.lock().await; + state.set_reference_context_item(Some(previous_context_item.clone())); + } + let config = session.get_config().await; + let recorder = RolloutRecorder::new( + config.as_ref(), + RolloutRecorderParams::new( + ThreadId::default(), + None, + SessionSource::Exec, + BaseInstructions::default(), + Vec::new(), + EventPersistenceMode::Limited, + ), + None, + None, + ) + .await + .expect("create rollout recorder"); + let rollout_path = recorder.rollout_path().to_path_buf(); + { + let mut rollout = session.services.rollout.lock().await; + *rollout = Some(recorder); + } + + let update_items = session + .build_settings_update_items(Some(&previous_context_item), &turn_context) + .await; + assert_eq!(update_items, Vec::new()); + + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + + assert_eq!( + session.clone_history().await.raw_items().to_vec(), + Vec::new() + ); + assert_eq!( + serde_json::to_value(session.reference_context_item().await) + .expect("serialize current context item"), + serde_json::to_value(Some(turn_context.to_turn_context_item())) + .expect("serialize expected context item") + ); + session.ensure_rollout_materialized().await; + session.flush_rollout().await; + + let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) + .await + .expect("read rollout history") + else { + panic!("expected resumed rollout history"); + }; + let persisted_turn_context = resumed.history.iter().find_map(|item| match item { + RolloutItem::TurnContext(ctx) => Some(ctx.clone()), + _ => None, + }); + assert_eq!( + serde_json::to_value(persisted_turn_context) + .expect("serialize persisted turn context item"), + serde_json::to_value(Some(turn_context.to_turn_context_item())) + .expect("serialize expected turn context item") + ); +} + +#[tokio::test] +async fn build_initial_context_prepends_model_switch_message() { + let (session, turn_context) = make_session_and_context().await; + let previous_turn_settings = PreviousTurnSettings { + model: "previous-regular-model".to_string(), + realtime_active: None, + }; + + session + .set_previous_turn_settings(Some(previous_turn_settings)) + .await; + let initial_context = session.build_initial_context(&turn_context).await; + + let ResponseItem::Message { role, content, .. } = &initial_context[0] else { + panic!("expected developer message"); + }; + assert_eq!(role, "developer"); + let [ContentItem::InputText { text }, ..] = content.as_slice() else { + panic!("expected developer text"); + }; + assert!(text.contains("")); +} + +#[tokio::test] +async fn record_context_updates_and_set_reference_context_item_persists_full_reinjection_to_rollout() + { + let (session, previous_context) = make_session_and_context().await; + let next_model = if previous_context.model_info.slug == "gpt-5.1" { + "gpt-5" + } else { + "gpt-5.1" + }; + let turn_context = previous_context + .with_model(next_model.to_string(), &session.services.models_manager) + .await; + let config = session.get_config().await; + let recorder = RolloutRecorder::new( + config.as_ref(), + RolloutRecorderParams::new( + ThreadId::default(), + None, + SessionSource::Exec, + BaseInstructions::default(), + Vec::new(), + EventPersistenceMode::Limited, + ), + None, + None, + ) + .await + .expect("create rollout recorder"); + let rollout_path = recorder.rollout_path().to_path_buf(); + { + let mut rollout = session.services.rollout.lock().await; + *rollout = Some(recorder); + } + + session + .persist_rollout_items(&[RolloutItem::EventMsg(EventMsg::UserMessage( + UserMessageEvent { + message: "seed rollout".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + ))]) + .await; + { + let mut state = session.state.lock().await; + state.set_reference_context_item(None); + } + + session + .set_previous_turn_settings(Some(PreviousTurnSettings { + model: previous_context.model_info.slug.clone(), + realtime_active: Some(previous_context.realtime_active), + })) + .await; + session + .record_context_updates_and_set_reference_context_item(&turn_context) + .await; + session.ensure_rollout_materialized().await; + session.flush_rollout().await; + + let InitialHistory::Resumed(resumed) = RolloutRecorder::get_rollout_history(&rollout_path) + .await + .expect("read rollout history") + else { + panic!("expected resumed rollout history"); + }; + let persisted_turn_context = resumed.history.iter().find_map(|item| match item { + RolloutItem::TurnContext(ctx) => Some(ctx.clone()), + _ => None, + }); + + assert_eq!( + serde_json::to_value(persisted_turn_context) + .expect("serialize persisted turn context item"), + serde_json::to_value(Some(turn_context.to_turn_context_item())) + .expect("serialize expected turn context item") + ); +} + +#[tokio::test] +async fn run_user_shell_command_does_not_set_reference_context_item() { + let (session, _turn_context, rx) = make_session_and_context_with_rx().await; + { + let mut state = session.state.lock().await; + state.set_reference_context_item(None); + } + + handlers::run_user_shell_command(&session, "sub-id".to_string(), "echo shell".to_string()) + .await; + + let deadline = StdDuration::from_secs(15); + let start = std::time::Instant::now(); + loop { + let remaining = deadline.saturating_sub(start.elapsed()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + if matches!(evt.msg, EventMsg::TurnComplete(_)) { + break; + } + } + + assert!( + session.reference_context_item().await.is_none(), + "standalone shell tasks should not mutate previous context" + ); +} + +#[derive(Clone, Copy)] +struct NeverEndingTask { + kind: TaskKind, + listen_to_cancellation_token: bool, +} + +#[async_trait::async_trait] +impl SessionTask for NeverEndingTask { + fn kind(&self) -> TaskKind { + self.kind + } + + fn span_name(&self) -> &'static str { + "session_task.never_ending" + } + + async fn run( + self: Arc, + _session: Arc, + _ctx: Arc, + _input: Vec, + cancellation_token: CancellationToken, + ) -> Option { + if self.listen_to_cancellation_token { + cancellation_token.cancelled().await; + return None; + } + loop { + sleep(Duration::from_secs(60)).await; + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[test_log::test] +async fn abort_regular_task_emits_turn_aborted_only() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + // Interrupts persist a model-visible `` marker into history, but there is no + // separate client-visible event for that marker (only `EventMsg::TurnAborted`). + let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), + other => panic!("unexpected event: {other:?}"), + } + // No extra events should be emitted after an abort. + assert!(rx.try_recv().is_err()); +} + +#[tokio::test] +async fn abort_gracefully_emits_turn_aborted_only() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + // Even if tasks handle cancellation gracefully, interrupts still result in `TurnAborted` + // being the only client-visible signal. + let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), + other => panic!("unexpected event: {other:?}"), + } + // No extra events should be emitted after an abort. + assert!(rx.try_recv().is_err()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + while rx.try_recv().is_ok() {} + + sess.inject_response_items(vec![ResponseInputItem::Message { + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "late pending input".to_string(), + }], + }]) + .await + .expect("inject pending input into active turn"); + + sess.on_task_finished(Arc::clone(&tc), None).await; + + let history = sess.clone_history().await; + let expected = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "late pending input".to_string(), + }], + end_turn: None, + phase: None, + }; + assert!( + history.raw_items().iter().any(|item| item == &expected), + "expected pending input to be persisted into history on turn completion" + ); + + let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected raw response item event") + .expect("channel open"); + assert!(matches!(first.msg, EventMsg::RawResponseItem(_))); + + let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected item started event") + .expect("channel open"); + assert!(matches!( + second.msg, + EventMsg::ItemStarted(ItemStartedEvent { + item: TurnItem::UserMessage(UserMessageItem { content, .. }), + .. + }) if content == vec![UserInput::Text { + text: "late pending input".to_string(), + text_elements: Vec::new(), + }] + )); + + let third = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected item completed event") + .expect("channel open"); + assert!(matches!( + third.msg, + EventMsg::ItemCompleted(ItemCompletedEvent { + item: TurnItem::UserMessage(UserMessageItem { content, .. }), + .. + }) if content == vec![UserInput::Text { + text: "late pending input".to_string(), + text_elements: Vec::new(), + }] + )); + + let fourth = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected legacy user message event") + .expect("channel open"); + assert!(matches!( + fourth.msg, + EventMsg::UserMessage(UserMessageEvent { + message, + images, + text_elements, + local_images, + }) if message == "late pending input" + && images == Some(Vec::new()) + && text_elements.is_empty() + && local_images.is_empty() + )); + + let fifth = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected turn complete event") + .expect("channel open"); + assert!(matches!( + fifth.msg, + EventMsg::TurnComplete(TurnCompleteEvent { + turn_id, + last_agent_message: None, + }) if turn_id == tc.sub_id + )); +} + +#[tokio::test] +async fn steer_input_requires_active_turn() { + let (sess, _tc, _rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "steer".to_string(), + text_elements: Vec::new(), + }]; + + let err = sess + .steer_input(input, None) + .await + .expect_err("steering without active turn should fail"); + + assert!(matches!(err, SteerInputError::NoActiveTurn(_))); +} + +#[tokio::test] +async fn steer_input_enforces_expected_turn_id() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + let steer_input = vec![UserInput::Text { + text: "steer".to_string(), + text_elements: Vec::new(), + }]; + let err = sess + .steer_input(steer_input, Some("different-turn-id")) + .await + .expect_err("mismatched expected turn id should fail"); + + match err { + SteerInputError::ExpectedTurnMismatch { expected, actual } => { + assert_eq!( + (expected, actual), + ("different-turn-id".to_string(), tc.sub_id.clone()) + ); + } + other => panic!("unexpected error: {other:?}"), + } +} + +#[tokio::test] +async fn steer_input_returns_active_turn_id() { + let (sess, tc, _rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + let steer_input = vec![UserInput::Text { + text: "steer".to_string(), + text_elements: Vec::new(), + }]; + let turn_id = sess + .steer_input(steer_input, Some(&tc.sub_id)) + .await + .expect("steering with matching expected turn id should succeed"); + + assert_eq!(turn_id, tc.sub_id); + assert!(sess.has_pending_input().await); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn abort_review_task_emits_exited_then_aborted_and_records_history() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let input = vec![UserInput::Text { + text: "start review".to_string(), + text_elements: Vec::new(), + }]; + sess.spawn_task(Arc::clone(&tc), input, ReviewTask::new()) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + // Aborting a review task should exit review mode before surfacing the abort to the client. + // We scan for these events (rather than relying on fixed ordering) since unrelated events + // may interleave. + let mut exited_review_mode_idx = None; + let mut turn_aborted_idx = None; + let mut idx = 0usize; + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(3); + while tokio::time::Instant::now() < deadline { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + let evt = tokio::time::timeout(remaining, rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + let event_idx = idx; + idx = idx.saturating_add(1); + match evt.msg { + EventMsg::ExitedReviewMode(ev) => { + assert!(ev.review_output.is_none()); + exited_review_mode_idx = Some(event_idx); + } + EventMsg::TurnAborted(ev) => { + assert_eq!(TurnAbortReason::Interrupted, ev.reason); + turn_aborted_idx = Some(event_idx); + break; + } + _ => {} + } + } + assert!( + exited_review_mode_idx.is_some(), + "expected ExitedReviewMode after abort" + ); + assert!( + turn_aborted_idx.is_some(), + "expected TurnAborted after abort" + ); + assert!( + exited_review_mode_idx.unwrap() < turn_aborted_idx.unwrap(), + "expected ExitedReviewMode before TurnAborted" + ); + + let history = sess.clone_history().await; + // The `` marker is silent in the event stream, so verify it is still + // recorded in history for the model. + assert!( + history.raw_items().iter().any(|item| { + let ResponseItem::Message { role, content, .. } = item else { + return false; + }; + if role != "user" { + return false; + } + content.iter().any(|content_item| { + let ContentItem::InputText { text } = content_item else { + return false; + }; + text.contains(crate::contextual_user_message::TURN_ABORTED_OPEN_TAG) + }) + }), + "expected a model-visible turn aborted marker in history after interrupt" + ); +} + +#[tokio::test] +async fn fatal_tool_error_stops_turn_and_reports_error() { + let (session, turn_context, _rx) = make_session_and_context_with_rx().await; + let tools = { + session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await + }; + let app_tools = Some(tools.clone()); + let router = ToolRouter::from_config( + &turn_context.tools_config, + Some( + tools + .into_iter() + .map(|(name, tool)| (name, tool.tool)) + .collect(), + ), + app_tools, + turn_context.dynamic_tools.as_slice(), + ); + let item = ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-1".to_string(), + name: "shell".to_string(), + input: "{}".to_string(), + }; + + let call = ToolRouter::build_tool_call(session.as_ref(), item.clone()) + .await + .expect("build tool call") + .expect("tool call present"); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let err = router + .dispatch_tool_call( + Arc::clone(&session), + Arc::clone(&turn_context), + tracker, + call, + ToolCallSource::Direct, + ) + .await + .expect_err("expected fatal error"); + + match err { + FunctionCallError::Fatal(message) => { + assert_eq!(message, "tool shell invoked with incompatible payload"); + } + other => panic!("expected FunctionCallError::Fatal, got {other:?}"), + } +} + +async fn sample_rollout( + session: &Session, + _turn_context: &TurnContext, +) -> (Vec, Vec) { + let mut rollout_items = Vec::new(); + let mut live_history = ContextManager::new(); + + // Use the same turn_context source as record_initial_history so model_info (and thus + // personality_spec) matches reconstruction. + let reconstruction_turn = session.new_default_turn().await; + let mut initial_context = session + .build_initial_context(reconstruction_turn.as_ref()) + .await; + // Ensure personality_spec is present when Personality is enabled, so expected matches + // what reconstruction produces (build_initial_context may omit it when baked into model). + if !initial_context.iter().any(|m| { + matches!(m, ResponseItem::Message { role, content, .. } + if role == "developer" + && content.iter().any(|c| { + matches!(c, ContentItem::InputText { text } if text.contains("")) + })) + }) && let Some(p) = reconstruction_turn.personality + && session.features.enabled(Feature::Personality) + && let Some(personality_message) = reconstruction_turn + .model_info + .model_messages + .as_ref() + .and_then(|m| m.get_personality_message(Some(p)).filter(|s| !s.is_empty())) + { + let msg = DeveloperInstructions::personality_spec_message(personality_message).into(); + let insert_at = initial_context + .iter() + .position(|m| matches!(m, ResponseItem::Message { role, .. } if role == "developer")) + .map(|i| i + 1) + .unwrap_or(0); + initial_context.insert(insert_at, msg); + } + for item in &initial_context { + rollout_items.push(RolloutItem::ResponseItem(item.clone())); + } + live_history.record_items( + initial_context.iter(), + reconstruction_turn.truncation_policy, + ); + + let user1 = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "first user".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&user1), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(user1.clone())); + + let assistant1 = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "assistant reply one".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&assistant1), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(assistant1.clone())); + + let summary1 = "summary one"; + let snapshot1 = live_history + .clone() + .for_prompt(&reconstruction_turn.model_info.input_modalities); + let user_messages1 = collect_user_messages(&snapshot1); + let rebuilt1 = compact::build_compacted_history(Vec::new(), &user_messages1, summary1); + live_history.replace(rebuilt1); + rollout_items.push(RolloutItem::Compacted(CompactedItem { + message: summary1.to_string(), + replacement_history: None, + })); + + let user2 = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "second user".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&user2), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(user2.clone())); + + let assistant2 = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "assistant reply two".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&assistant2), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(assistant2.clone())); + + let summary2 = "summary two"; + let snapshot2 = live_history + .clone() + .for_prompt(&reconstruction_turn.model_info.input_modalities); + let user_messages2 = collect_user_messages(&snapshot2); + let rebuilt2 = compact::build_compacted_history(Vec::new(), &user_messages2, summary2); + live_history.replace(rebuilt2); + rollout_items.push(RolloutItem::Compacted(CompactedItem { + message: summary2.to_string(), + replacement_history: None, + })); + + let user3 = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "third user".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&user3), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(user3)); + + let assistant3 = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "assistant reply three".to_string(), + }], + end_turn: None, + phase: None, + }; + live_history.record_items( + std::iter::once(&assistant3), + reconstruction_turn.truncation_policy, + ); + rollout_items.push(RolloutItem::ResponseItem(assistant3)); + + ( + rollout_items, + live_history.for_prompt(&reconstruction_turn.model_info.input_modalities), + ) +} + +#[tokio::test] +async fn rejects_escalated_permissions_when_policy_not_on_request() { + use crate::exec::ExecParams; + use crate::protocol::AskForApproval; + use crate::protocol::SandboxPolicy; + use crate::sandboxing::SandboxPermissions; + use crate::turn_diff_tracker::TurnDiffTracker; + use std::collections::HashMap; + + let (session, mut turn_context_raw) = make_session_and_context().await; + // Ensure policy is NOT OnRequest so the early rejection path triggers + turn_context_raw + .approval_policy + .set(AskForApproval::OnFailure) + .expect("test setup should allow updating approval policy"); + let session = Arc::new(session); + let mut turn_context = Arc::new(turn_context_raw); + + let timeout_ms = 1000; + let sandbox_permissions = SandboxPermissions::RequireEscalated; + let params = ExecParams { + command: if cfg!(windows) { + vec![ + "cmd.exe".to_string(), + "/C".to_string(), + "echo hi".to_string(), + ] + } else { + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "echo hi".to_string(), + ] + }, + cwd: turn_context.cwd.clone(), + expiration: timeout_ms.into(), + env: HashMap::new(), + network: None, + sandbox_permissions, + windows_sandbox_level: turn_context.windows_sandbox_level, + justification: Some("test".to_string()), + arg0: None, + }; + + let params2 = ExecParams { + sandbox_permissions: SandboxPermissions::UseDefault, + command: params.command.clone(), + cwd: params.cwd.clone(), + expiration: timeout_ms.into(), + env: HashMap::new(), + network: None, + windows_sandbox_level: turn_context.windows_sandbox_level, + justification: params.justification.clone(), + arg0: None, + }; + + let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + + let tool_name = "shell"; + let call_id = "test-call".to_string(); + + let handler = ShellHandler; + let resp = handler + .handle(ToolInvocation { + session: Arc::clone(&session), + turn: Arc::clone(&turn_context), + tracker: Arc::clone(&turn_diff_tracker), + call_id, + tool_name: tool_name.to_string(), + payload: ToolPayload::Function { + arguments: serde_json::json!({ + "command": params.command.clone(), + "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), + "timeout_ms": params.expiration.timeout_ms(), + "sandbox_permissions": params.sandbox_permissions, + "justification": params.justification.clone(), + }) + .to_string(), + }, + }) + .await; + + let Err(FunctionCallError::RespondToModel(output)) = resp else { + panic!("expected error result"); + }; + + let expected = format!( + "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}", + policy = turn_context.approval_policy.value() + ); + + pretty_assertions::assert_eq!(output, expected); + + // Now retry the same command WITHOUT escalated permissions; should succeed. + // Force DangerFullAccess to avoid platform sandbox dependencies in tests. + Arc::get_mut(&mut turn_context) + .expect("unique turn context Arc") + .sandbox_policy + .set(SandboxPolicy::DangerFullAccess) + .expect("test setup should allow updating sandbox policy"); + + let resp2 = handler + .handle(ToolInvocation { + session: Arc::clone(&session), + turn: Arc::clone(&turn_context), + tracker: Arc::clone(&turn_diff_tracker), + call_id: "test-call-2".to_string(), + tool_name: tool_name.to_string(), + payload: ToolPayload::Function { + arguments: serde_json::json!({ + "command": params2.command.clone(), + "workdir": Some(turn_context.cwd.to_string_lossy().to_string()), + "timeout_ms": params2.expiration.timeout_ms(), + "sandbox_permissions": params2.sandbox_permissions, + "justification": params2.justification.clone(), + }) + .to_string(), + }, + }) + .await; + + let output = match resp2.expect("expected Ok result") { + ToolOutput::Function { + body: FunctionCallOutputBody::Text(content), + .. + } => content, + _ => panic!("unexpected tool output"), + }; + + #[derive(Deserialize, PartialEq, Eq, Debug)] + struct ResponseExecMetadata { + exit_code: i32, + } + + #[derive(Deserialize)] + struct ResponseExecOutput { + output: String, + metadata: ResponseExecMetadata, + } + + let exec_output: ResponseExecOutput = + serde_json::from_str(&output).expect("valid exec output json"); + + pretty_assertions::assert_eq!(exec_output.metadata, ResponseExecMetadata { exit_code: 0 }); + assert!(exec_output.output.contains("hi")); +} +#[tokio::test] +async fn unified_exec_rejects_escalated_permissions_when_policy_not_on_request() { + use crate::protocol::AskForApproval; + use crate::sandboxing::SandboxPermissions; + use crate::turn_diff_tracker::TurnDiffTracker; + + let (session, mut turn_context_raw) = make_session_and_context().await; + turn_context_raw + .approval_policy + .set(AskForApproval::OnFailure) + .expect("test setup should allow updating approval policy"); + let session = Arc::new(session); + let turn_context = Arc::new(turn_context_raw); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + + let handler = UnifiedExecHandler; + let resp = handler + .handle(ToolInvocation { + session: Arc::clone(&session), + turn: Arc::clone(&turn_context), + tracker: Arc::clone(&tracker), + call_id: "exec-call".to_string(), + tool_name: "exec_command".to_string(), + payload: ToolPayload::Function { + arguments: serde_json::json!({ + "cmd": "echo hi", + "sandbox_permissions": SandboxPermissions::RequireEscalated, + "justification": "need unsandboxed execution", + }) + .to_string(), + }, + }) + .await; + + let Err(FunctionCallError::RespondToModel(output)) = resp else { + panic!("expected error result"); + }; + + let expected = format!( + "approval policy is {policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {policy:?}", + policy = turn_context.approval_policy.value() + ); + + pretty_assertions::assert_eq!(output, expected); +}