From 6cf0ed4e79468d2f3fd3855458bf05419d4fac64 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Thu, 5 Mar 2026 16:31:20 -0800 Subject: [PATCH] Refine realtime startup context formatting (#13560) ## Summary - group recent work by git repo when available, otherwise by directory - render recent work as bounded user asks with per-thread cwd context - exclude hidden files and directories from workspace trees --- .../tests/suite/v2/realtime_conversation.rs | 19 + codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/realtime_context.rs | 532 ++++++++++++++++++ codex-rs/core/src/realtime_conversation.rs | 9 + .../core/tests/suite/realtime_conversation.rs | 247 +++++++- 5 files changed, 794 insertions(+), 14 deletions(-) create mode 100644 codex-rs/core/src/realtime_context.rs diff --git a/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs b/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs index 71150d712..d12578448 100644 --- a/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs +++ b/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs @@ -36,6 +36,7 @@ use tempfile::TempDir; use tokio::time::timeout; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); +const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex."; #[tokio::test] async fn realtime_conversation_streams_v2_notifications() -> Result<()> { @@ -114,6 +115,18 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { assert_eq!(started.thread_id, thread_start.thread.id); assert!(started.session_id.is_some()); + let startup_context_request = realtime_server.wait_for_request(0, 0).await; + assert_eq!( + startup_context_request.body_json()["type"].as_str(), + Some("session.update") + ); + assert!( + startup_context_request.body_json()["session"]["instructions"] + .as_str() + .context("expected startup context instructions")? + .contains(STARTUP_CONTEXT_HEADER) + ); + let audio_append_request_id = mcp .send_thread_realtime_append_audio_request(ThreadRealtimeAppendAudioParams { thread_id: started.thread_id.clone(), @@ -183,6 +196,12 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { connection[0].body_json()["type"].as_str(), Some("session.update") ); + assert!( + connection[0].body_json()["session"]["instructions"] + .as_str() + .context("expected startup context instructions")? + .contains(STARTUP_CONTEXT_HEADER) + ); let mut request_types = [ connection[1].body_json()["type"] .as_str() diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 828bbe214..9b822a85a 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -13,6 +13,7 @@ pub mod auth; mod client; mod client_common; pub mod codex; +mod realtime_context; mod realtime_conversation; pub use codex::SteerInputError; mod codex_thread; diff --git a/codex-rs/core/src/realtime_context.rs b/codex-rs/core/src/realtime_context.rs new file mode 100644 index 000000000..e15adabc4 --- /dev/null +++ b/codex-rs/core/src/realtime_context.rs @@ -0,0 +1,532 @@ +use crate::codex::Session; +use crate::git_info::resolve_root_git_project_for_trust; +use crate::truncate::TruncationPolicy; +use crate::truncate::truncate_text; +use chrono::Utc; +use codex_state::SortKey; +use codex_state::ThreadMetadata; +use dirs::home_dir; +use std::cmp::Reverse; +use std::collections::HashMap; +use std::collections::HashSet; +use std::ffi::OsStr; +use std::fs::DirEntry; +use std::io; +use std::path::Path; +use std::path::PathBuf; +use tracing::debug; +use tracing::info; +use tracing::warn; + +const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex.\nThis is background context about recent work and machine/workspace layout. It may be incomplete or stale. Use it to inform responses, and do not repeat it back unless relevant."; +const RECENT_WORK_SECTION_TOKEN_BUDGET: usize = 2_200; +const WORKSPACE_SECTION_TOKEN_BUDGET: usize = 1_600; +const NOTES_SECTION_TOKEN_BUDGET: usize = 300; +const MAX_RECENT_THREADS: usize = 40; +const MAX_RECENT_WORK_GROUPS: usize = 8; +const MAX_CURRENT_CWD_ASKS: usize = 8; +const MAX_OTHER_CWD_ASKS: usize = 5; +const MAX_ASK_CHARS: usize = 240; +const TREE_MAX_DEPTH: usize = 2; +const DIR_ENTRY_LIMIT: usize = 20; +const APPROX_BYTES_PER_TOKEN: usize = 4; +const NOISY_DIR_NAMES: &[&str] = &[ + ".git", + ".next", + ".pytest_cache", + ".ruff_cache", + "__pycache__", + "build", + "dist", + "node_modules", + "out", + "target", +]; + +pub(crate) async fn build_realtime_startup_context( + sess: &Session, + budget_tokens: usize, +) -> Option { + let config = sess.get_config().await; + let cwd = config.cwd.clone(); + let recent_threads = load_recent_threads(sess).await; + let recent_work_section = build_recent_work_section(&cwd, &recent_threads); + let workspace_section = build_workspace_section(&cwd); + + if recent_work_section.is_none() && workspace_section.is_none() { + debug!("realtime startup context unavailable; skipping injection"); + return None; + } + + let mut parts = vec![STARTUP_CONTEXT_HEADER.to_string()]; + + let has_recent_work_section = recent_work_section.is_some(); + let has_workspace_section = workspace_section.is_some(); + + if let Some(section) = format_section( + "Recent Work", + recent_work_section, + RECENT_WORK_SECTION_TOKEN_BUDGET, + ) { + parts.push(section); + } + if let Some(section) = format_section( + "Machine / Workspace Map", + workspace_section, + WORKSPACE_SECTION_TOKEN_BUDGET, + ) { + parts.push(section); + } + if let Some(section) = format_section( + "Notes", + Some("Built at realtime startup from persisted thread metadata in the state DB and a bounded local workspace scan. This excludes repo memory instructions, AGENTS files, project-doc prompt blends, and memory summaries.".to_string()), + NOTES_SECTION_TOKEN_BUDGET, + ) { + parts.push(section); + } + + let context = truncate_text(&parts.join("\n\n"), TruncationPolicy::Tokens(budget_tokens)); + debug!( + approx_tokens = approx_token_count(&context), + bytes = context.len(), + has_recent_work_section, + has_workspace_section, + "built realtime startup context" + ); + info!("realtime startup context: {context}"); + Some(context) +} + +async fn load_recent_threads(sess: &Session) -> Vec { + let Some(state_db) = sess.services.state_db.as_ref() else { + return Vec::new(); + }; + + match state_db + .list_threads( + MAX_RECENT_THREADS, + None, + SortKey::UpdatedAt, + &[], + None, + false, + None, + ) + .await + { + Ok(page) => page.items, + Err(err) => { + warn!("failed to load realtime startup threads from state db: {err}"); + Vec::new() + } + } +} + +fn build_recent_work_section(cwd: &Path, recent_threads: &[ThreadMetadata]) -> Option { + let mut groups: HashMap> = HashMap::new(); + for entry in recent_threads { + let group = + resolve_root_git_project_for_trust(&entry.cwd).unwrap_or_else(|| entry.cwd.clone()); + groups.entry(group).or_default().push(entry); + } + + let current_group = + resolve_root_git_project_for_trust(cwd).unwrap_or_else(|| cwd.to_path_buf()); + let mut groups = groups.into_iter().collect::>(); + groups.sort_by(|(left_group, left_entries), (right_group, right_entries)| { + let left_latest = left_entries + .iter() + .map(|entry| entry.updated_at) + .max() + .unwrap_or_else(Utc::now); + let right_latest = right_entries + .iter() + .map(|entry| entry.updated_at) + .max() + .unwrap_or_else(Utc::now); + ( + *left_group != current_group, + Reverse(left_latest), + left_group.as_os_str(), + ) + .cmp(&( + *right_group != current_group, + Reverse(right_latest), + right_group.as_os_str(), + )) + }); + + let sections = groups + .into_iter() + .take(MAX_RECENT_WORK_GROUPS) + .filter_map(|(group, mut entries)| { + entries.sort_by_key(|entry| Reverse(entry.updated_at)); + format_thread_group(¤t_group, &group, entries) + }) + .collect::>(); + (!sections.is_empty()).then(|| sections.join("\n\n")) +} + +fn build_workspace_section(cwd: &Path) -> Option { + build_workspace_section_with_user_root(cwd, home_dir()) +} + +fn build_workspace_section_with_user_root( + cwd: &Path, + user_root: Option, +) -> Option { + let git_root = resolve_root_git_project_for_trust(cwd); + let cwd_tree = render_tree(cwd); + let git_root_tree = git_root + .as_ref() + .filter(|git_root| git_root.as_path() != cwd) + .and_then(|git_root| render_tree(git_root)); + let user_root_tree = user_root + .as_ref() + .filter(|user_root| user_root.as_path() != cwd) + .filter(|user_root| { + git_root + .as_ref() + .is_none_or(|git_root| git_root.as_path() != user_root.as_path()) + }) + .and_then(|user_root| render_tree(user_root)); + + if cwd_tree.is_none() && git_root.is_none() && user_root_tree.is_none() { + return None; + } + + let mut lines = vec![ + format!("Current working directory: {}", cwd.display()), + format!("Working directory name: {}", display_name(cwd)), + ]; + + if let Some(git_root) = &git_root { + lines.push(format!("Git root: {}", git_root.display())); + lines.push(format!("Git project: {}", display_name(git_root))); + } + if let Some(user_root) = &user_root { + lines.push(format!("User root: {}", user_root.display())); + } + + if let Some(tree) = cwd_tree { + lines.push(String::new()); + lines.push("Working directory tree:".to_string()); + lines.extend(tree); + } + + if let Some(tree) = git_root_tree { + lines.push(String::new()); + lines.push("Git root tree:".to_string()); + lines.extend(tree); + } + + if let Some(tree) = user_root_tree { + lines.push(String::new()); + lines.push("User root tree:".to_string()); + lines.extend(tree); + } + + Some(lines.join("\n")) +} + +fn render_tree(root: &Path) -> Option> { + if !root.is_dir() { + return None; + } + + let mut lines = Vec::new(); + collect_tree_lines(root, 0, &mut lines); + (!lines.is_empty()).then_some(lines) +} + +fn collect_tree_lines(dir: &Path, depth: usize, lines: &mut Vec) { + if depth >= TREE_MAX_DEPTH { + return; + } + + let entries = match read_sorted_entries(dir) { + Ok(entries) => entries, + Err(_) => return, + }; + let total_entries = entries.len(); + + for entry in entries.into_iter().take(DIR_ENTRY_LIMIT) { + let Ok(file_type) = entry.file_type() else { + continue; + }; + let name = file_name_string(&entry.path()); + let indent = " ".repeat(depth); + let suffix = if file_type.is_dir() { "/" } else { "" }; + lines.push(format!("{indent}- {name}{suffix}")); + if file_type.is_dir() { + collect_tree_lines(&entry.path(), depth + 1, lines); + } + } + + if total_entries > DIR_ENTRY_LIMIT { + lines.push(format!( + "{}- ... {} more entries", + " ".repeat(depth), + total_entries - DIR_ENTRY_LIMIT + )); + } +} + +fn read_sorted_entries(dir: &Path) -> io::Result> { + let mut entries = std::fs::read_dir(dir)? + .filter_map(Result::ok) + .filter(|entry| !is_noisy_name(&entry.file_name())) + .collect::>(); + entries.sort_by(|left, right| { + let left_is_dir = left + .file_type() + .map(|file_type| file_type.is_dir()) + .unwrap_or(false); + let right_is_dir = right + .file_type() + .map(|file_type| file_type.is_dir()) + .unwrap_or(false); + (!left_is_dir, file_name_string(&left.path())) + .cmp(&(!right_is_dir, file_name_string(&right.path()))) + }); + Ok(entries) +} + +fn is_noisy_name(name: &OsStr) -> bool { + let name = name.to_string_lossy(); + name.starts_with('.') || NOISY_DIR_NAMES.iter().any(|noisy| *noisy == name) +} + +fn format_section(title: &str, body: Option, budget_tokens: usize) -> Option { + let body = body?; + let body = body.trim(); + if body.is_empty() { + return None; + } + + Some(format!( + "## {title}\n{}", + truncate_text(body, TruncationPolicy::Tokens(budget_tokens)) + )) +} + +fn format_thread_group( + current_group: &Path, + group: &Path, + entries: Vec<&ThreadMetadata>, +) -> Option { + let latest = entries.first()?; + let group_label = if resolve_root_git_project_for_trust(latest.cwd.as_path()).is_some() { + format!("### Git repo: {}", group.display()) + } else { + format!("### Directory: {}", group.display()) + }; + let mut lines = vec![ + group_label, + format!("Recent sessions: {}", entries.len()), + format!("Latest activity: {}", latest.updated_at.to_rfc3339()), + ]; + + if let Some(git_branch) = latest + .git_branch + .as_deref() + .filter(|git_branch| !git_branch.is_empty()) + { + lines.push(format!("Latest branch: {git_branch}")); + } + + lines.push(String::new()); + lines.push("User asks:".to_string()); + + let mut seen = HashSet::new(); + let max_asks = if group == current_group { + MAX_CURRENT_CWD_ASKS + } else { + MAX_OTHER_CWD_ASKS + }; + + for entry in entries { + let Some(first_user_message) = entry.first_user_message.as_deref() else { + continue; + }; + let ask = first_user_message + .split_whitespace() + .collect::>() + .join(" "); + let dedupe_key = format!("{}:{ask}", entry.cwd.display()); + if ask.is_empty() || !seen.insert(dedupe_key) { + continue; + } + let ask = if ask.chars().count() > MAX_ASK_CHARS { + format!( + "{}...", + ask.chars() + .take(MAX_ASK_CHARS.saturating_sub(3)) + .collect::() + ) + } else { + ask + }; + lines.push(format!("- {}: {ask}", entry.cwd.display())); + if seen.len() == max_asks { + break; + } + } + + (lines.len() > 5).then(|| lines.join("\n")) +} + +fn display_name(path: &Path) -> String { + path.file_name() + .and_then(OsStr::to_str) + .map(str::to_owned) + .unwrap_or_else(|| path.display().to_string()) +} + +fn file_name_string(path: &Path) -> String { + path.file_name() + .and_then(OsStr::to_str) + .map(str::to_owned) + .unwrap_or_else(|| path.display().to_string()) +} + +fn approx_token_count(text: &str) -> usize { + text.len().div_ceil(APPROX_BYTES_PER_TOKEN) +} + +#[cfg(test)] +mod tests { + use super::build_recent_work_section; + use super::build_workspace_section; + use super::build_workspace_section_with_user_root; + use chrono::TimeZone; + use chrono::Utc; + use codex_protocol::ThreadId; + use codex_state::ThreadMetadata; + use pretty_assertions::assert_eq; + use std::fs; + use std::path::PathBuf; + use std::process::Command; + use tempfile::TempDir; + + fn thread_metadata(cwd: &str, title: &str, first_user_message: &str) -> ThreadMetadata { + ThreadMetadata { + id: ThreadId::new(), + rollout_path: PathBuf::from("/tmp/rollout.jsonl"), + created_at: Utc + .timestamp_opt(1_709_251_100, 0) + .single() + .expect("valid timestamp"), + updated_at: Utc + .timestamp_opt(1_709_251_200, 0) + .single() + .expect("valid timestamp"), + source: "cli".to_string(), + agent_nickname: None, + agent_role: None, + model_provider: "test-provider".to_string(), + cwd: PathBuf::from(cwd), + cli_version: "test".to_string(), + title: title.to_string(), + sandbox_policy: "workspace-write".to_string(), + approval_mode: "never".to_string(), + tokens_used: 0, + first_user_message: Some(first_user_message.to_string()), + archived_at: None, + git_sha: None, + git_branch: Some("main".to_string()), + git_origin_url: None, + } + } + + #[test] + fn workspace_section_requires_meaningful_structure() { + let cwd = TempDir::new().expect("tempdir"); + assert_eq!( + build_workspace_section_with_user_root(cwd.path(), None), + None + ); + } + + #[test] + fn workspace_section_includes_tree_when_entries_exist() { + let cwd = TempDir::new().expect("tempdir"); + fs::create_dir(cwd.path().join("docs")).expect("create docs dir"); + fs::write(cwd.path().join("README.md"), "hello").expect("write readme"); + + let section = build_workspace_section(cwd.path()).expect("workspace section"); + assert!(section.contains("Working directory tree:")); + assert!(section.contains("- docs/")); + assert!(section.contains("- README.md")); + } + + #[test] + fn workspace_section_includes_user_root_tree_when_distinct() { + let root = TempDir::new().expect("tempdir"); + let cwd = root.path().join("cwd"); + let git_root = root.path().join("git"); + let user_root = root.path().join("home"); + + fs::create_dir_all(cwd.join("docs")).expect("create cwd docs dir"); + fs::write(cwd.join("README.md"), "hello").expect("write cwd readme"); + fs::create_dir_all(git_root.join(".git")).expect("create git dir"); + fs::write(git_root.join("Cargo.toml"), "[workspace]").expect("write git root marker"); + fs::create_dir_all(user_root.join("code")).expect("create user root child"); + fs::write(user_root.join(".zshrc"), "export TEST=1").expect("write home file"); + + let section = build_workspace_section_with_user_root(cwd.as_path(), Some(user_root)) + .expect("workspace section"); + assert!(section.contains("User root tree:")); + assert!(section.contains("- code/")); + assert!(!section.contains("- .zshrc")); + } + + #[test] + fn recent_work_section_groups_threads_by_cwd() { + let root = TempDir::new().expect("tempdir"); + let repo = root.path().join("repo"); + let workspace_a = repo.join("workspace-a"); + let workspace_b = repo.join("workspace-b"); + let outside = root.path().join("outside"); + + fs::create_dir(&repo).expect("create repo dir"); + Command::new("git") + .env("GIT_CONFIG_GLOBAL", "/dev/null") + .env("GIT_CONFIG_NOSYSTEM", "1") + .args(["init"]) + .current_dir(&repo) + .output() + .expect("git init"); + fs::create_dir_all(&workspace_a).expect("create workspace a"); + fs::create_dir_all(&workspace_b).expect("create workspace b"); + fs::create_dir_all(&outside).expect("create outside dir"); + + let recent_threads = vec![ + thread_metadata( + workspace_a.to_string_lossy().as_ref(), + "Investigate realtime startup context", + "Log the startup context before sending it", + ), + thread_metadata( + workspace_b.to_string_lossy().as_ref(), + "Trim websocket startup payload", + "Remove memories from the realtime startup context", + ), + thread_metadata(outside.to_string_lossy().as_ref(), "", "Inspect flaky test"), + ]; + let current_cwd = workspace_a; + let repo = fs::canonicalize(repo).expect("canonicalize repo"); + + let section = build_recent_work_section(current_cwd.as_path(), &recent_threads) + .expect("recent work section"); + assert!(section.contains(&format!("### Git repo: {}", repo.display()))); + assert!(section.contains("Recent sessions: 2")); + assert!(section.contains("User asks:")); + assert!(section.contains(&format!( + "- {}: Log the startup context before sending it", + current_cwd.display() + ))); + assert!(section.contains(&format!("### Directory: {}", outside.display()))); + assert!(section.contains(&format!("- {}: Inspect flaky test", outside.display()))); + } +} diff --git a/codex-rs/core/src/realtime_conversation.rs b/codex-rs/core/src/realtime_conversation.rs index 656b590e5..4d8d6127d 100644 --- a/codex-rs/core/src/realtime_conversation.rs +++ b/codex-rs/core/src/realtime_conversation.rs @@ -5,6 +5,7 @@ use crate::codex::Session; use crate::default_client::default_headers; use crate::error::CodexErr; use crate::error::Result as CodexResult; +use crate::realtime_context::build_realtime_startup_context; use async_channel::Receiver; use async_channel::Sender; use async_channel::TrySendError; @@ -43,6 +44,7 @@ const AUDIO_IN_QUEUE_CAPACITY: usize = 256; const USER_TEXT_IN_QUEUE_CAPACITY: usize = 64; const HANDOFF_OUT_QUEUE_CAPACITY: usize = 64; const OUTPUT_EVENTS_QUEUE_CAPACITY: usize = 256; +const REALTIME_STARTUP_CONTEXT_TOKEN_BUDGET: usize = 5_000; pub(crate) struct RealtimeConversationManager { state: Mutex>, @@ -282,6 +284,13 @@ pub(crate) async fn handle_start( .experimental_realtime_ws_backend_prompt .clone() .unwrap_or(params.prompt); + let prompt = + match build_realtime_startup_context(sess.as_ref(), REALTIME_STARTUP_CONTEXT_TOKEN_BUDGET) + .await + { + Some(context) => format!("{prompt}\n\n{context}"), + None => prompt, + }; let model = config.experimental_realtime_ws_model.clone(); let requested_session_id = params diff --git a/codex-rs/core/tests/suite/realtime_conversation.rs b/codex-rs/core/tests/suite/realtime_conversation.rs index 1a3d87b55..71976e00c 100644 --- a/codex-rs/core/tests/suite/realtime_conversation.rs +++ b/codex-rs/core/tests/suite/realtime_conversation.rs @@ -1,6 +1,9 @@ +use anyhow::Context; use anyhow::Result; +use chrono::Utc; use codex_core::CodexAuth; use codex_core::auth::OPENAI_API_KEY_ENV_VAR; +use codex_protocol::ThreadId; use codex_protocol::protocol::CodexErrorInfo; use codex_protocol::protocol::ConversationAudioParams; use codex_protocol::protocol::ConversationStartParams; @@ -11,6 +14,7 @@ use codex_protocol::protocol::Op; use codex_protocol::protocol::RealtimeAudioFrame; use codex_protocol::protocol::RealtimeConversationRealtimeEvent; use codex_protocol::protocol::RealtimeEvent; +use codex_protocol::protocol::SessionSource; use codex_protocol::user_input::UserInput; use core_test_support::responses; use core_test_support::responses::start_mock_server; @@ -18,6 +22,7 @@ use core_test_support::responses::start_websocket_server; use core_test_support::skip_if_no_network; use core_test_support::streaming_sse::StreamingSseChunk; use core_test_support::streaming_sse::start_streaming_sse_server; +use core_test_support::test_codex::TestCodex; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use core_test_support::wait_for_event_match; @@ -25,9 +30,57 @@ use pretty_assertions::assert_eq; use serde_json::Value; use serde_json::json; use std::ffi::OsString; +use std::fs; use std::time::Duration; use tokio::sync::oneshot; +const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex."; +const MEMORY_PROMPT_PHRASE: &str = + "You have access to a memory folder with guidance from prior runs."; + +fn websocket_request_text( + request: &core_test_support::responses::WebSocketRequest, +) -> Option { + request.body_json()["item"]["content"][0]["text"] + .as_str() + .map(str::to_owned) +} + +fn websocket_request_instructions( + request: &core_test_support::responses::WebSocketRequest, +) -> Option { + request.body_json()["session"]["instructions"] + .as_str() + .map(str::to_owned) +} + +async fn seed_recent_thread( + test: &TestCodex, + title: &str, + first_user_message: &str, + slug: &str, +) -> Result<()> { + let db = test.codex.state_db().context("state db enabled")?; + let thread_id = ThreadId::new(); + let updated_at = Utc::now(); + let mut metadata_builder = codex_state::ThreadMetadataBuilder::new( + thread_id, + test.codex_home_path() + .join(format!("rollout-{thread_id}.jsonl")), + updated_at, + SessionSource::Cli, + ); + metadata_builder.cwd = test.workspace_path(format!("workspace-{slug}")); + metadata_builder.model_provider = Some("test-provider".to_string()); + metadata_builder.git_branch = Some(format!("branch-{slug}")); + let mut metadata = metadata_builder.build("test-provider"); + metadata.title = title.to_string(); + metadata.first_user_message = Some(first_user_message.to_string()); + db.upsert_thread(&metadata).await?; + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn conversation_start_audio_text_close_round_trip() -> Result<()> { skip_if_no_network!(Ok(())); @@ -122,10 +175,9 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> { connection[0].body_json()["type"].as_str(), Some("session.update") ); - assert_eq!( - connection[0].body_json()["session"]["instructions"].as_str(), - Some("backend prompt") - ); + let initial_instructions = websocket_request_instructions(&connection[0]) + .expect("initial session update instructions"); + assert!(initial_instructions.starts_with("backend prompt")); assert_eq!( server.handshakes()[1] .header("x-session-id") @@ -452,19 +504,17 @@ async fn conversation_second_start_replaces_runtime() -> Result<()> { let connections = server.connections(); assert_eq!(connections.len(), 3); assert_eq!(connections[1].len(), 1); - assert_eq!( - connections[1][0].body_json()["session"]["instructions"].as_str(), - Some("old") - ); + let old_instructions = + websocket_request_instructions(&connections[1][0]).expect("old session instructions"); + assert!(old_instructions.starts_with("old")); assert_eq!( server.handshakes()[1].header("x-session-id").as_deref(), Some("conv_old") ); assert_eq!(connections[2].len(), 2); - assert_eq!( - connections[2][0].body_json()["session"]["instructions"].as_str(), - Some("new") - ); + let new_instructions = + websocket_request_instructions(&connections[2][0]).expect("new session instructions"); + assert!(new_instructions.starts_with("new")); assert_eq!( server.handshakes()[2].header("x-session-id").as_deref(), Some("conv_new") @@ -570,9 +620,178 @@ async fn conversation_uses_experimental_realtime_ws_backend_prompt_override() -> let connections = server.connections(); assert_eq!(connections.len(), 2); + let overridden_instructions = websocket_request_instructions(&connections[1][0]) + .expect("overridden session instructions"); + assert!(overridden_instructions.starts_with("prompt from config")); + + server.shutdown().await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn conversation_start_injects_startup_context_from_thread_history() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_websocket_server(vec![ + vec![], + vec![vec![json!({ + "type": "session.updated", + "session": { "id": "sess_context", "instructions": "backend prompt" } + })]], + ]) + .await; + + let mut builder = test_codex(); + let test = builder.build_with_websocket_server(&server).await?; + seed_recent_thread( + &test, + "Recent work: cleaned up startup flows and reviewed websocket routing.", + "Investigate realtime startup context", + "latest", + ) + .await?; + fs::create_dir_all(test.workspace_path("docs"))?; + fs::write(test.workspace_path("README.md"), "workspace marker")?; + + test.codex + .submit(Op::RealtimeConversationStart(ConversationStartParams { + prompt: "backend prompt".to_string(), + session_id: None, + })) + .await?; + + wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::SessionUpdated { session_id, .. }, + }) if session_id == "sess_context" => Some(Ok(())), + EventMsg::Error(err) => Some(Err(err.clone())), + _ => None, + }) + .await + .unwrap_or_else(|err: ErrorEvent| panic!("conversation start failed: {err:?}")); + + let startup_context_request = server.wait_for_request(1, 0).await; + let startup_context = websocket_request_instructions(&startup_context_request) + .expect("startup context request should contain instructions"); + + assert!(startup_context.contains(STARTUP_CONTEXT_HEADER)); + assert!(!startup_context.contains("## User")); + assert!(startup_context.contains("### ")); + assert!(startup_context.contains("Recent sessions: 1")); + assert!(startup_context.contains("Latest branch: branch-latest")); + assert!(startup_context.contains("User asks:")); + assert!(startup_context.contains("Investigate realtime startup context")); + assert!(startup_context.contains("## Machine / Workspace Map")); + assert!(startup_context.contains("README.md")); + assert!(!startup_context.contains(MEMORY_PROMPT_PHRASE)); + + server.shutdown().await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn conversation_startup_context_falls_back_to_workspace_map() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_websocket_server(vec![ + vec![], + vec![vec![json!({ + "type": "session.updated", + "session": { "id": "sess_workspace", "instructions": "backend prompt" } + })]], + ]) + .await; + + let mut builder = test_codex(); + let test = builder.build_with_websocket_server(&server).await?; + fs::create_dir_all(test.workspace_path("codex-rs/core"))?; + fs::write(test.workspace_path("notes.txt"), "workspace marker")?; + + test.codex + .submit(Op::RealtimeConversationStart(ConversationStartParams { + prompt: "backend prompt".to_string(), + session_id: None, + })) + .await?; + + wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::SessionUpdated { session_id, .. }, + }) if session_id == "sess_workspace" => Some(Ok(())), + EventMsg::Error(err) => Some(Err(err.clone())), + _ => None, + }) + .await + .unwrap_or_else(|err: ErrorEvent| panic!("conversation start failed: {err:?}")); + + let startup_context_request = server.wait_for_request(1, 0).await; + let startup_context = websocket_request_instructions(&startup_context_request) + .expect("startup context request should contain instructions"); + + assert!(startup_context.contains(STARTUP_CONTEXT_HEADER)); + assert!(startup_context.contains("## Machine / Workspace Map")); + assert!(startup_context.contains("notes.txt")); + assert!(startup_context.contains("codex-rs/")); + + server.shutdown().await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn conversation_startup_context_is_truncated_and_sent_once_per_start() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_websocket_server(vec![ + vec![], + vec![ + vec![json!({ + "type": "session.updated", + "session": { "id": "sess_truncated", "instructions": "backend prompt" } + })], + vec![], + ], + ]) + .await; + + let oversized_summary = "recent work ".repeat(3_500); + let mut builder = test_codex(); + let test = builder.build_with_websocket_server(&server).await?; + seed_recent_thread(&test, &oversized_summary, "summary", "oversized").await?; + fs::write(test.workspace_path("marker.txt"), "marker")?; + + test.codex + .submit(Op::RealtimeConversationStart(ConversationStartParams { + prompt: "backend prompt".to_string(), + session_id: None, + })) + .await?; + + wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::SessionUpdated { session_id, .. }, + }) if session_id == "sess_truncated" => Some(Ok(())), + EventMsg::Error(err) => Some(Err(err.clone())), + _ => None, + }) + .await + .unwrap_or_else(|err: ErrorEvent| panic!("conversation start failed: {err:?}")); + + let startup_context_request = server.wait_for_request(1, 0).await; + let startup_context = websocket_request_instructions(&startup_context_request) + .expect("startup context request should contain instructions"); + assert!(startup_context.contains(STARTUP_CONTEXT_HEADER)); + assert!(startup_context.len() <= 20_500); + + test.codex + .submit(Op::RealtimeConversationText(ConversationTextParams { + text: "hello".to_string(), + })) + .await?; + + let explicit_text_request = server.wait_for_request(1, 1).await; assert_eq!( - connections[1][0].body_json()["session"]["instructions"].as_str(), - Some("prompt from config") + websocket_request_text(&explicit_text_request), + Some("hello".to_string()) ); server.shutdown().await;