feat: add phase 1 mem db (#10634)

- Schema: thread_id (PK, FK to threads.id with cascade delete),
trace_summary, memory_summary, updated_at.
- Migration: creates the table and an index on (updated_at DESC,
thread_id DESC) for efficient recent-first reads.
  - Runtime API (DB-only):
      - `get_thread_memory(thread_id)`: fetch one memory row.
- `upsert_thread_memory(thread_id, trace_summary, memory_summary)`:
insert/update by thread id and always advance updated_at.
- `get_last_n_thread_memories_for_cwd(cwd, n)`: join thread_memory with
threads and return newest n rows for an exact cwd match.
- Model layer: introduced ThreadMemory and row conversion types to keep
query decoding typed and consistent with existing state models.
This commit is contained in:
jif-oai 2026-02-04 21:38:39 +00:00 committed by GitHub
parent 7a253076fe
commit 4922b3e571
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 517 additions and 0 deletions

View file

@ -280,6 +280,60 @@ pub async fn persist_dynamic_tools(
}
}
/// Get memory summaries for a thread id using SQLite.
pub async fn get_thread_memory(
context: Option<&codex_state::StateRuntime>,
thread_id: ThreadId,
stage: &str,
) -> Option<codex_state::ThreadMemory> {
let ctx = context?;
match ctx.get_thread_memory(thread_id).await {
Ok(memory) => memory,
Err(err) => {
warn!("state db get_thread_memory failed during {stage}: {err}");
None
}
}
}
/// Upsert memory summaries for a thread id using SQLite.
pub async fn upsert_thread_memory(
context: Option<&codex_state::StateRuntime>,
thread_id: ThreadId,
trace_summary: &str,
memory_summary: &str,
stage: &str,
) -> Option<codex_state::ThreadMemory> {
let ctx = context?;
match ctx
.upsert_thread_memory(thread_id, trace_summary, memory_summary)
.await
{
Ok(memory) => Some(memory),
Err(err) => {
warn!("state db upsert_thread_memory failed during {stage}: {err}");
None
}
}
}
/// Get the last N memories corresponding to a cwd using an exact path match.
pub async fn get_last_n_thread_memories_for_cwd(
context: Option<&codex_state::StateRuntime>,
cwd: &Path,
n: usize,
stage: &str,
) -> Option<Vec<codex_state::ThreadMemory>> {
let ctx = context?;
match ctx.get_last_n_thread_memories_for_cwd(cwd, n).await {
Ok(memories) => Some(memories),
Err(err) => {
warn!("state db get_last_n_thread_memories_for_cwd failed during {stage}: {err}");
None
}
}
}
/// Reconcile rollout items into SQLite, falling back to scanning the rollout file.
pub async fn reconcile_rollout(
context: Option<&codex_state::StateRuntime>,

View file

@ -0,0 +1,9 @@
CREATE TABLE thread_memory (
thread_id TEXT PRIMARY KEY,
trace_summary TEXT NOT NULL,
memory_summary TEXT NOT NULL,
updated_at INTEGER NOT NULL,
FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE
);
CREATE INDEX idx_thread_memory_updated_at ON thread_memory(updated_at DESC, thread_id DESC);

View file

@ -25,6 +25,7 @@ pub use model::Anchor;
pub use model::BackfillStats;
pub use model::ExtractionOutcome;
pub use model::SortKey;
pub use model::ThreadMemory;
pub use model::ThreadMetadata;
pub use model::ThreadMetadataBuilder;
pub use model::ThreadsPage;

View file

@ -1,9 +1,11 @@
mod log;
mod thread_memory;
mod thread_metadata;
pub use log::LogEntry;
pub use log::LogQuery;
pub use log::LogRow;
pub use thread_memory::ThreadMemory;
pub use thread_metadata::Anchor;
pub use thread_metadata::BackfillStats;
pub use thread_metadata::ExtractionOutcome;
@ -12,6 +14,7 @@ pub use thread_metadata::ThreadMetadata;
pub use thread_metadata::ThreadMetadataBuilder;
pub use thread_metadata::ThreadsPage;
pub(crate) use thread_memory::ThreadMemoryRow;
pub(crate) use thread_metadata::ThreadRow;
pub(crate) use thread_metadata::anchor_from_item;
pub(crate) use thread_metadata::datetime_to_epoch_seconds;

View file

@ -0,0 +1,52 @@
use anyhow::Result;
use chrono::DateTime;
use chrono::Utc;
use codex_protocol::ThreadId;
use sqlx::Row;
use sqlx::sqlite::SqliteRow;
/// Stored memory summaries for a single thread.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ThreadMemory {
pub thread_id: ThreadId,
pub trace_summary: String,
pub memory_summary: String,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug)]
pub(crate) struct ThreadMemoryRow {
thread_id: String,
trace_summary: String,
memory_summary: String,
updated_at: i64,
}
impl ThreadMemoryRow {
pub(crate) fn try_from_row(row: &SqliteRow) -> Result<Self> {
Ok(Self {
thread_id: row.try_get("thread_id")?,
trace_summary: row.try_get("trace_summary")?,
memory_summary: row.try_get("memory_summary")?,
updated_at: row.try_get("updated_at")?,
})
}
}
impl TryFrom<ThreadMemoryRow> for ThreadMemory {
type Error = anyhow::Error;
fn try_from(row: ThreadMemoryRow) -> std::result::Result<Self, Self::Error> {
Ok(Self {
thread_id: ThreadId::try_from(row.thread_id)?,
trace_summary: row.trace_summary,
memory_summary: row.memory_summary,
updated_at: epoch_seconds_to_datetime(row.updated_at)?,
})
}
}
fn epoch_seconds_to_datetime(secs: i64) -> Result<DateTime<Utc>> {
DateTime::<Utc>::from_timestamp(secs, 0)
.ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}"))
}

View file

@ -3,11 +3,13 @@ use crate::LogEntry;
use crate::LogQuery;
use crate::LogRow;
use crate::SortKey;
use crate::ThreadMemory;
use crate::ThreadMetadata;
use crate::ThreadMetadataBuilder;
use crate::ThreadsPage;
use crate::apply_rollout_item;
use crate::migrations::MIGRATOR;
use crate::model::ThreadMemoryRow;
use crate::model::ThreadRow;
use crate::model::anchor_from_item;
use crate::model::datetime_to_epoch_seconds;
@ -153,6 +155,26 @@ ORDER BY position ASC
Ok(Some(tools))
}
/// Get memory summaries for a thread, if present.
pub async fn get_thread_memory(
&self,
thread_id: ThreadId,
) -> anyhow::Result<Option<ThreadMemory>> {
let row = sqlx::query(
r#"
SELECT thread_id, trace_summary, memory_summary, updated_at
FROM thread_memory
WHERE thread_id = ?
"#,
)
.bind(thread_id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from))
.transpose()
}
/// Find a rollout path by thread id using the underlying database.
pub async fn find_rollout_path_by_id(
&self,
@ -405,6 +427,83 @@ ON CONFLICT(id) DO UPDATE SET
Ok(())
}
/// Insert or update memory summaries for a thread.
///
/// This method always advances `updated_at`, even if summaries are unchanged.
pub async fn upsert_thread_memory(
&self,
thread_id: ThreadId,
trace_summary: &str,
memory_summary: &str,
) -> anyhow::Result<ThreadMemory> {
if self.get_thread(thread_id).await?.is_none() {
return Err(anyhow::anyhow!("thread not found: {thread_id}"));
}
let updated_at = Utc::now().timestamp();
sqlx::query(
r#"
INSERT INTO thread_memory (
thread_id,
trace_summary,
memory_summary,
updated_at
) VALUES (?, ?, ?, ?)
ON CONFLICT(thread_id) DO UPDATE SET
trace_summary = excluded.trace_summary,
memory_summary = excluded.memory_summary,
updated_at = CASE
WHEN excluded.updated_at <= thread_memory.updated_at THEN thread_memory.updated_at + 1
ELSE excluded.updated_at
END
"#,
)
.bind(thread_id.to_string())
.bind(trace_summary)
.bind(memory_summary)
.bind(updated_at)
.execute(self.pool.as_ref())
.await?;
self.get_thread_memory(thread_id)
.await?
.ok_or_else(|| anyhow::anyhow!("failed to load upserted thread memory: {thread_id}"))
}
/// Get the last `n` memories for threads with an exact cwd match.
pub async fn get_last_n_thread_memories_for_cwd(
&self,
cwd: &Path,
n: usize,
) -> anyhow::Result<Vec<ThreadMemory>> {
if n == 0 {
return Ok(Vec::new());
}
let rows = sqlx::query(
r#"
SELECT
m.thread_id,
m.trace_summary,
m.memory_summary,
m.updated_at
FROM thread_memory AS m
INNER JOIN threads AS t ON t.id = m.thread_id
WHERE t.cwd = ?
ORDER BY m.updated_at DESC, m.thread_id DESC
LIMIT ?
"#,
)
.bind(cwd.display().to_string())
.bind(n as i64)
.fetch_all(self.pool.as_ref())
.await?;
rows.into_iter()
.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from))
.collect()
}
/// Persist dynamic tools for a thread if none have been stored yet.
///
/// Dynamic tools are defined at thread start and should not change afterward.
@ -771,11 +870,20 @@ mod tests {
use super::STATE_DB_FILENAME;
use super::STATE_DB_VERSION;
use super::StateRuntime;
use super::ThreadMetadata;
use super::state_db_filename;
use chrono::DateTime;
use chrono::Utc;
use codex_protocol::ThreadId;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
use pretty_assertions::assert_eq;
use sqlx::Row;
use std::path::Path;
use std::path::PathBuf;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use uuid::Uuid;
fn unique_temp_dir() -> PathBuf {
let nanos = SystemTime::now()
@ -858,4 +966,294 @@ mod tests {
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn upsert_and_get_thread_memory() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.join("a"));
runtime
.upsert_thread(&metadata)
.await
.expect("upsert thread");
assert_eq!(
runtime
.get_thread_memory(thread_id)
.await
.expect("get memory before insert"),
None
);
let inserted = runtime
.upsert_thread_memory(thread_id, "trace one", "memory one")
.await
.expect("upsert memory");
assert_eq!(inserted.thread_id, thread_id);
assert_eq!(inserted.trace_summary, "trace one");
assert_eq!(inserted.memory_summary, "memory one");
let updated = runtime
.upsert_thread_memory(thread_id, "trace two", "memory two")
.await
.expect("update memory");
assert_eq!(updated.thread_id, thread_id);
assert_eq!(updated.trace_summary, "trace two");
assert_eq!(updated.memory_summary, "memory two");
assert!(
updated.updated_at >= inserted.updated_at,
"updated_at should not move backward"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_last_n_thread_memories_for_cwd_matches_exactly() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let cwd_a = codex_home.join("workspace-a");
let cwd_b = codex_home.join("workspace-b");
let t1 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let t2 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let t3 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, t1, cwd_a.clone()))
.await
.expect("upsert thread t1");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, t2, cwd_a.clone()))
.await
.expect("upsert thread t2");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, t3, cwd_b.clone()))
.await
.expect("upsert thread t3");
let first = runtime
.upsert_thread_memory(t1, "trace-1", "memory-1")
.await
.expect("upsert t1 memory");
runtime
.upsert_thread_memory(t2, "trace-2", "memory-2")
.await
.expect("upsert t2 memory");
runtime
.upsert_thread_memory(t3, "trace-3", "memory-3")
.await
.expect("upsert t3 memory");
// Ensure deterministic ordering even when updates happen in the same second.
runtime
.upsert_thread_memory(t1, "trace-1b", "memory-1b")
.await
.expect("upsert t1 memory again");
let cwd_a_memories = runtime
.get_last_n_thread_memories_for_cwd(cwd_a.as_path(), 2)
.await
.expect("list cwd a memories");
assert_eq!(cwd_a_memories.len(), 2);
assert_eq!(cwd_a_memories[0].thread_id, t1);
assert_eq!(cwd_a_memories[0].trace_summary, "trace-1b");
assert_eq!(cwd_a_memories[0].memory_summary, "memory-1b");
assert_eq!(cwd_a_memories[1].thread_id, t2);
assert!(cwd_a_memories[0].updated_at >= first.updated_at);
let cwd_b_memories = runtime
.get_last_n_thread_memories_for_cwd(cwd_b.as_path(), 10)
.await
.expect("list cwd b memories");
assert_eq!(cwd_b_memories.len(), 1);
assert_eq!(cwd_b_memories[0].thread_id, t3);
let none = runtime
.get_last_n_thread_memories_for_cwd(codex_home.join("missing").as_path(), 10)
.await
.expect("list missing cwd memories");
assert_eq!(none, Vec::new());
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn upsert_thread_memory_errors_for_unknown_thread() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let unknown_thread_id =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let err = runtime
.upsert_thread_memory(unknown_thread_id, "trace", "memory")
.await
.expect_err("unknown thread should fail");
assert!(
err.to_string().contains("thread not found"),
"error should mention missing thread: {err}"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_last_n_thread_memories_for_cwd_zero_returns_empty() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd.clone()))
.await
.expect("upsert thread");
runtime
.upsert_thread_memory(thread_id, "trace", "memory")
.await
.expect("upsert memory");
let memories = runtime
.get_last_n_thread_memories_for_cwd(cwd.as_path(), 0)
.await
.expect("query memories");
assert_eq!(memories, Vec::new());
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_last_n_thread_memories_for_cwd_does_not_prefix_match() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let cwd_exact = codex_home.join("workspace");
let cwd_prefix = codex_home.join("workspace-child");
let t_exact = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let t_prefix = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
t_exact,
cwd_exact.clone(),
))
.await
.expect("upsert exact thread");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
t_prefix,
cwd_prefix.clone(),
))
.await
.expect("upsert prefix thread");
runtime
.upsert_thread_memory(t_exact, "trace-exact", "memory-exact")
.await
.expect("upsert exact memory");
runtime
.upsert_thread_memory(t_prefix, "trace-prefix", "memory-prefix")
.await
.expect("upsert prefix memory");
let exact_only = runtime
.get_last_n_thread_memories_for_cwd(cwd_exact.as_path(), 10)
.await
.expect("query exact cwd");
assert_eq!(exact_only.len(), 1);
assert_eq!(exact_only[0].thread_id, t_exact);
assert_eq!(exact_only[0].memory_summary, "memory-exact");
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn deleting_thread_cascades_thread_memory() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd))
.await
.expect("upsert thread");
runtime
.upsert_thread_memory(thread_id, "trace", "memory")
.await
.expect("upsert memory");
let count_before =
sqlx::query("SELECT COUNT(*) AS count FROM thread_memory WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("count before delete")
.try_get::<i64, _>("count")
.expect("count value");
assert_eq!(count_before, 1);
sqlx::query("DELETE FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.await
.expect("delete thread");
let count_after =
sqlx::query("SELECT COUNT(*) AS count FROM thread_memory WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("count after delete")
.try_get::<i64, _>("count")
.expect("count value");
assert_eq!(count_after, 0);
assert_eq!(
runtime
.get_thread_memory(thread_id)
.await
.expect("get memory after delete"),
None
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
fn test_thread_metadata(
codex_home: &Path,
thread_id: ThreadId,
cwd: PathBuf,
) -> ThreadMetadata {
let now = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).expect("timestamp");
ThreadMetadata {
id: thread_id,
rollout_path: codex_home.join(format!("rollout-{thread_id}.jsonl")),
created_at: now,
updated_at: now,
source: "cli".to_string(),
model_provider: "test-provider".to_string(),
cwd,
title: String::new(),
sandbox_policy: crate::extract::enum_to_string(&SandboxPolicy::ReadOnly),
approval_mode: crate::extract::enum_to_string(&AskForApproval::OnRequest),
tokens_used: 0,
has_user_event: true,
archived_at: None,
git_sha: None,
git_branch: None,
git_origin_url: None,
}
}
}