Refine realtime startup context formatting (#13560)
## Summary - group recent work by git repo when available, otherwise by directory - render recent work as bounded user asks with per-thread cwd context - exclude hidden files and directories from workspace trees
This commit is contained in:
parent
c3736cff0a
commit
6cf0ed4e79
5 changed files with 794 additions and 14 deletions
|
|
@ -36,6 +36,7 @@ use tempfile::TempDir;
|
|||
use tokio::time::timeout;
|
||||
|
||||
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex.";
|
||||
|
||||
#[tokio::test]
|
||||
async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
|
||||
|
|
@ -114,6 +115,18 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
|
|||
assert_eq!(started.thread_id, thread_start.thread.id);
|
||||
assert!(started.session_id.is_some());
|
||||
|
||||
let startup_context_request = realtime_server.wait_for_request(0, 0).await;
|
||||
assert_eq!(
|
||||
startup_context_request.body_json()["type"].as_str(),
|
||||
Some("session.update")
|
||||
);
|
||||
assert!(
|
||||
startup_context_request.body_json()["session"]["instructions"]
|
||||
.as_str()
|
||||
.context("expected startup context instructions")?
|
||||
.contains(STARTUP_CONTEXT_HEADER)
|
||||
);
|
||||
|
||||
let audio_append_request_id = mcp
|
||||
.send_thread_realtime_append_audio_request(ThreadRealtimeAppendAudioParams {
|
||||
thread_id: started.thread_id.clone(),
|
||||
|
|
@ -183,6 +196,12 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
|
|||
connection[0].body_json()["type"].as_str(),
|
||||
Some("session.update")
|
||||
);
|
||||
assert!(
|
||||
connection[0].body_json()["session"]["instructions"]
|
||||
.as_str()
|
||||
.context("expected startup context instructions")?
|
||||
.contains(STARTUP_CONTEXT_HEADER)
|
||||
);
|
||||
let mut request_types = [
|
||||
connection[1].body_json()["type"]
|
||||
.as_str()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ pub mod auth;
|
|||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
mod realtime_context;
|
||||
mod realtime_conversation;
|
||||
pub use codex::SteerInputError;
|
||||
mod codex_thread;
|
||||
|
|
|
|||
532
codex-rs/core/src/realtime_context.rs
Normal file
532
codex-rs/core/src/realtime_context.rs
Normal file
|
|
@ -0,0 +1,532 @@
|
|||
use crate::codex::Session;
|
||||
use crate::git_info::resolve_root_git_project_for_trust;
|
||||
use crate::truncate::TruncationPolicy;
|
||||
use crate::truncate::truncate_text;
|
||||
use chrono::Utc;
|
||||
use codex_state::SortKey;
|
||||
use codex_state::ThreadMetadata;
|
||||
use dirs::home_dir;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::ffi::OsStr;
|
||||
use std::fs::DirEntry;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tracing::debug;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex.\nThis is background context about recent work and machine/workspace layout. It may be incomplete or stale. Use it to inform responses, and do not repeat it back unless relevant.";
|
||||
const RECENT_WORK_SECTION_TOKEN_BUDGET: usize = 2_200;
|
||||
const WORKSPACE_SECTION_TOKEN_BUDGET: usize = 1_600;
|
||||
const NOTES_SECTION_TOKEN_BUDGET: usize = 300;
|
||||
const MAX_RECENT_THREADS: usize = 40;
|
||||
const MAX_RECENT_WORK_GROUPS: usize = 8;
|
||||
const MAX_CURRENT_CWD_ASKS: usize = 8;
|
||||
const MAX_OTHER_CWD_ASKS: usize = 5;
|
||||
const MAX_ASK_CHARS: usize = 240;
|
||||
const TREE_MAX_DEPTH: usize = 2;
|
||||
const DIR_ENTRY_LIMIT: usize = 20;
|
||||
const APPROX_BYTES_PER_TOKEN: usize = 4;
|
||||
const NOISY_DIR_NAMES: &[&str] = &[
|
||||
".git",
|
||||
".next",
|
||||
".pytest_cache",
|
||||
".ruff_cache",
|
||||
"__pycache__",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"out",
|
||||
"target",
|
||||
];
|
||||
|
||||
pub(crate) async fn build_realtime_startup_context(
|
||||
sess: &Session,
|
||||
budget_tokens: usize,
|
||||
) -> Option<String> {
|
||||
let config = sess.get_config().await;
|
||||
let cwd = config.cwd.clone();
|
||||
let recent_threads = load_recent_threads(sess).await;
|
||||
let recent_work_section = build_recent_work_section(&cwd, &recent_threads);
|
||||
let workspace_section = build_workspace_section(&cwd);
|
||||
|
||||
if recent_work_section.is_none() && workspace_section.is_none() {
|
||||
debug!("realtime startup context unavailable; skipping injection");
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut parts = vec![STARTUP_CONTEXT_HEADER.to_string()];
|
||||
|
||||
let has_recent_work_section = recent_work_section.is_some();
|
||||
let has_workspace_section = workspace_section.is_some();
|
||||
|
||||
if let Some(section) = format_section(
|
||||
"Recent Work",
|
||||
recent_work_section,
|
||||
RECENT_WORK_SECTION_TOKEN_BUDGET,
|
||||
) {
|
||||
parts.push(section);
|
||||
}
|
||||
if let Some(section) = format_section(
|
||||
"Machine / Workspace Map",
|
||||
workspace_section,
|
||||
WORKSPACE_SECTION_TOKEN_BUDGET,
|
||||
) {
|
||||
parts.push(section);
|
||||
}
|
||||
if let Some(section) = format_section(
|
||||
"Notes",
|
||||
Some("Built at realtime startup from persisted thread metadata in the state DB and a bounded local workspace scan. This excludes repo memory instructions, AGENTS files, project-doc prompt blends, and memory summaries.".to_string()),
|
||||
NOTES_SECTION_TOKEN_BUDGET,
|
||||
) {
|
||||
parts.push(section);
|
||||
}
|
||||
|
||||
let context = truncate_text(&parts.join("\n\n"), TruncationPolicy::Tokens(budget_tokens));
|
||||
debug!(
|
||||
approx_tokens = approx_token_count(&context),
|
||||
bytes = context.len(),
|
||||
has_recent_work_section,
|
||||
has_workspace_section,
|
||||
"built realtime startup context"
|
||||
);
|
||||
info!("realtime startup context: {context}");
|
||||
Some(context)
|
||||
}
|
||||
|
||||
async fn load_recent_threads(sess: &Session) -> Vec<ThreadMetadata> {
|
||||
let Some(state_db) = sess.services.state_db.as_ref() else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
match state_db
|
||||
.list_threads(
|
||||
MAX_RECENT_THREADS,
|
||||
None,
|
||||
SortKey::UpdatedAt,
|
||||
&[],
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(page) => page.items,
|
||||
Err(err) => {
|
||||
warn!("failed to load realtime startup threads from state db: {err}");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_recent_work_section(cwd: &Path, recent_threads: &[ThreadMetadata]) -> Option<String> {
|
||||
let mut groups: HashMap<PathBuf, Vec<&ThreadMetadata>> = HashMap::new();
|
||||
for entry in recent_threads {
|
||||
let group =
|
||||
resolve_root_git_project_for_trust(&entry.cwd).unwrap_or_else(|| entry.cwd.clone());
|
||||
groups.entry(group).or_default().push(entry);
|
||||
}
|
||||
|
||||
let current_group =
|
||||
resolve_root_git_project_for_trust(cwd).unwrap_or_else(|| cwd.to_path_buf());
|
||||
let mut groups = groups.into_iter().collect::<Vec<_>>();
|
||||
groups.sort_by(|(left_group, left_entries), (right_group, right_entries)| {
|
||||
let left_latest = left_entries
|
||||
.iter()
|
||||
.map(|entry| entry.updated_at)
|
||||
.max()
|
||||
.unwrap_or_else(Utc::now);
|
||||
let right_latest = right_entries
|
||||
.iter()
|
||||
.map(|entry| entry.updated_at)
|
||||
.max()
|
||||
.unwrap_or_else(Utc::now);
|
||||
(
|
||||
*left_group != current_group,
|
||||
Reverse(left_latest),
|
||||
left_group.as_os_str(),
|
||||
)
|
||||
.cmp(&(
|
||||
*right_group != current_group,
|
||||
Reverse(right_latest),
|
||||
right_group.as_os_str(),
|
||||
))
|
||||
});
|
||||
|
||||
let sections = groups
|
||||
.into_iter()
|
||||
.take(MAX_RECENT_WORK_GROUPS)
|
||||
.filter_map(|(group, mut entries)| {
|
||||
entries.sort_by_key(|entry| Reverse(entry.updated_at));
|
||||
format_thread_group(¤t_group, &group, entries)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
(!sections.is_empty()).then(|| sections.join("\n\n"))
|
||||
}
|
||||
|
||||
fn build_workspace_section(cwd: &Path) -> Option<String> {
|
||||
build_workspace_section_with_user_root(cwd, home_dir())
|
||||
}
|
||||
|
||||
fn build_workspace_section_with_user_root(
|
||||
cwd: &Path,
|
||||
user_root: Option<PathBuf>,
|
||||
) -> Option<String> {
|
||||
let git_root = resolve_root_git_project_for_trust(cwd);
|
||||
let cwd_tree = render_tree(cwd);
|
||||
let git_root_tree = git_root
|
||||
.as_ref()
|
||||
.filter(|git_root| git_root.as_path() != cwd)
|
||||
.and_then(|git_root| render_tree(git_root));
|
||||
let user_root_tree = user_root
|
||||
.as_ref()
|
||||
.filter(|user_root| user_root.as_path() != cwd)
|
||||
.filter(|user_root| {
|
||||
git_root
|
||||
.as_ref()
|
||||
.is_none_or(|git_root| git_root.as_path() != user_root.as_path())
|
||||
})
|
||||
.and_then(|user_root| render_tree(user_root));
|
||||
|
||||
if cwd_tree.is_none() && git_root.is_none() && user_root_tree.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut lines = vec![
|
||||
format!("Current working directory: {}", cwd.display()),
|
||||
format!("Working directory name: {}", display_name(cwd)),
|
||||
];
|
||||
|
||||
if let Some(git_root) = &git_root {
|
||||
lines.push(format!("Git root: {}", git_root.display()));
|
||||
lines.push(format!("Git project: {}", display_name(git_root)));
|
||||
}
|
||||
if let Some(user_root) = &user_root {
|
||||
lines.push(format!("User root: {}", user_root.display()));
|
||||
}
|
||||
|
||||
if let Some(tree) = cwd_tree {
|
||||
lines.push(String::new());
|
||||
lines.push("Working directory tree:".to_string());
|
||||
lines.extend(tree);
|
||||
}
|
||||
|
||||
if let Some(tree) = git_root_tree {
|
||||
lines.push(String::new());
|
||||
lines.push("Git root tree:".to_string());
|
||||
lines.extend(tree);
|
||||
}
|
||||
|
||||
if let Some(tree) = user_root_tree {
|
||||
lines.push(String::new());
|
||||
lines.push("User root tree:".to_string());
|
||||
lines.extend(tree);
|
||||
}
|
||||
|
||||
Some(lines.join("\n"))
|
||||
}
|
||||
|
||||
fn render_tree(root: &Path) -> Option<Vec<String>> {
|
||||
if !root.is_dir() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut lines = Vec::new();
|
||||
collect_tree_lines(root, 0, &mut lines);
|
||||
(!lines.is_empty()).then_some(lines)
|
||||
}
|
||||
|
||||
fn collect_tree_lines(dir: &Path, depth: usize, lines: &mut Vec<String>) {
|
||||
if depth >= TREE_MAX_DEPTH {
|
||||
return;
|
||||
}
|
||||
|
||||
let entries = match read_sorted_entries(dir) {
|
||||
Ok(entries) => entries,
|
||||
Err(_) => return,
|
||||
};
|
||||
let total_entries = entries.len();
|
||||
|
||||
for entry in entries.into_iter().take(DIR_ENTRY_LIMIT) {
|
||||
let Ok(file_type) = entry.file_type() else {
|
||||
continue;
|
||||
};
|
||||
let name = file_name_string(&entry.path());
|
||||
let indent = " ".repeat(depth);
|
||||
let suffix = if file_type.is_dir() { "/" } else { "" };
|
||||
lines.push(format!("{indent}- {name}{suffix}"));
|
||||
if file_type.is_dir() {
|
||||
collect_tree_lines(&entry.path(), depth + 1, lines);
|
||||
}
|
||||
}
|
||||
|
||||
if total_entries > DIR_ENTRY_LIMIT {
|
||||
lines.push(format!(
|
||||
"{}- ... {} more entries",
|
||||
" ".repeat(depth),
|
||||
total_entries - DIR_ENTRY_LIMIT
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
fn read_sorted_entries(dir: &Path) -> io::Result<Vec<DirEntry>> {
|
||||
let mut entries = std::fs::read_dir(dir)?
|
||||
.filter_map(Result::ok)
|
||||
.filter(|entry| !is_noisy_name(&entry.file_name()))
|
||||
.collect::<Vec<_>>();
|
||||
entries.sort_by(|left, right| {
|
||||
let left_is_dir = left
|
||||
.file_type()
|
||||
.map(|file_type| file_type.is_dir())
|
||||
.unwrap_or(false);
|
||||
let right_is_dir = right
|
||||
.file_type()
|
||||
.map(|file_type| file_type.is_dir())
|
||||
.unwrap_or(false);
|
||||
(!left_is_dir, file_name_string(&left.path()))
|
||||
.cmp(&(!right_is_dir, file_name_string(&right.path())))
|
||||
});
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
fn is_noisy_name(name: &OsStr) -> bool {
|
||||
let name = name.to_string_lossy();
|
||||
name.starts_with('.') || NOISY_DIR_NAMES.iter().any(|noisy| *noisy == name)
|
||||
}
|
||||
|
||||
fn format_section(title: &str, body: Option<String>, budget_tokens: usize) -> Option<String> {
|
||||
let body = body?;
|
||||
let body = body.trim();
|
||||
if body.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(format!(
|
||||
"## {title}\n{}",
|
||||
truncate_text(body, TruncationPolicy::Tokens(budget_tokens))
|
||||
))
|
||||
}
|
||||
|
||||
fn format_thread_group(
|
||||
current_group: &Path,
|
||||
group: &Path,
|
||||
entries: Vec<&ThreadMetadata>,
|
||||
) -> Option<String> {
|
||||
let latest = entries.first()?;
|
||||
let group_label = if resolve_root_git_project_for_trust(latest.cwd.as_path()).is_some() {
|
||||
format!("### Git repo: {}", group.display())
|
||||
} else {
|
||||
format!("### Directory: {}", group.display())
|
||||
};
|
||||
let mut lines = vec![
|
||||
group_label,
|
||||
format!("Recent sessions: {}", entries.len()),
|
||||
format!("Latest activity: {}", latest.updated_at.to_rfc3339()),
|
||||
];
|
||||
|
||||
if let Some(git_branch) = latest
|
||||
.git_branch
|
||||
.as_deref()
|
||||
.filter(|git_branch| !git_branch.is_empty())
|
||||
{
|
||||
lines.push(format!("Latest branch: {git_branch}"));
|
||||
}
|
||||
|
||||
lines.push(String::new());
|
||||
lines.push("User asks:".to_string());
|
||||
|
||||
let mut seen = HashSet::new();
|
||||
let max_asks = if group == current_group {
|
||||
MAX_CURRENT_CWD_ASKS
|
||||
} else {
|
||||
MAX_OTHER_CWD_ASKS
|
||||
};
|
||||
|
||||
for entry in entries {
|
||||
let Some(first_user_message) = entry.first_user_message.as_deref() else {
|
||||
continue;
|
||||
};
|
||||
let ask = first_user_message
|
||||
.split_whitespace()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
let dedupe_key = format!("{}:{ask}", entry.cwd.display());
|
||||
if ask.is_empty() || !seen.insert(dedupe_key) {
|
||||
continue;
|
||||
}
|
||||
let ask = if ask.chars().count() > MAX_ASK_CHARS {
|
||||
format!(
|
||||
"{}...",
|
||||
ask.chars()
|
||||
.take(MAX_ASK_CHARS.saturating_sub(3))
|
||||
.collect::<String>()
|
||||
)
|
||||
} else {
|
||||
ask
|
||||
};
|
||||
lines.push(format!("- {}: {ask}", entry.cwd.display()));
|
||||
if seen.len() == max_asks {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(lines.len() > 5).then(|| lines.join("\n"))
|
||||
}
|
||||
|
||||
fn display_name(path: &Path) -> String {
|
||||
path.file_name()
|
||||
.and_then(OsStr::to_str)
|
||||
.map(str::to_owned)
|
||||
.unwrap_or_else(|| path.display().to_string())
|
||||
}
|
||||
|
||||
fn file_name_string(path: &Path) -> String {
|
||||
path.file_name()
|
||||
.and_then(OsStr::to_str)
|
||||
.map(str::to_owned)
|
||||
.unwrap_or_else(|| path.display().to_string())
|
||||
}
|
||||
|
||||
fn approx_token_count(text: &str) -> usize {
|
||||
text.len().div_ceil(APPROX_BYTES_PER_TOKEN)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::build_recent_work_section;
|
||||
use super::build_workspace_section;
|
||||
use super::build_workspace_section_with_user_root;
|
||||
use chrono::TimeZone;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_state::ThreadMetadata;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn thread_metadata(cwd: &str, title: &str, first_user_message: &str) -> ThreadMetadata {
|
||||
ThreadMetadata {
|
||||
id: ThreadId::new(),
|
||||
rollout_path: PathBuf::from("/tmp/rollout.jsonl"),
|
||||
created_at: Utc
|
||||
.timestamp_opt(1_709_251_100, 0)
|
||||
.single()
|
||||
.expect("valid timestamp"),
|
||||
updated_at: Utc
|
||||
.timestamp_opt(1_709_251_200, 0)
|
||||
.single()
|
||||
.expect("valid timestamp"),
|
||||
source: "cli".to_string(),
|
||||
agent_nickname: None,
|
||||
agent_role: None,
|
||||
model_provider: "test-provider".to_string(),
|
||||
cwd: PathBuf::from(cwd),
|
||||
cli_version: "test".to_string(),
|
||||
title: title.to_string(),
|
||||
sandbox_policy: "workspace-write".to_string(),
|
||||
approval_mode: "never".to_string(),
|
||||
tokens_used: 0,
|
||||
first_user_message: Some(first_user_message.to_string()),
|
||||
archived_at: None,
|
||||
git_sha: None,
|
||||
git_branch: Some("main".to_string()),
|
||||
git_origin_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_section_requires_meaningful_structure() {
|
||||
let cwd = TempDir::new().expect("tempdir");
|
||||
assert_eq!(
|
||||
build_workspace_section_with_user_root(cwd.path(), None),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_section_includes_tree_when_entries_exist() {
|
||||
let cwd = TempDir::new().expect("tempdir");
|
||||
fs::create_dir(cwd.path().join("docs")).expect("create docs dir");
|
||||
fs::write(cwd.path().join("README.md"), "hello").expect("write readme");
|
||||
|
||||
let section = build_workspace_section(cwd.path()).expect("workspace section");
|
||||
assert!(section.contains("Working directory tree:"));
|
||||
assert!(section.contains("- docs/"));
|
||||
assert!(section.contains("- README.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_section_includes_user_root_tree_when_distinct() {
|
||||
let root = TempDir::new().expect("tempdir");
|
||||
let cwd = root.path().join("cwd");
|
||||
let git_root = root.path().join("git");
|
||||
let user_root = root.path().join("home");
|
||||
|
||||
fs::create_dir_all(cwd.join("docs")).expect("create cwd docs dir");
|
||||
fs::write(cwd.join("README.md"), "hello").expect("write cwd readme");
|
||||
fs::create_dir_all(git_root.join(".git")).expect("create git dir");
|
||||
fs::write(git_root.join("Cargo.toml"), "[workspace]").expect("write git root marker");
|
||||
fs::create_dir_all(user_root.join("code")).expect("create user root child");
|
||||
fs::write(user_root.join(".zshrc"), "export TEST=1").expect("write home file");
|
||||
|
||||
let section = build_workspace_section_with_user_root(cwd.as_path(), Some(user_root))
|
||||
.expect("workspace section");
|
||||
assert!(section.contains("User root tree:"));
|
||||
assert!(section.contains("- code/"));
|
||||
assert!(!section.contains("- .zshrc"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recent_work_section_groups_threads_by_cwd() {
|
||||
let root = TempDir::new().expect("tempdir");
|
||||
let repo = root.path().join("repo");
|
||||
let workspace_a = repo.join("workspace-a");
|
||||
let workspace_b = repo.join("workspace-b");
|
||||
let outside = root.path().join("outside");
|
||||
|
||||
fs::create_dir(&repo).expect("create repo dir");
|
||||
Command::new("git")
|
||||
.env("GIT_CONFIG_GLOBAL", "/dev/null")
|
||||
.env("GIT_CONFIG_NOSYSTEM", "1")
|
||||
.args(["init"])
|
||||
.current_dir(&repo)
|
||||
.output()
|
||||
.expect("git init");
|
||||
fs::create_dir_all(&workspace_a).expect("create workspace a");
|
||||
fs::create_dir_all(&workspace_b).expect("create workspace b");
|
||||
fs::create_dir_all(&outside).expect("create outside dir");
|
||||
|
||||
let recent_threads = vec![
|
||||
thread_metadata(
|
||||
workspace_a.to_string_lossy().as_ref(),
|
||||
"Investigate realtime startup context",
|
||||
"Log the startup context before sending it",
|
||||
),
|
||||
thread_metadata(
|
||||
workspace_b.to_string_lossy().as_ref(),
|
||||
"Trim websocket startup payload",
|
||||
"Remove memories from the realtime startup context",
|
||||
),
|
||||
thread_metadata(outside.to_string_lossy().as_ref(), "", "Inspect flaky test"),
|
||||
];
|
||||
let current_cwd = workspace_a;
|
||||
let repo = fs::canonicalize(repo).expect("canonicalize repo");
|
||||
|
||||
let section = build_recent_work_section(current_cwd.as_path(), &recent_threads)
|
||||
.expect("recent work section");
|
||||
assert!(section.contains(&format!("### Git repo: {}", repo.display())));
|
||||
assert!(section.contains("Recent sessions: 2"));
|
||||
assert!(section.contains("User asks:"));
|
||||
assert!(section.contains(&format!(
|
||||
"- {}: Log the startup context before sending it",
|
||||
current_cwd.display()
|
||||
)));
|
||||
assert!(section.contains(&format!("### Directory: {}", outside.display())));
|
||||
assert!(section.contains(&format!("- {}: Inspect flaky test", outside.display())));
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,7 @@ use crate::codex::Session;
|
|||
use crate::default_client::default_headers;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::realtime_context::build_realtime_startup_context;
|
||||
use async_channel::Receiver;
|
||||
use async_channel::Sender;
|
||||
use async_channel::TrySendError;
|
||||
|
|
@ -43,6 +44,7 @@ const AUDIO_IN_QUEUE_CAPACITY: usize = 256;
|
|||
const USER_TEXT_IN_QUEUE_CAPACITY: usize = 64;
|
||||
const HANDOFF_OUT_QUEUE_CAPACITY: usize = 64;
|
||||
const OUTPUT_EVENTS_QUEUE_CAPACITY: usize = 256;
|
||||
const REALTIME_STARTUP_CONTEXT_TOKEN_BUDGET: usize = 5_000;
|
||||
|
||||
pub(crate) struct RealtimeConversationManager {
|
||||
state: Mutex<Option<ConversationState>>,
|
||||
|
|
@ -282,6 +284,13 @@ pub(crate) async fn handle_start(
|
|||
.experimental_realtime_ws_backend_prompt
|
||||
.clone()
|
||||
.unwrap_or(params.prompt);
|
||||
let prompt =
|
||||
match build_realtime_startup_context(sess.as_ref(), REALTIME_STARTUP_CONTEXT_TOKEN_BUDGET)
|
||||
.await
|
||||
{
|
||||
Some(context) => format!("{prompt}\n\n{context}"),
|
||||
None => prompt,
|
||||
};
|
||||
let model = config.experimental_realtime_ws_model.clone();
|
||||
|
||||
let requested_session_id = params
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use chrono::Utc;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::auth::OPENAI_API_KEY_ENV_VAR;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::CodexErrorInfo;
|
||||
use codex_protocol::protocol::ConversationAudioParams;
|
||||
use codex_protocol::protocol::ConversationStartParams;
|
||||
|
|
@ -11,6 +14,7 @@ use codex_protocol::protocol::Op;
|
|||
use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
use codex_protocol::protocol::RealtimeConversationRealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
|
|
@ -18,6 +22,7 @@ use core_test_support::responses::start_websocket_server;
|
|||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::streaming_sse::StreamingSseChunk;
|
||||
use core_test_support::streaming_sse::start_streaming_sse_server;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
|
|
@ -25,9 +30,57 @@ use pretty_assertions::assert_eq;
|
|||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::ffi::OsString;
|
||||
use std::fs;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex.";
|
||||
const MEMORY_PROMPT_PHRASE: &str =
|
||||
"You have access to a memory folder with guidance from prior runs.";
|
||||
|
||||
fn websocket_request_text(
|
||||
request: &core_test_support::responses::WebSocketRequest,
|
||||
) -> Option<String> {
|
||||
request.body_json()["item"]["content"][0]["text"]
|
||||
.as_str()
|
||||
.map(str::to_owned)
|
||||
}
|
||||
|
||||
fn websocket_request_instructions(
|
||||
request: &core_test_support::responses::WebSocketRequest,
|
||||
) -> Option<String> {
|
||||
request.body_json()["session"]["instructions"]
|
||||
.as_str()
|
||||
.map(str::to_owned)
|
||||
}
|
||||
|
||||
async fn seed_recent_thread(
|
||||
test: &TestCodex,
|
||||
title: &str,
|
||||
first_user_message: &str,
|
||||
slug: &str,
|
||||
) -> Result<()> {
|
||||
let db = test.codex.state_db().context("state db enabled")?;
|
||||
let thread_id = ThreadId::new();
|
||||
let updated_at = Utc::now();
|
||||
let mut metadata_builder = codex_state::ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
test.codex_home_path()
|
||||
.join(format!("rollout-{thread_id}.jsonl")),
|
||||
updated_at,
|
||||
SessionSource::Cli,
|
||||
);
|
||||
metadata_builder.cwd = test.workspace_path(format!("workspace-{slug}"));
|
||||
metadata_builder.model_provider = Some("test-provider".to_string());
|
||||
metadata_builder.git_branch = Some(format!("branch-{slug}"));
|
||||
let mut metadata = metadata_builder.build("test-provider");
|
||||
metadata.title = title.to_string();
|
||||
metadata.first_user_message = Some(first_user_message.to_string());
|
||||
db.upsert_thread(&metadata).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
|
@ -122,10 +175,9 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
|
|||
connection[0].body_json()["type"].as_str(),
|
||||
Some("session.update")
|
||||
);
|
||||
assert_eq!(
|
||||
connection[0].body_json()["session"]["instructions"].as_str(),
|
||||
Some("backend prompt")
|
||||
);
|
||||
let initial_instructions = websocket_request_instructions(&connection[0])
|
||||
.expect("initial session update instructions");
|
||||
assert!(initial_instructions.starts_with("backend prompt"));
|
||||
assert_eq!(
|
||||
server.handshakes()[1]
|
||||
.header("x-session-id")
|
||||
|
|
@ -452,19 +504,17 @@ async fn conversation_second_start_replaces_runtime() -> Result<()> {
|
|||
let connections = server.connections();
|
||||
assert_eq!(connections.len(), 3);
|
||||
assert_eq!(connections[1].len(), 1);
|
||||
assert_eq!(
|
||||
connections[1][0].body_json()["session"]["instructions"].as_str(),
|
||||
Some("old")
|
||||
);
|
||||
let old_instructions =
|
||||
websocket_request_instructions(&connections[1][0]).expect("old session instructions");
|
||||
assert!(old_instructions.starts_with("old"));
|
||||
assert_eq!(
|
||||
server.handshakes()[1].header("x-session-id").as_deref(),
|
||||
Some("conv_old")
|
||||
);
|
||||
assert_eq!(connections[2].len(), 2);
|
||||
assert_eq!(
|
||||
connections[2][0].body_json()["session"]["instructions"].as_str(),
|
||||
Some("new")
|
||||
);
|
||||
let new_instructions =
|
||||
websocket_request_instructions(&connections[2][0]).expect("new session instructions");
|
||||
assert!(new_instructions.starts_with("new"));
|
||||
assert_eq!(
|
||||
server.handshakes()[2].header("x-session-id").as_deref(),
|
||||
Some("conv_new")
|
||||
|
|
@ -570,9 +620,178 @@ async fn conversation_uses_experimental_realtime_ws_backend_prompt_override() ->
|
|||
|
||||
let connections = server.connections();
|
||||
assert_eq!(connections.len(), 2);
|
||||
let overridden_instructions = websocket_request_instructions(&connections[1][0])
|
||||
.expect("overridden session instructions");
|
||||
assert!(overridden_instructions.starts_with("prompt from config"));
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn conversation_start_injects_startup_context_from_thread_history() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_websocket_server(vec![
|
||||
vec![],
|
||||
vec![vec![json!({
|
||||
"type": "session.updated",
|
||||
"session": { "id": "sess_context", "instructions": "backend prompt" }
|
||||
})]],
|
||||
])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
seed_recent_thread(
|
||||
&test,
|
||||
"Recent work: cleaned up startup flows and reviewed websocket routing.",
|
||||
"Investigate realtime startup context",
|
||||
"latest",
|
||||
)
|
||||
.await?;
|
||||
fs::create_dir_all(test.workspace_path("docs"))?;
|
||||
fs::write(test.workspace_path("README.md"), "workspace marker")?;
|
||||
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationStart(ConversationStartParams {
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: None,
|
||||
}))
|
||||
.await?;
|
||||
|
||||
wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::SessionUpdated { session_id, .. },
|
||||
}) if session_id == "sess_context" => Some(Ok(())),
|
||||
EventMsg::Error(err) => Some(Err(err.clone())),
|
||||
_ => None,
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|err: ErrorEvent| panic!("conversation start failed: {err:?}"));
|
||||
|
||||
let startup_context_request = server.wait_for_request(1, 0).await;
|
||||
let startup_context = websocket_request_instructions(&startup_context_request)
|
||||
.expect("startup context request should contain instructions");
|
||||
|
||||
assert!(startup_context.contains(STARTUP_CONTEXT_HEADER));
|
||||
assert!(!startup_context.contains("## User"));
|
||||
assert!(startup_context.contains("### "));
|
||||
assert!(startup_context.contains("Recent sessions: 1"));
|
||||
assert!(startup_context.contains("Latest branch: branch-latest"));
|
||||
assert!(startup_context.contains("User asks:"));
|
||||
assert!(startup_context.contains("Investigate realtime startup context"));
|
||||
assert!(startup_context.contains("## Machine / Workspace Map"));
|
||||
assert!(startup_context.contains("README.md"));
|
||||
assert!(!startup_context.contains(MEMORY_PROMPT_PHRASE));
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn conversation_startup_context_falls_back_to_workspace_map() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_websocket_server(vec![
|
||||
vec![],
|
||||
vec![vec![json!({
|
||||
"type": "session.updated",
|
||||
"session": { "id": "sess_workspace", "instructions": "backend prompt" }
|
||||
})]],
|
||||
])
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
fs::create_dir_all(test.workspace_path("codex-rs/core"))?;
|
||||
fs::write(test.workspace_path("notes.txt"), "workspace marker")?;
|
||||
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationStart(ConversationStartParams {
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: None,
|
||||
}))
|
||||
.await?;
|
||||
|
||||
wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::SessionUpdated { session_id, .. },
|
||||
}) if session_id == "sess_workspace" => Some(Ok(())),
|
||||
EventMsg::Error(err) => Some(Err(err.clone())),
|
||||
_ => None,
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|err: ErrorEvent| panic!("conversation start failed: {err:?}"));
|
||||
|
||||
let startup_context_request = server.wait_for_request(1, 0).await;
|
||||
let startup_context = websocket_request_instructions(&startup_context_request)
|
||||
.expect("startup context request should contain instructions");
|
||||
|
||||
assert!(startup_context.contains(STARTUP_CONTEXT_HEADER));
|
||||
assert!(startup_context.contains("## Machine / Workspace Map"));
|
||||
assert!(startup_context.contains("notes.txt"));
|
||||
assert!(startup_context.contains("codex-rs/"));
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn conversation_startup_context_is_truncated_and_sent_once_per_start() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_websocket_server(vec![
|
||||
vec![],
|
||||
vec![
|
||||
vec![json!({
|
||||
"type": "session.updated",
|
||||
"session": { "id": "sess_truncated", "instructions": "backend prompt" }
|
||||
})],
|
||||
vec![],
|
||||
],
|
||||
])
|
||||
.await;
|
||||
|
||||
let oversized_summary = "recent work ".repeat(3_500);
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build_with_websocket_server(&server).await?;
|
||||
seed_recent_thread(&test, &oversized_summary, "summary", "oversized").await?;
|
||||
fs::write(test.workspace_path("marker.txt"), "marker")?;
|
||||
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationStart(ConversationStartParams {
|
||||
prompt: "backend prompt".to_string(),
|
||||
session_id: None,
|
||||
}))
|
||||
.await?;
|
||||
|
||||
wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::SessionUpdated { session_id, .. },
|
||||
}) if session_id == "sess_truncated" => Some(Ok(())),
|
||||
EventMsg::Error(err) => Some(Err(err.clone())),
|
||||
_ => None,
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|err: ErrorEvent| panic!("conversation start failed: {err:?}"));
|
||||
|
||||
let startup_context_request = server.wait_for_request(1, 0).await;
|
||||
let startup_context = websocket_request_instructions(&startup_context_request)
|
||||
.expect("startup context request should contain instructions");
|
||||
assert!(startup_context.contains(STARTUP_CONTEXT_HEADER));
|
||||
assert!(startup_context.len() <= 20_500);
|
||||
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationText(ConversationTextParams {
|
||||
text: "hello".to_string(),
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let explicit_text_request = server.wait_for_request(1, 1).await;
|
||||
assert_eq!(
|
||||
connections[1][0].body_json()["session"]["instructions"].as_str(),
|
||||
Some("prompt from config")
|
||||
websocket_request_text(&explicit_text_request),
|
||||
Some("hello".to_string())
|
||||
);
|
||||
|
||||
server.shutdown().await;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue