diff --git a/codex-rs/core/src/state_db.rs b/codex-rs/core/src/state_db.rs index 1cfb8f156..c36d8afe4 100644 --- a/codex-rs/core/src/state_db.rs +++ b/codex-rs/core/src/state_db.rs @@ -280,6 +280,60 @@ pub async fn persist_dynamic_tools( } } +/// Get memory summaries for a thread id using SQLite. +pub async fn get_thread_memory( + context: Option<&codex_state::StateRuntime>, + thread_id: ThreadId, + stage: &str, +) -> Option { + let ctx = context?; + match ctx.get_thread_memory(thread_id).await { + Ok(memory) => memory, + Err(err) => { + warn!("state db get_thread_memory failed during {stage}: {err}"); + None + } + } +} + +/// Upsert memory summaries for a thread id using SQLite. +pub async fn upsert_thread_memory( + context: Option<&codex_state::StateRuntime>, + thread_id: ThreadId, + trace_summary: &str, + memory_summary: &str, + stage: &str, +) -> Option { + let ctx = context?; + match ctx + .upsert_thread_memory(thread_id, trace_summary, memory_summary) + .await + { + Ok(memory) => Some(memory), + Err(err) => { + warn!("state db upsert_thread_memory failed during {stage}: {err}"); + None + } + } +} + +/// Get the last N memories corresponding to a cwd using an exact path match. +pub async fn get_last_n_thread_memories_for_cwd( + context: Option<&codex_state::StateRuntime>, + cwd: &Path, + n: usize, + stage: &str, +) -> Option> { + let ctx = context?; + match ctx.get_last_n_thread_memories_for_cwd(cwd, n).await { + Ok(memories) => Some(memories), + Err(err) => { + warn!("state db get_last_n_thread_memories_for_cwd failed during {stage}: {err}"); + None + } + } +} + /// Reconcile rollout items into SQLite, falling back to scanning the rollout file. pub async fn reconcile_rollout( context: Option<&codex_state::StateRuntime>, diff --git a/codex-rs/state/migrations/0006_thread_memory.sql b/codex-rs/state/migrations/0006_thread_memory.sql new file mode 100644 index 000000000..fe90ab667 --- /dev/null +++ b/codex-rs/state/migrations/0006_thread_memory.sql @@ -0,0 +1,9 @@ +CREATE TABLE thread_memory ( + thread_id TEXT PRIMARY KEY, + trace_summary TEXT NOT NULL, + memory_summary TEXT NOT NULL, + updated_at INTEGER NOT NULL, + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE +); + +CREATE INDEX idx_thread_memory_updated_at ON thread_memory(updated_at DESC, thread_id DESC); diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index 2d37ecee9..7df337409 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -25,6 +25,7 @@ pub use model::Anchor; pub use model::BackfillStats; pub use model::ExtractionOutcome; pub use model::SortKey; +pub use model::ThreadMemory; pub use model::ThreadMetadata; pub use model::ThreadMetadataBuilder; pub use model::ThreadsPage; diff --git a/codex-rs/state/src/model/mod.rs b/codex-rs/state/src/model/mod.rs index bd615d756..d937019ec 100644 --- a/codex-rs/state/src/model/mod.rs +++ b/codex-rs/state/src/model/mod.rs @@ -1,9 +1,11 @@ mod log; +mod thread_memory; mod thread_metadata; pub use log::LogEntry; pub use log::LogQuery; pub use log::LogRow; +pub use thread_memory::ThreadMemory; pub use thread_metadata::Anchor; pub use thread_metadata::BackfillStats; pub use thread_metadata::ExtractionOutcome; @@ -12,6 +14,7 @@ pub use thread_metadata::ThreadMetadata; pub use thread_metadata::ThreadMetadataBuilder; pub use thread_metadata::ThreadsPage; +pub(crate) use thread_memory::ThreadMemoryRow; pub(crate) use thread_metadata::ThreadRow; pub(crate) use thread_metadata::anchor_from_item; pub(crate) use thread_metadata::datetime_to_epoch_seconds; diff --git a/codex-rs/state/src/model/thread_memory.rs b/codex-rs/state/src/model/thread_memory.rs new file mode 100644 index 000000000..6e3a34c21 --- /dev/null +++ b/codex-rs/state/src/model/thread_memory.rs @@ -0,0 +1,52 @@ +use anyhow::Result; +use chrono::DateTime; +use chrono::Utc; +use codex_protocol::ThreadId; +use sqlx::Row; +use sqlx::sqlite::SqliteRow; + +/// Stored memory summaries for a single thread. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ThreadMemory { + pub thread_id: ThreadId, + pub trace_summary: String, + pub memory_summary: String, + pub updated_at: DateTime, +} + +#[derive(Debug)] +pub(crate) struct ThreadMemoryRow { + thread_id: String, + trace_summary: String, + memory_summary: String, + updated_at: i64, +} + +impl ThreadMemoryRow { + pub(crate) fn try_from_row(row: &SqliteRow) -> Result { + Ok(Self { + thread_id: row.try_get("thread_id")?, + trace_summary: row.try_get("trace_summary")?, + memory_summary: row.try_get("memory_summary")?, + updated_at: row.try_get("updated_at")?, + }) + } +} + +impl TryFrom for ThreadMemory { + type Error = anyhow::Error; + + fn try_from(row: ThreadMemoryRow) -> std::result::Result { + Ok(Self { + thread_id: ThreadId::try_from(row.thread_id)?, + trace_summary: row.trace_summary, + memory_summary: row.memory_summary, + updated_at: epoch_seconds_to_datetime(row.updated_at)?, + }) + } +} + +fn epoch_seconds_to_datetime(secs: i64) -> Result> { + DateTime::::from_timestamp(secs, 0) + .ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}")) +} diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index f4c2d76e8..25c666fa6 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -3,11 +3,13 @@ use crate::LogEntry; use crate::LogQuery; use crate::LogRow; use crate::SortKey; +use crate::ThreadMemory; use crate::ThreadMetadata; use crate::ThreadMetadataBuilder; use crate::ThreadsPage; use crate::apply_rollout_item; use crate::migrations::MIGRATOR; +use crate::model::ThreadMemoryRow; use crate::model::ThreadRow; use crate::model::anchor_from_item; use crate::model::datetime_to_epoch_seconds; @@ -153,6 +155,26 @@ ORDER BY position ASC Ok(Some(tools)) } + /// Get memory summaries for a thread, if present. + pub async fn get_thread_memory( + &self, + thread_id: ThreadId, + ) -> anyhow::Result> { + let row = sqlx::query( + r#" +SELECT thread_id, trace_summary, memory_summary, updated_at +FROM thread_memory +WHERE thread_id = ? + "#, + ) + .bind(thread_id.to_string()) + .fetch_optional(self.pool.as_ref()) + .await?; + + row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from)) + .transpose() + } + /// Find a rollout path by thread id using the underlying database. pub async fn find_rollout_path_by_id( &self, @@ -405,6 +427,83 @@ ON CONFLICT(id) DO UPDATE SET Ok(()) } + /// Insert or update memory summaries for a thread. + /// + /// This method always advances `updated_at`, even if summaries are unchanged. + pub async fn upsert_thread_memory( + &self, + thread_id: ThreadId, + trace_summary: &str, + memory_summary: &str, + ) -> anyhow::Result { + if self.get_thread(thread_id).await?.is_none() { + return Err(anyhow::anyhow!("thread not found: {thread_id}")); + } + + let updated_at = Utc::now().timestamp(); + sqlx::query( + r#" +INSERT INTO thread_memory ( + thread_id, + trace_summary, + memory_summary, + updated_at +) VALUES (?, ?, ?, ?) +ON CONFLICT(thread_id) DO UPDATE SET + trace_summary = excluded.trace_summary, + memory_summary = excluded.memory_summary, + updated_at = CASE + WHEN excluded.updated_at <= thread_memory.updated_at THEN thread_memory.updated_at + 1 + ELSE excluded.updated_at + END + "#, + ) + .bind(thread_id.to_string()) + .bind(trace_summary) + .bind(memory_summary) + .bind(updated_at) + .execute(self.pool.as_ref()) + .await?; + + self.get_thread_memory(thread_id) + .await? + .ok_or_else(|| anyhow::anyhow!("failed to load upserted thread memory: {thread_id}")) + } + + /// Get the last `n` memories for threads with an exact cwd match. + pub async fn get_last_n_thread_memories_for_cwd( + &self, + cwd: &Path, + n: usize, + ) -> anyhow::Result> { + if n == 0 { + return Ok(Vec::new()); + } + + let rows = sqlx::query( + r#" +SELECT + m.thread_id, + m.trace_summary, + m.memory_summary, + m.updated_at +FROM thread_memory AS m +INNER JOIN threads AS t ON t.id = m.thread_id +WHERE t.cwd = ? +ORDER BY m.updated_at DESC, m.thread_id DESC +LIMIT ? + "#, + ) + .bind(cwd.display().to_string()) + .bind(n as i64) + .fetch_all(self.pool.as_ref()) + .await?; + + rows.into_iter() + .map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from)) + .collect() + } + /// Persist dynamic tools for a thread if none have been stored yet. /// /// Dynamic tools are defined at thread start and should not change afterward. @@ -771,11 +870,20 @@ mod tests { use super::STATE_DB_FILENAME; use super::STATE_DB_VERSION; use super::StateRuntime; + use super::ThreadMetadata; use super::state_db_filename; + use chrono::DateTime; + use chrono::Utc; + use codex_protocol::ThreadId; + use codex_protocol::protocol::AskForApproval; + use codex_protocol::protocol::SandboxPolicy; use pretty_assertions::assert_eq; + use sqlx::Row; + use std::path::Path; use std::path::PathBuf; use std::time::SystemTime; use std::time::UNIX_EPOCH; + use uuid::Uuid; fn unique_temp_dir() -> PathBuf { let nanos = SystemTime::now() @@ -858,4 +966,294 @@ mod tests { let _ = tokio::fs::remove_dir_all(codex_home).await; } + + #[tokio::test] + async fn upsert_and_get_thread_memory() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.join("a")); + runtime + .upsert_thread(&metadata) + .await + .expect("upsert thread"); + + assert_eq!( + runtime + .get_thread_memory(thread_id) + .await + .expect("get memory before insert"), + None + ); + + let inserted = runtime + .upsert_thread_memory(thread_id, "trace one", "memory one") + .await + .expect("upsert memory"); + assert_eq!(inserted.thread_id, thread_id); + assert_eq!(inserted.trace_summary, "trace one"); + assert_eq!(inserted.memory_summary, "memory one"); + + let updated = runtime + .upsert_thread_memory(thread_id, "trace two", "memory two") + .await + .expect("update memory"); + assert_eq!(updated.thread_id, thread_id); + assert_eq!(updated.trace_summary, "trace two"); + assert_eq!(updated.memory_summary, "memory two"); + assert!( + updated.updated_at >= inserted.updated_at, + "updated_at should not move backward" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_last_n_thread_memories_for_cwd_matches_exactly() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let cwd_a = codex_home.join("workspace-a"); + let cwd_b = codex_home.join("workspace-b"); + let t1 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let t2 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let t3 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, t1, cwd_a.clone())) + .await + .expect("upsert thread t1"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, t2, cwd_a.clone())) + .await + .expect("upsert thread t2"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, t3, cwd_b.clone())) + .await + .expect("upsert thread t3"); + + let first = runtime + .upsert_thread_memory(t1, "trace-1", "memory-1") + .await + .expect("upsert t1 memory"); + runtime + .upsert_thread_memory(t2, "trace-2", "memory-2") + .await + .expect("upsert t2 memory"); + runtime + .upsert_thread_memory(t3, "trace-3", "memory-3") + .await + .expect("upsert t3 memory"); + // Ensure deterministic ordering even when updates happen in the same second. + runtime + .upsert_thread_memory(t1, "trace-1b", "memory-1b") + .await + .expect("upsert t1 memory again"); + + let cwd_a_memories = runtime + .get_last_n_thread_memories_for_cwd(cwd_a.as_path(), 2) + .await + .expect("list cwd a memories"); + assert_eq!(cwd_a_memories.len(), 2); + assert_eq!(cwd_a_memories[0].thread_id, t1); + assert_eq!(cwd_a_memories[0].trace_summary, "trace-1b"); + assert_eq!(cwd_a_memories[0].memory_summary, "memory-1b"); + assert_eq!(cwd_a_memories[1].thread_id, t2); + assert!(cwd_a_memories[0].updated_at >= first.updated_at); + + let cwd_b_memories = runtime + .get_last_n_thread_memories_for_cwd(cwd_b.as_path(), 10) + .await + .expect("list cwd b memories"); + assert_eq!(cwd_b_memories.len(), 1); + assert_eq!(cwd_b_memories[0].thread_id, t3); + + let none = runtime + .get_last_n_thread_memories_for_cwd(codex_home.join("missing").as_path(), 10) + .await + .expect("list missing cwd memories"); + assert_eq!(none, Vec::new()); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn upsert_thread_memory_errors_for_unknown_thread() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let unknown_thread_id = + ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let err = runtime + .upsert_thread_memory(unknown_thread_id, "trace", "memory") + .await + .expect_err("unknown thread should fail"); + assert!( + err.to_string().contains("thread not found"), + "error should mention missing thread: {err}" + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_last_n_thread_memories_for_cwd_zero_returns_empty() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd.clone())) + .await + .expect("upsert thread"); + runtime + .upsert_thread_memory(thread_id, "trace", "memory") + .await + .expect("upsert memory"); + + let memories = runtime + .get_last_n_thread_memories_for_cwd(cwd.as_path(), 0) + .await + .expect("query memories"); + assert_eq!(memories, Vec::new()); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn get_last_n_thread_memories_for_cwd_does_not_prefix_match() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let cwd_exact = codex_home.join("workspace"); + let cwd_prefix = codex_home.join("workspace-child"); + let t_exact = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let t_prefix = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + t_exact, + cwd_exact.clone(), + )) + .await + .expect("upsert exact thread"); + runtime + .upsert_thread(&test_thread_metadata( + &codex_home, + t_prefix, + cwd_prefix.clone(), + )) + .await + .expect("upsert prefix thread"); + runtime + .upsert_thread_memory(t_exact, "trace-exact", "memory-exact") + .await + .expect("upsert exact memory"); + runtime + .upsert_thread_memory(t_prefix, "trace-prefix", "memory-prefix") + .await + .expect("upsert prefix memory"); + + let exact_only = runtime + .get_last_n_thread_memories_for_cwd(cwd_exact.as_path(), 10) + .await + .expect("query exact cwd"); + assert_eq!(exact_only.len(), 1); + assert_eq!(exact_only[0].thread_id, t_exact); + assert_eq!(exact_only[0].memory_summary, "memory-exact"); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn deleting_thread_cascades_thread_memory() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None) + .await + .expect("initialize runtime"); + + let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id"); + let cwd = codex_home.join("workspace"); + runtime + .upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd)) + .await + .expect("upsert thread"); + runtime + .upsert_thread_memory(thread_id, "trace", "memory") + .await + .expect("upsert memory"); + + let count_before = + sqlx::query("SELECT COUNT(*) AS count FROM thread_memory WHERE thread_id = ?") + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("count before delete") + .try_get::("count") + .expect("count value"); + assert_eq!(count_before, 1); + + sqlx::query("DELETE FROM threads WHERE id = ?") + .bind(thread_id.to_string()) + .execute(runtime.pool.as_ref()) + .await + .expect("delete thread"); + + let count_after = + sqlx::query("SELECT COUNT(*) AS count FROM thread_memory WHERE thread_id = ?") + .bind(thread_id.to_string()) + .fetch_one(runtime.pool.as_ref()) + .await + .expect("count after delete") + .try_get::("count") + .expect("count value"); + assert_eq!(count_after, 0); + assert_eq!( + runtime + .get_thread_memory(thread_id) + .await + .expect("get memory after delete"), + None + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + fn test_thread_metadata( + codex_home: &Path, + thread_id: ThreadId, + cwd: PathBuf, + ) -> ThreadMetadata { + let now = DateTime::::from_timestamp(1_700_000_000, 0).expect("timestamp"); + ThreadMetadata { + id: thread_id, + rollout_path: codex_home.join(format!("rollout-{thread_id}.jsonl")), + created_at: now, + updated_at: now, + source: "cli".to_string(), + model_provider: "test-provider".to_string(), + cwd, + title: String::new(), + sandbox_policy: crate::extract::enum_to_string(&SandboxPolicy::ReadOnly), + approval_mode: crate::extract::enum_to_string(&AskForApproval::OnRequest), + tokens_used: 0, + has_user_event: true, + archived_at: None, + git_sha: None, + git_branch: None, + git_origin_url: None, + } + } }