core-agent-ide/codex-rs/state/src/runtime.rs
jif-oai 1d5eba0090
feat: align memory phase 1 and make it stronger (#11300)
## Align with the new phase-1 design

Basically we know run phase 1 in parallel by considering:
* Max 64 rollouts
* Max 1 month old
* Consider the most recent first

This PR also adds stronger parallelization capabilities by detecting
stale jobs, retry policies, ownership of computation to prevent double
computations etc etc
2026-02-10 13:42:09 +00:00

2278 lines
73 KiB
Rust

use crate::DB_ERROR_METRIC;
use crate::LogEntry;
use crate::LogQuery;
use crate::LogRow;
use crate::SortKey;
use crate::ThreadMemory;
use crate::ThreadMetadata;
use crate::ThreadMetadataBuilder;
use crate::ThreadsPage;
use crate::apply_rollout_item;
use crate::migrations::MIGRATOR;
use crate::model::ThreadMemoryRow;
use crate::model::ThreadRow;
use crate::model::anchor_from_item;
use crate::model::datetime_to_epoch_seconds;
use crate::paths::file_modified_time_utc;
use chrono::DateTime;
use chrono::Utc;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use codex_protocol::dynamic_tools::DynamicToolSpec;
use codex_protocol::protocol::RolloutItem;
use log::LevelFilter;
use serde_json::Value;
use sqlx::ConnectOptions;
use sqlx::QueryBuilder;
use sqlx::Row;
use sqlx::Sqlite;
use sqlx::SqlitePool;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::sqlite::SqliteJournalMode;
use sqlx::sqlite::SqlitePoolOptions;
use sqlx::sqlite::SqliteSynchronous;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tracing::warn;
use uuid::Uuid;
pub const STATE_DB_FILENAME: &str = "state";
pub const STATE_DB_VERSION: u32 = 4;
const MEMORY_SCOPE_KIND_CWD: &str = "cwd";
const METRIC_DB_INIT: &str = "codex.db.init";
#[derive(Clone)]
pub struct StateRuntime {
codex_home: PathBuf,
default_provider: String,
pool: Arc<sqlx::SqlitePool>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Phase1JobClaimOutcome {
Claimed { ownership_token: String },
SkippedTerminalFailure,
SkippedUpToDate,
SkippedRunning,
}
impl StateRuntime {
/// Initialize the state runtime using the provided Codex home and default provider.
///
/// This opens (and migrates) the SQLite database at `codex_home/state.sqlite`.
pub async fn init(
codex_home: PathBuf,
default_provider: String,
otel: Option<OtelManager>,
) -> anyhow::Result<Arc<Self>> {
tokio::fs::create_dir_all(&codex_home).await?;
remove_legacy_state_files(&codex_home).await;
let state_path = state_db_path(codex_home.as_path());
let existed = tokio::fs::try_exists(&state_path).await.unwrap_or(false);
let pool = match open_sqlite(&state_path).await {
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open state db at {}: {err}", state_path.display());
if let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "open_error")]);
}
return Err(err);
}
};
if let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "opened")]);
}
let runtime = Arc::new(Self {
pool,
codex_home,
default_provider,
});
if !existed && let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "created")]);
}
Ok(runtime)
}
/// Return the configured Codex home directory for this runtime.
pub fn codex_home(&self) -> &Path {
self.codex_home.as_path()
}
/// Get persisted rollout metadata backfill state.
pub async fn get_backfill_state(&self) -> anyhow::Result<crate::BackfillState> {
self.ensure_backfill_state_row().await?;
let row = sqlx::query(
r#"
SELECT status, last_watermark, last_success_at
FROM backfill_state
WHERE id = 1
"#,
)
.fetch_one(self.pool.as_ref())
.await?;
crate::BackfillState::try_from_row(&row)
}
/// Mark rollout metadata backfill as running.
pub async fn mark_backfill_running(&self) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
sqlx::query(
r#"
UPDATE backfill_state
SET status = ?, updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Persist rollout metadata backfill progress.
pub async fn checkpoint_backfill(&self, watermark: &str) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
sqlx::query(
r#"
UPDATE backfill_state
SET status = ?, last_watermark = ?, updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Running.as_str())
.bind(watermark)
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Mark rollout metadata backfill as complete.
pub async fn mark_backfill_complete(&self, last_watermark: Option<&str>) -> anyhow::Result<()> {
self.ensure_backfill_state_row().await?;
let now = Utc::now().timestamp();
sqlx::query(
r#"
UPDATE backfill_state
SET
status = ?,
last_watermark = COALESCE(?, last_watermark),
last_success_at = ?,
updated_at = ?
WHERE id = 1
"#,
)
.bind(crate::BackfillStatus::Complete.as_str())
.bind(last_watermark)
.bind(now)
.bind(now)
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Load thread metadata by id using the underlying database.
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
let row = sqlx::query(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
cli_version,
title,
sandbox_policy,
approval_mode,
tokens_used,
first_user_message,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
WHERE id = ?
"#,
)
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.transpose()
}
/// Get dynamic tools for a thread, if present.
pub async fn get_dynamic_tools(
&self,
thread_id: ThreadId,
) -> anyhow::Result<Option<Vec<DynamicToolSpec>>> {
let rows = sqlx::query(
r#"
SELECT name, description, input_schema
FROM thread_dynamic_tools
WHERE thread_id = ?
ORDER BY position ASC
"#,
)
.bind(thread_id.to_string())
.fetch_all(self.pool.as_ref())
.await?;
if rows.is_empty() {
return Ok(None);
}
let mut tools = Vec::with_capacity(rows.len());
for row in rows {
let input_schema: String = row.try_get("input_schema")?;
let input_schema = serde_json::from_str::<Value>(input_schema.as_str())?;
tools.push(DynamicToolSpec {
name: row.try_get("name")?,
description: row.try_get("description")?,
input_schema,
});
}
Ok(Some(tools))
}
/// Get memory summaries for a thread, if present.
pub async fn get_thread_memory(
&self,
thread_id: ThreadId,
) -> anyhow::Result<Option<ThreadMemory>> {
let row = sqlx::query(
r#"
SELECT
thread_id,
scope_kind,
scope_key,
raw_memory,
memory_summary,
updated_at,
last_used_at,
used_count,
invalidated_at,
invalid_reason
FROM thread_memory
WHERE thread_id = ?
ORDER BY updated_at DESC, scope_kind DESC, scope_key DESC
LIMIT 1
"#,
)
.bind(thread_id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from))
.transpose()
}
/// Find a rollout path by thread id using the underlying database.
pub async fn find_rollout_path_by_id(
&self,
id: ThreadId,
archived_only: Option<bool>,
) -> anyhow::Result<Option<PathBuf>> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
builder.push_bind(id.to_string());
match archived_only {
Some(true) => {
builder.push(" AND archived = 1");
}
Some(false) => {
builder.push(" AND archived = 0");
}
None => {}
}
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
Ok(row
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
.map(PathBuf::from))
}
/// List threads using the underlying database.
pub async fn list_threads(
&self,
page_size: usize,
anchor: Option<&crate::Anchor>,
sort_key: crate::SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<crate::ThreadsPage> {
let limit = page_size.saturating_add(1);
let mut builder = QueryBuilder::<Sqlite>::new(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
cli_version,
title,
sandbox_policy,
approval_mode,
tokens_used,
first_user_message,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
"#,
);
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
anchor,
sort_key,
);
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
let mut items = rows
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect::<Result<Vec<_>, _>>()?;
let num_scanned_rows = items.len();
let next_anchor = if items.len() > page_size {
items.pop();
items
.last()
.and_then(|item| anchor_from_item(item, sort_key))
} else {
None
};
Ok(ThreadsPage {
items,
next_anchor,
num_scanned_rows,
})
}
/// Insert one log entry into the logs table.
pub async fn insert_log(&self, entry: &LogEntry) -> anyhow::Result<()> {
self.insert_logs(std::slice::from_ref(entry)).await
}
/// Insert a batch of log entries into the logs table.
pub async fn insert_logs(&self, entries: &[LogEntry]) -> anyhow::Result<()> {
if entries.is_empty() {
return Ok(());
}
let mut builder = QueryBuilder::<Sqlite>::new(
"INSERT INTO logs (ts, ts_nanos, level, target, message, thread_id, module_path, file, line) ",
);
builder.push_values(entries, |mut row, entry| {
row.push_bind(entry.ts)
.push_bind(entry.ts_nanos)
.push_bind(&entry.level)
.push_bind(&entry.target)
.push_bind(&entry.message)
.push_bind(&entry.thread_id)
.push_bind(&entry.module_path)
.push_bind(&entry.file)
.push_bind(entry.line);
});
builder.build().execute(self.pool.as_ref()).await?;
Ok(())
}
pub(crate) async fn delete_logs_before(&self, cutoff_ts: i64) -> anyhow::Result<u64> {
let result = sqlx::query("DELETE FROM logs WHERE ts < ?")
.bind(cutoff_ts)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected())
}
/// Query logs with optional filters.
pub async fn query_logs(&self, query: &LogQuery) -> anyhow::Result<Vec<LogRow>> {
let mut builder = QueryBuilder::<Sqlite>::new(
"SELECT id, ts, ts_nanos, level, target, message, thread_id, file, line FROM logs WHERE 1 = 1",
);
push_log_filters(&mut builder, query);
if query.descending {
builder.push(" ORDER BY id DESC");
} else {
builder.push(" ORDER BY id ASC");
}
if let Some(limit) = query.limit {
builder.push(" LIMIT ").push_bind(limit as i64);
}
let rows = builder
.build_query_as::<LogRow>()
.fetch_all(self.pool.as_ref())
.await?;
Ok(rows)
}
/// Return the max log id matching optional filters.
pub async fn max_log_id(&self, query: &LogQuery) -> anyhow::Result<i64> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT MAX(id) AS max_id FROM logs WHERE 1 = 1");
push_log_filters(&mut builder, query);
let row = builder.build().fetch_one(self.pool.as_ref()).await?;
let max_id: Option<i64> = row.try_get("max_id")?;
Ok(max_id.unwrap_or(0))
}
/// List thread ids using the underlying database (no rollout scanning).
pub async fn list_thread_ids(
&self,
limit: usize,
anchor: Option<&crate::Anchor>,
sort_key: crate::SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<Vec<ThreadId>> {
let mut builder = QueryBuilder::<Sqlite>::new("SELECT id FROM threads");
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
anchor,
sort_key,
);
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
rows.into_iter()
.map(|row| {
let id: String = row.try_get("id")?;
Ok(ThreadId::try_from(id)?)
})
.collect()
}
/// Insert or replace thread metadata directly.
pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO threads (
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
cli_version,
title,
sandbox_policy,
approval_mode,
tokens_used,
first_user_message,
archived,
archived_at,
git_sha,
git_branch,
git_origin_url
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
rollout_path = excluded.rollout_path,
created_at = excluded.created_at,
updated_at = excluded.updated_at,
source = excluded.source,
model_provider = excluded.model_provider,
cwd = excluded.cwd,
cli_version = excluded.cli_version,
title = excluded.title,
sandbox_policy = excluded.sandbox_policy,
approval_mode = excluded.approval_mode,
tokens_used = excluded.tokens_used,
first_user_message = excluded.first_user_message,
archived = excluded.archived,
archived_at = excluded.archived_at,
git_sha = excluded.git_sha,
git_branch = excluded.git_branch,
git_origin_url = excluded.git_origin_url
"#,
)
.bind(metadata.id.to_string())
.bind(metadata.rollout_path.display().to_string())
.bind(datetime_to_epoch_seconds(metadata.created_at))
.bind(datetime_to_epoch_seconds(metadata.updated_at))
.bind(metadata.source.as_str())
.bind(metadata.model_provider.as_str())
.bind(metadata.cwd.display().to_string())
.bind(metadata.cli_version.as_str())
.bind(metadata.title.as_str())
.bind(metadata.sandbox_policy.as_str())
.bind(metadata.approval_mode.as_str())
.bind(metadata.tokens_used)
.bind(metadata.first_user_message.as_deref().unwrap_or_default())
.bind(metadata.archived_at.is_some())
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Insert or update memory summaries for a thread in the cwd scope.
///
/// This method always advances `updated_at`, even if summaries are unchanged.
pub async fn upsert_thread_memory(
&self,
thread_id: ThreadId,
raw_memory: &str,
memory_summary: &str,
) -> anyhow::Result<ThreadMemory> {
let Some(thread) = self.get_thread(thread_id).await? else {
return Err(anyhow::anyhow!("thread not found: {thread_id}"));
};
let scope_key = thread.cwd.display().to_string();
self.upsert_thread_memory_for_scope(
thread_id,
MEMORY_SCOPE_KIND_CWD,
scope_key.as_str(),
raw_memory,
memory_summary,
)
.await
}
/// Insert or update memory summaries for a thread in an explicit scope.
pub async fn upsert_thread_memory_for_scope(
&self,
thread_id: ThreadId,
scope_kind: &str,
scope_key: &str,
raw_memory: &str,
memory_summary: &str,
) -> anyhow::Result<ThreadMemory> {
if self.get_thread(thread_id).await?.is_none() {
return Err(anyhow::anyhow!("thread not found: {thread_id}"));
}
let updated_at = Utc::now().timestamp();
sqlx::query(
r#"
INSERT INTO thread_memory (
thread_id,
scope_kind,
scope_key,
raw_memory,
memory_summary,
updated_at
) VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(thread_id, scope_kind, scope_key) DO UPDATE SET
raw_memory = excluded.raw_memory,
memory_summary = excluded.memory_summary,
updated_at = CASE
WHEN excluded.updated_at <= thread_memory.updated_at THEN thread_memory.updated_at + 1
ELSE excluded.updated_at
END
"#,
)
.bind(thread_id.to_string())
.bind(scope_kind)
.bind(scope_key)
.bind(raw_memory)
.bind(memory_summary)
.bind(updated_at)
.execute(self.pool.as_ref())
.await?;
let row = sqlx::query(
r#"
SELECT
thread_id,
scope_kind,
scope_key,
raw_memory,
memory_summary,
updated_at,
last_used_at,
used_count,
invalidated_at,
invalid_reason
FROM thread_memory
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
"#,
)
.bind(thread_id.to_string())
.bind(scope_kind)
.bind(scope_key)
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from))
.transpose()?
.ok_or_else(|| anyhow::anyhow!("failed to load upserted thread memory: {thread_id}"))
}
/// Insert or update memory summaries for a thread/scope only if the caller
/// still owns the corresponding phase-1 running job.
pub async fn upsert_thread_memory_for_scope_if_phase1_owner(
&self,
thread_id: ThreadId,
scope_kind: &str,
scope_key: &str,
ownership_token: &str,
raw_memory: &str,
memory_summary: &str,
) -> anyhow::Result<Option<ThreadMemory>> {
if self.get_thread(thread_id).await?.is_none() {
return Err(anyhow::anyhow!("thread not found: {thread_id}"));
}
let updated_at = Utc::now().timestamp();
let rows_affected = sqlx::query(
r#"
INSERT INTO thread_memory (
thread_id,
scope_kind,
scope_key,
raw_memory,
memory_summary,
updated_at
)
SELECT ?, ?, ?, ?, ?, ?
WHERE EXISTS (
SELECT 1
FROM memory_phase1_jobs
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
AND status = 'running' AND ownership_token = ?
)
ON CONFLICT(thread_id, scope_kind, scope_key) DO UPDATE SET
raw_memory = excluded.raw_memory,
memory_summary = excluded.memory_summary,
updated_at = CASE
WHEN excluded.updated_at <= thread_memory.updated_at THEN thread_memory.updated_at + 1
ELSE excluded.updated_at
END
"#,
)
.bind(thread_id.to_string())
.bind(scope_kind)
.bind(scope_key)
.bind(raw_memory)
.bind(memory_summary)
.bind(updated_at)
.bind(thread_id.to_string())
.bind(scope_kind)
.bind(scope_key)
.bind(ownership_token)
.execute(self.pool.as_ref())
.await?
.rows_affected();
if rows_affected == 0 {
return Ok(None);
}
let row = sqlx::query(
r#"
SELECT
thread_id,
scope_kind,
scope_key,
raw_memory,
memory_summary,
updated_at,
last_used_at,
used_count,
invalidated_at,
invalid_reason
FROM thread_memory
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
"#,
)
.bind(thread_id.to_string())
.bind(scope_kind)
.bind(scope_key)
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from))
.transpose()
}
/// Get the last `n` memories for threads with an exact cwd match.
pub async fn get_last_n_thread_memories_for_cwd(
&self,
cwd: &Path,
n: usize,
) -> anyhow::Result<Vec<ThreadMemory>> {
self.get_last_n_thread_memories_for_scope(
MEMORY_SCOPE_KIND_CWD,
&cwd.display().to_string(),
n,
)
.await
}
/// Get the last `n` memories for a specific memory scope.
pub async fn get_last_n_thread_memories_for_scope(
&self,
scope_kind: &str,
scope_key: &str,
n: usize,
) -> anyhow::Result<Vec<ThreadMemory>> {
if n == 0 {
return Ok(Vec::new());
}
let rows = sqlx::query(
r#"
SELECT
thread_id,
scope_kind,
scope_key,
raw_memory,
memory_summary,
updated_at,
last_used_at,
used_count,
invalidated_at,
invalid_reason
FROM thread_memory
WHERE scope_kind = ? AND scope_key = ? AND invalidated_at IS NULL
ORDER BY updated_at DESC, thread_id DESC
LIMIT ?
"#,
)
.bind(scope_kind)
.bind(scope_key)
.bind(n as i64)
.fetch_all(self.pool.as_ref())
.await?;
rows.into_iter()
.map(|row| ThreadMemoryRow::try_from_row(&row).and_then(ThreadMemory::try_from))
.collect()
}
/// Try to claim a phase-1 memory extraction job for `(thread, scope)`.
pub async fn try_claim_phase1_job(
&self,
thread_id: ThreadId,
scope_kind: &str,
scope_key: &str,
owner_session_id: ThreadId,
source_updated_at: i64,
lease_seconds: i64,
) -> anyhow::Result<Phase1JobClaimOutcome> {
let now = Utc::now().timestamp();
let stale_cutoff = now.saturating_sub(lease_seconds.max(0));
let ownership_token = Uuid::new_v4().to_string();
let thread_id = thread_id.to_string();
let owner_session_id = owner_session_id.to_string();
let mut tx = self.pool.begin().await?;
let existing = sqlx::query(
r#"
SELECT status, source_updated_at, started_at
FROM memory_phase1_jobs
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
"#,
)
.bind(thread_id.as_str())
.bind(scope_kind)
.bind(scope_key)
.fetch_optional(&mut *tx)
.await?;
let Some(existing) = existing else {
sqlx::query(
r#"
INSERT INTO memory_phase1_jobs (
thread_id,
scope_kind,
scope_key,
status,
owner_session_id,
started_at,
finished_at,
failure_reason,
source_updated_at,
raw_memory_path,
summary_hash,
ownership_token
) VALUES (?, ?, ?, 'running', ?, ?, NULL, NULL, ?, NULL, NULL, ?)
"#,
)
.bind(thread_id.as_str())
.bind(scope_kind)
.bind(scope_key)
.bind(owner_session_id.as_str())
.bind(now)
.bind(source_updated_at)
.bind(ownership_token.as_str())
.execute(&mut *tx)
.await?;
tx.commit().await?;
return Ok(Phase1JobClaimOutcome::Claimed { ownership_token });
};
let status: String = existing.try_get("status")?;
let existing_source_updated_at: i64 = existing.try_get("source_updated_at")?;
let existing_started_at: Option<i64> = existing.try_get("started_at")?;
if status == "failed" {
tx.commit().await?;
return Ok(Phase1JobClaimOutcome::SkippedTerminalFailure);
}
if status == "succeeded" && existing_source_updated_at >= source_updated_at {
tx.commit().await?;
return Ok(Phase1JobClaimOutcome::SkippedUpToDate);
}
if status == "running" && existing_started_at.is_some_and(|started| started > stale_cutoff)
{
tx.commit().await?;
return Ok(Phase1JobClaimOutcome::SkippedRunning);
}
let rows_affected = if let Some(existing_started_at) = existing_started_at {
sqlx::query(
r#"
UPDATE memory_phase1_jobs
SET
status = 'running',
owner_session_id = ?,
started_at = ?,
finished_at = NULL,
failure_reason = NULL,
source_updated_at = ?,
raw_memory_path = NULL,
summary_hash = NULL,
ownership_token = ?
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
AND status = ? AND source_updated_at = ? AND started_at = ?
"#,
)
.bind(owner_session_id.as_str())
.bind(now)
.bind(source_updated_at)
.bind(ownership_token.as_str())
.bind(thread_id.as_str())
.bind(scope_kind)
.bind(scope_key)
.bind(status.as_str())
.bind(existing_source_updated_at)
.bind(existing_started_at)
.execute(&mut *tx)
.await?
.rows_affected()
} else {
sqlx::query(
r#"
UPDATE memory_phase1_jobs
SET
status = 'running',
owner_session_id = ?,
started_at = ?,
finished_at = NULL,
failure_reason = NULL,
source_updated_at = ?,
raw_memory_path = NULL,
summary_hash = NULL,
ownership_token = ?
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
AND status = ? AND source_updated_at = ? AND started_at IS NULL
"#,
)
.bind(owner_session_id.as_str())
.bind(now)
.bind(source_updated_at)
.bind(ownership_token.as_str())
.bind(thread_id.as_str())
.bind(scope_kind)
.bind(scope_key)
.bind(status.as_str())
.bind(existing_source_updated_at)
.execute(&mut *tx)
.await?
.rows_affected()
};
tx.commit().await?;
if rows_affected == 0 {
Ok(Phase1JobClaimOutcome::SkippedRunning)
} else {
Ok(Phase1JobClaimOutcome::Claimed { ownership_token })
}
}
/// Finalize a claimed phase-1 job as succeeded.
pub async fn mark_phase1_job_succeeded(
&self,
thread_id: ThreadId,
scope_kind: &str,
scope_key: &str,
ownership_token: &str,
raw_memory_path: &str,
summary_hash: &str,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let rows_affected = sqlx::query(
r#"
UPDATE memory_phase1_jobs
SET
status = 'succeeded',
finished_at = ?,
failure_reason = NULL,
raw_memory_path = ?,
summary_hash = ?
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
AND status = 'running' AND ownership_token = ?
"#,
)
.bind(now)
.bind(raw_memory_path)
.bind(summary_hash)
.bind(thread_id.to_string())
.bind(scope_kind)
.bind(scope_key)
.bind(ownership_token)
.execute(self.pool.as_ref())
.await?
.rows_affected();
Ok(rows_affected > 0)
}
/// Finalize a claimed phase-1 job as failed.
pub async fn mark_phase1_job_failed(
&self,
thread_id: ThreadId,
scope_kind: &str,
scope_key: &str,
ownership_token: &str,
failure_reason: &str,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let rows_affected = sqlx::query(
r#"
UPDATE memory_phase1_jobs
SET
status = 'failed',
finished_at = ?,
failure_reason = ?
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
AND status = 'running' AND ownership_token = ?
"#,
)
.bind(now)
.bind(failure_reason)
.bind(thread_id.to_string())
.bind(scope_kind)
.bind(scope_key)
.bind(ownership_token)
.execute(self.pool.as_ref())
.await?
.rows_affected();
Ok(rows_affected > 0)
}
/// Refresh lease timestamp for a claimed phase-1 job.
///
/// Returns `true` only when the current owner token still matches.
pub async fn renew_phase1_job_lease(
&self,
thread_id: ThreadId,
scope_kind: &str,
scope_key: &str,
ownership_token: &str,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let rows_affected = sqlx::query(
r#"
UPDATE memory_phase1_jobs
SET started_at = ?
WHERE thread_id = ? AND scope_kind = ? AND scope_key = ?
AND status = 'running' AND ownership_token = ?
"#,
)
.bind(now)
.bind(thread_id.to_string())
.bind(scope_kind)
.bind(scope_key)
.bind(ownership_token)
.execute(self.pool.as_ref())
.await?
.rows_affected();
Ok(rows_affected > 0)
}
/// Mark a memory scope as dirty/clean for phase-2 consolidation scheduling.
pub async fn mark_memory_scope_dirty(
&self,
scope_kind: &str,
scope_key: &str,
dirty: bool,
) -> anyhow::Result<()> {
let now = Utc::now().timestamp();
sqlx::query(
r#"
INSERT INTO memory_scope_dirty (scope_kind, scope_key, dirty, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(scope_kind, scope_key) DO UPDATE SET
dirty = excluded.dirty,
updated_at = excluded.updated_at
"#,
)
.bind(scope_kind)
.bind(scope_key)
.bind(dirty)
.bind(now)
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Try to acquire or renew the per-cwd memory consolidation lock.
///
/// Returns `true` when the lock is acquired/renewed for `working_thread_id`.
/// Returns `false` when another owner holds a non-expired lease.
pub async fn try_acquire_memory_consolidation_lock(
&self,
cwd: &Path,
working_thread_id: ThreadId,
lease_seconds: i64,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let stale_cutoff = now.saturating_sub(lease_seconds.max(0));
let result = sqlx::query(
r#"
INSERT INTO memory_consolidation_locks (
cwd,
working_thread_id,
updated_at
) VALUES (?, ?, ?)
ON CONFLICT(cwd) DO UPDATE SET
working_thread_id = excluded.working_thread_id,
updated_at = excluded.updated_at
WHERE memory_consolidation_locks.working_thread_id = excluded.working_thread_id
OR memory_consolidation_locks.updated_at <= ?
"#,
)
.bind(cwd.display().to_string())
.bind(working_thread_id.to_string())
.bind(now)
.bind(stale_cutoff)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
/// Release the per-cwd memory consolidation lock if held by `working_thread_id`.
///
/// Returns `true` when a lock row was removed.
pub async fn release_memory_consolidation_lock(
&self,
cwd: &Path,
working_thread_id: ThreadId,
) -> anyhow::Result<bool> {
let result = sqlx::query(
r#"
DELETE FROM memory_consolidation_locks
WHERE cwd = ? AND working_thread_id = ?
"#,
)
.bind(cwd.display().to_string())
.bind(working_thread_id.to_string())
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
}
/// Persist dynamic tools for a thread if none have been stored yet.
///
/// Dynamic tools are defined at thread start and should not change afterward.
/// This only writes the first time we see tools for a given thread.
pub async fn persist_dynamic_tools(
&self,
thread_id: ThreadId,
tools: Option<&[DynamicToolSpec]>,
) -> anyhow::Result<()> {
let Some(tools) = tools else {
return Ok(());
};
if tools.is_empty() {
return Ok(());
}
let thread_id = thread_id.to_string();
let mut tx = self.pool.begin().await?;
for (idx, tool) in tools.iter().enumerate() {
let position = i64::try_from(idx).unwrap_or(i64::MAX);
let input_schema = serde_json::to_string(&tool.input_schema)?;
sqlx::query(
r#"
INSERT INTO thread_dynamic_tools (
thread_id,
position,
name,
description,
input_schema
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(thread_id, position) DO NOTHING
"#,
)
.bind(thread_id.as_str())
.bind(position)
.bind(tool.name.as_str())
.bind(tool.description.as_str())
.bind(input_schema)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
/// Apply rollout items incrementally using the underlying database.
pub async fn apply_rollout_items(
&self,
builder: &ThreadMetadataBuilder,
items: &[RolloutItem],
otel: Option<&OtelManager>,
) -> anyhow::Result<()> {
if items.is_empty() {
return Ok(());
}
let mut metadata = self
.get_thread(builder.id)
.await?
.unwrap_or_else(|| builder.build(&self.default_provider));
metadata.rollout_path = builder.rollout_path.clone();
for item in items {
apply_rollout_item(&mut metadata, item, &self.default_provider);
}
if let Some(updated_at) = file_modified_time_utc(builder.rollout_path.as_path()).await {
metadata.updated_at = updated_at;
}
// Keep the thread upsert before dynamic tools to satisfy the foreign key constraint:
// thread_dynamic_tools.thread_id -> threads.id.
if let Err(err) = self.upsert_thread(&metadata).await {
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]);
}
return Err(err);
}
let dynamic_tools = extract_dynamic_tools(items);
if let Some(dynamic_tools) = dynamic_tools
&& let Err(err) = self
.persist_dynamic_tools(builder.id, dynamic_tools.as_deref())
.await
{
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "persist_dynamic_tools")]);
}
return Err(err);
}
Ok(())
}
/// Mark a thread as archived using the underlying database.
pub async fn mark_archived(
&self,
thread_id: ThreadId,
rollout_path: &Path,
archived_at: DateTime<Utc>,
) -> anyhow::Result<()> {
let Some(mut metadata) = self.get_thread(thread_id).await? else {
return Ok(());
};
metadata.archived_at = Some(archived_at);
metadata.rollout_path = rollout_path.to_path_buf();
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
metadata.updated_at = updated_at;
}
if metadata.id != thread_id {
warn!(
"thread id mismatch during archive: expected {thread_id}, got {}",
metadata.id
);
}
self.upsert_thread(&metadata).await
}
/// Mark a thread as unarchived using the underlying database.
pub async fn mark_unarchived(
&self,
thread_id: ThreadId,
rollout_path: &Path,
) -> anyhow::Result<()> {
let Some(mut metadata) = self.get_thread(thread_id).await? else {
return Ok(());
};
metadata.archived_at = None;
metadata.rollout_path = rollout_path.to_path_buf();
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
metadata.updated_at = updated_at;
}
if metadata.id != thread_id {
warn!(
"thread id mismatch during unarchive: expected {thread_id}, got {}",
metadata.id
);
}
self.upsert_thread(&metadata).await
}
async fn ensure_backfill_state_row(&self) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO backfill_state (id, status, last_watermark, last_success_at, updated_at)
VALUES (?, ?, NULL, NULL, ?)
ON CONFLICT(id) DO NOTHING
"#,
)
.bind(1_i64)
.bind(crate::BackfillStatus::Pending.as_str())
.bind(Utc::now().timestamp())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
}
fn push_log_filters<'a>(builder: &mut QueryBuilder<'a, Sqlite>, query: &'a LogQuery) {
if let Some(level_upper) = query.level_upper.as_ref() {
builder
.push(" AND UPPER(level) = ")
.push_bind(level_upper.as_str());
}
if let Some(from_ts) = query.from_ts {
builder.push(" AND ts >= ").push_bind(from_ts);
}
if let Some(to_ts) = query.to_ts {
builder.push(" AND ts <= ").push_bind(to_ts);
}
push_like_filters(builder, "module_path", &query.module_like);
push_like_filters(builder, "file", &query.file_like);
let has_thread_filter = !query.thread_ids.is_empty() || query.include_threadless;
if has_thread_filter {
builder.push(" AND (");
let mut needs_or = false;
for thread_id in &query.thread_ids {
if needs_or {
builder.push(" OR ");
}
builder.push("thread_id = ").push_bind(thread_id.as_str());
needs_or = true;
}
if query.include_threadless {
if needs_or {
builder.push(" OR ");
}
builder.push("thread_id IS NULL");
}
builder.push(")");
}
if let Some(after_id) = query.after_id {
builder.push(" AND id > ").push_bind(after_id);
}
}
fn push_like_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
column: &str,
filters: &'a [String],
) {
if filters.is_empty() {
return;
}
builder.push(" AND (");
for (idx, filter) in filters.iter().enumerate() {
if idx > 0 {
builder.push(" OR ");
}
builder
.push(column)
.push(" LIKE '%' || ")
.push_bind(filter.as_str())
.push(" || '%'");
}
builder.push(")");
}
fn extract_dynamic_tools(items: &[RolloutItem]) -> Option<Option<Vec<DynamicToolSpec>>> {
items.iter().find_map(|item| match item {
RolloutItem::SessionMeta(meta_line) => Some(meta_line.meta.dynamic_tools.clone()),
RolloutItem::ResponseItem(_)
| RolloutItem::Compacted(_)
| RolloutItem::TurnContext(_)
| RolloutItem::EventMsg(_) => None,
})
}
async fn open_sqlite(path: &Path) -> anyhow::Result<SqlitePool> {
let options = SqliteConnectOptions::new()
.filename(path)
.create_if_missing(true)
.journal_mode(SqliteJournalMode::Wal)
.synchronous(SqliteSynchronous::Normal)
.busy_timeout(Duration::from_secs(5))
.log_statements(LevelFilter::Off);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
MIGRATOR.run(&pool).await?;
Ok(pool)
}
pub fn state_db_filename() -> String {
format!("{STATE_DB_FILENAME}_{STATE_DB_VERSION}.sqlite")
}
pub fn state_db_path(codex_home: &Path) -> PathBuf {
codex_home.join(state_db_filename())
}
async fn remove_legacy_state_files(codex_home: &Path) {
let current_name = state_db_filename();
let mut entries = match tokio::fs::read_dir(codex_home).await {
Ok(entries) => entries,
Err(err) => {
warn!(
"failed to read codex_home for state db cleanup {}: {err}",
codex_home.display()
);
return;
}
};
while let Ok(Some(entry)) = entries.next_entry().await {
if !entry
.file_type()
.await
.map(|file_type| file_type.is_file())
.unwrap_or(false)
{
continue;
}
let file_name = entry.file_name();
let file_name = file_name.to_string_lossy();
if !should_remove_state_file(file_name.as_ref(), current_name.as_str()) {
continue;
}
let legacy_path = entry.path();
if let Err(err) = tokio::fs::remove_file(&legacy_path).await {
warn!(
"failed to remove legacy state db file {}: {err}",
legacy_path.display()
);
}
}
}
fn should_remove_state_file(file_name: &str, current_name: &str) -> bool {
let mut base_name = file_name;
for suffix in ["-wal", "-shm", "-journal"] {
if let Some(stripped) = file_name.strip_suffix(suffix) {
base_name = stripped;
break;
}
}
if base_name == current_name {
return false;
}
let unversioned_name = format!("{STATE_DB_FILENAME}.sqlite");
if base_name == unversioned_name {
return true;
}
let Some(version_with_extension) = base_name.strip_prefix(&format!("{STATE_DB_FILENAME}_"))
else {
return false;
};
let Some(version_suffix) = version_with_extension.strip_suffix(".sqlite") else {
return false;
};
!version_suffix.is_empty() && version_suffix.chars().all(|ch| ch.is_ascii_digit())
}
fn push_thread_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
archived_only: bool,
allowed_sources: &'a [String],
model_providers: Option<&'a [String]>,
anchor: Option<&crate::Anchor>,
sort_key: SortKey,
) {
builder.push(" WHERE 1 = 1");
if archived_only {
builder.push(" AND archived = 1");
} else {
builder.push(" AND archived = 0");
}
builder.push(" AND first_user_message <> ''");
if !allowed_sources.is_empty() {
builder.push(" AND source IN (");
let mut separated = builder.separated(", ");
for source in allowed_sources {
separated.push_bind(source);
}
separated.push_unseparated(")");
}
if let Some(model_providers) = model_providers
&& !model_providers.is_empty()
{
builder.push(" AND model_provider IN (");
let mut separated = builder.separated(", ");
for provider in model_providers {
separated.push_bind(provider);
}
separated.push_unseparated(")");
}
if let Some(anchor) = anchor {
let anchor_ts = datetime_to_epoch_seconds(anchor.ts);
let column = match sort_key {
SortKey::CreatedAt => "created_at",
SortKey::UpdatedAt => "updated_at",
};
builder.push(" AND (");
builder.push(column);
builder.push(" < ");
builder.push_bind(anchor_ts);
builder.push(" OR (");
builder.push(column);
builder.push(" = ");
builder.push_bind(anchor_ts);
builder.push(" AND id < ");
builder.push_bind(anchor.id.to_string());
builder.push("))");
}
}
fn push_thread_order_and_limit(
builder: &mut QueryBuilder<'_, Sqlite>,
sort_key: SortKey,
limit: usize,
) {
let order_column = match sort_key {
SortKey::CreatedAt => "created_at",
SortKey::UpdatedAt => "updated_at",
};
builder.push(" ORDER BY ");
builder.push(order_column);
builder.push(" DESC, id DESC");
builder.push(" LIMIT ");
builder.push_bind(limit as i64);
}
#[cfg(test)]
mod tests {
use super::Phase1JobClaimOutcome;
use super::STATE_DB_FILENAME;
use super::STATE_DB_VERSION;
use super::StateRuntime;
use super::ThreadMetadata;
use super::state_db_filename;
use chrono::DateTime;
use chrono::Utc;
use codex_protocol::ThreadId;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
use pretty_assertions::assert_eq;
use sqlx::Row;
use std::path::Path;
use std::path::PathBuf;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use uuid::Uuid;
fn unique_temp_dir() -> PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_nanos());
std::env::temp_dir().join(format!(
"codex-state-runtime-test-{nanos}-{}",
Uuid::new_v4()
))
}
#[tokio::test]
async fn init_removes_legacy_state_db_files() {
let codex_home = unique_temp_dir();
tokio::fs::create_dir_all(&codex_home)
.await
.expect("create codex_home");
let current_name = state_db_filename();
let previous_version = STATE_DB_VERSION.saturating_sub(1);
let unversioned_name = format!("{STATE_DB_FILENAME}.sqlite");
for suffix in ["", "-wal", "-shm", "-journal"] {
let path = codex_home.join(format!("{unversioned_name}{suffix}"));
tokio::fs::write(path, b"legacy")
.await
.expect("write legacy");
let old_version_path = codex_home.join(format!(
"{STATE_DB_FILENAME}_{previous_version}.sqlite{suffix}"
));
tokio::fs::write(old_version_path, b"old_version")
.await
.expect("write old version");
}
let unrelated_path = codex_home.join("state.sqlite_backup");
tokio::fs::write(&unrelated_path, b"keep")
.await
.expect("write unrelated");
let numeric_path = codex_home.join("123");
tokio::fs::write(&numeric_path, b"keep")
.await
.expect("write numeric");
let _runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
for suffix in ["", "-wal", "-shm", "-journal"] {
let legacy_path = codex_home.join(format!("{unversioned_name}{suffix}"));
assert_eq!(
tokio::fs::try_exists(&legacy_path)
.await
.expect("check legacy path"),
false
);
let old_version_path = codex_home.join(format!(
"{STATE_DB_FILENAME}_{previous_version}.sqlite{suffix}"
));
assert_eq!(
tokio::fs::try_exists(&old_version_path)
.await
.expect("check old version path"),
false
);
}
assert_eq!(
tokio::fs::try_exists(codex_home.join(current_name))
.await
.expect("check new db path"),
true
);
assert_eq!(
tokio::fs::try_exists(&unrelated_path)
.await
.expect("check unrelated path"),
true
);
assert_eq!(
tokio::fs::try_exists(&numeric_path)
.await
.expect("check numeric path"),
true
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn backfill_state_persists_progress_and_completion() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let initial = runtime
.get_backfill_state()
.await
.expect("get initial backfill state");
assert_eq!(initial.status, crate::BackfillStatus::Pending);
assert_eq!(initial.last_watermark, None);
assert_eq!(initial.last_success_at, None);
runtime
.mark_backfill_running()
.await
.expect("mark backfill running");
runtime
.checkpoint_backfill("sessions/2026/01/27/rollout-a.jsonl")
.await
.expect("checkpoint backfill");
let running = runtime
.get_backfill_state()
.await
.expect("get running backfill state");
assert_eq!(running.status, crate::BackfillStatus::Running);
assert_eq!(
running.last_watermark,
Some("sessions/2026/01/27/rollout-a.jsonl".to_string())
);
assert_eq!(running.last_success_at, None);
runtime
.mark_backfill_complete(Some("sessions/2026/01/28/rollout-b.jsonl"))
.await
.expect("mark backfill complete");
let completed = runtime
.get_backfill_state()
.await
.expect("get completed backfill state");
assert_eq!(completed.status, crate::BackfillStatus::Complete);
assert_eq!(
completed.last_watermark,
Some("sessions/2026/01/28/rollout-b.jsonl".to_string())
);
assert!(completed.last_success_at.is_some());
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn upsert_and_get_thread_memory() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let metadata = test_thread_metadata(&codex_home, thread_id, codex_home.join("a"));
runtime
.upsert_thread(&metadata)
.await
.expect("upsert thread");
assert_eq!(
runtime
.get_thread_memory(thread_id)
.await
.expect("get memory before insert"),
None
);
let inserted = runtime
.upsert_thread_memory(thread_id, "trace one", "memory one")
.await
.expect("upsert memory");
assert_eq!(inserted.thread_id, thread_id);
assert_eq!(inserted.raw_memory, "trace one");
assert_eq!(inserted.memory_summary, "memory one");
let updated = runtime
.upsert_thread_memory(thread_id, "trace two", "memory two")
.await
.expect("update memory");
assert_eq!(updated.thread_id, thread_id);
assert_eq!(updated.raw_memory, "trace two");
assert_eq!(updated.memory_summary, "memory two");
assert!(
updated.updated_at >= inserted.updated_at,
"updated_at should not move backward"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_last_n_thread_memories_for_cwd_matches_exactly() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let cwd_a = codex_home.join("workspace-a");
let cwd_b = codex_home.join("workspace-b");
let t1 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let t2 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let t3 = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, t1, cwd_a.clone()))
.await
.expect("upsert thread t1");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, t2, cwd_a.clone()))
.await
.expect("upsert thread t2");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, t3, cwd_b.clone()))
.await
.expect("upsert thread t3");
let first = runtime
.upsert_thread_memory(t1, "trace-1", "memory-1")
.await
.expect("upsert t1 memory");
runtime
.upsert_thread_memory(t2, "trace-2", "memory-2")
.await
.expect("upsert t2 memory");
runtime
.upsert_thread_memory(t3, "trace-3", "memory-3")
.await
.expect("upsert t3 memory");
// Ensure deterministic ordering even when updates happen in the same second.
runtime
.upsert_thread_memory(t1, "trace-1b", "memory-1b")
.await
.expect("upsert t1 memory again");
let cwd_a_memories = runtime
.get_last_n_thread_memories_for_cwd(cwd_a.as_path(), 2)
.await
.expect("list cwd a memories");
assert_eq!(cwd_a_memories.len(), 2);
assert_eq!(cwd_a_memories[0].thread_id, t1);
assert_eq!(cwd_a_memories[0].raw_memory, "trace-1b");
assert_eq!(cwd_a_memories[0].memory_summary, "memory-1b");
assert_eq!(cwd_a_memories[1].thread_id, t2);
assert!(cwd_a_memories[0].updated_at >= first.updated_at);
let cwd_b_memories = runtime
.get_last_n_thread_memories_for_cwd(cwd_b.as_path(), 10)
.await
.expect("list cwd b memories");
assert_eq!(cwd_b_memories.len(), 1);
assert_eq!(cwd_b_memories[0].thread_id, t3);
let none = runtime
.get_last_n_thread_memories_for_cwd(codex_home.join("missing").as_path(), 10)
.await
.expect("list missing cwd memories");
assert_eq!(none, Vec::new());
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn upsert_thread_memory_errors_for_unknown_thread() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let unknown_thread_id =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let err = runtime
.upsert_thread_memory(unknown_thread_id, "trace", "memory")
.await
.expect_err("unknown thread should fail");
assert!(
err.to_string().contains("thread not found"),
"error should mention missing thread: {err}"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_last_n_thread_memories_for_cwd_zero_returns_empty() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd.clone()))
.await
.expect("upsert thread");
runtime
.upsert_thread_memory(thread_id, "trace", "memory")
.await
.expect("upsert memory");
let memories = runtime
.get_last_n_thread_memories_for_cwd(cwd.as_path(), 0)
.await
.expect("query memories");
assert_eq!(memories, Vec::new());
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_last_n_thread_memories_for_cwd_does_not_prefix_match() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let cwd_exact = codex_home.join("workspace");
let cwd_prefix = codex_home.join("workspace-child");
let t_exact = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let t_prefix = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
t_exact,
cwd_exact.clone(),
))
.await
.expect("upsert exact thread");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
t_prefix,
cwd_prefix.clone(),
))
.await
.expect("upsert prefix thread");
runtime
.upsert_thread_memory(t_exact, "trace-exact", "memory-exact")
.await
.expect("upsert exact memory");
runtime
.upsert_thread_memory(t_prefix, "trace-prefix", "memory-prefix")
.await
.expect("upsert prefix memory");
let exact_only = runtime
.get_last_n_thread_memories_for_cwd(cwd_exact.as_path(), 10)
.await
.expect("query exact cwd");
assert_eq!(exact_only.len(), 1);
assert_eq!(exact_only[0].thread_id, t_exact);
assert_eq!(exact_only[0].memory_summary, "memory-exact");
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn memory_consolidation_lock_enforces_owner_and_release() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let cwd = codex_home.join("workspace");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.await
.expect("acquire for owner_a"),
"owner_a should acquire lock"
);
assert!(
!runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 600)
.await
.expect("acquire for owner_b should fail"),
"owner_b should not steal active lock"
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.await
.expect("owner_a should renew lock"),
"owner_a should renew lock"
);
assert!(
!runtime
.release_memory_consolidation_lock(cwd.as_path(), owner_b)
.await
.expect("owner_b release should be no-op"),
"non-owner release should not remove lock"
);
assert!(
runtime
.release_memory_consolidation_lock(cwd.as_path(), owner_a)
.await
.expect("owner_a release"),
"owner_a should release lock"
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 600)
.await
.expect("owner_b acquire after release"),
"owner_b should acquire released lock"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn memory_consolidation_lock_can_be_stolen_when_lease_expired() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let cwd = codex_home.join("workspace");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.await
.expect("owner_a acquire")
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 0)
.await
.expect("owner_b steal with expired lease"),
"owner_b should steal lock when lease cutoff marks previous lock stale"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn phase1_job_claim_and_success_require_current_owner_token() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd))
.await
.expect("upsert thread");
let claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner, 100, 3600)
.await
.expect("claim phase1 job");
let ownership_token = match claim {
Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
assert!(
!runtime
.mark_phase1_job_succeeded(
thread_id,
"cwd",
"scope",
"wrong-token",
"/tmp/path",
"summary-hash"
)
.await
.expect("mark succeeded wrong token should fail"),
"wrong token should not finalize the job"
);
assert!(
runtime
.mark_phase1_job_succeeded(
thread_id,
"cwd",
"scope",
ownership_token.as_str(),
"/tmp/path",
"summary-hash"
)
.await
.expect("mark succeeded with current token"),
"current token should finalize the job"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn phase1_job_running_stale_can_be_stolen_but_fresh_running_is_skipped() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd))
.await
.expect("upsert thread");
let first_claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600)
.await
.expect("first claim");
assert!(
matches!(first_claim, Phase1JobClaimOutcome::Claimed { .. }),
"first claim should acquire"
);
let fresh_second_claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 3600)
.await
.expect("fresh second claim");
assert_eq!(fresh_second_claim, Phase1JobClaimOutcome::SkippedRunning);
let stale_second_claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0)
.await
.expect("stale second claim");
assert!(
matches!(stale_second_claim, Phase1JobClaimOutcome::Claimed { .. }),
"stale running job should be stealable"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn phase1_job_lease_renewal_requires_current_owner_token() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd))
.await
.expect("upsert thread");
let first_claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600)
.await
.expect("first claim");
let owner_a_token = match first_claim {
Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
let stolen_claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0)
.await
.expect("stolen claim");
let owner_b_token = match stolen_claim {
Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
assert!(
!runtime
.renew_phase1_job_lease(thread_id, "cwd", "scope", owner_a_token.as_str())
.await
.expect("old owner lease renewal should fail"),
"stale owner token should not renew lease"
);
assert!(
runtime
.renew_phase1_job_lease(thread_id, "cwd", "scope", owner_b_token.as_str())
.await
.expect("current owner lease renewal should succeed"),
"current owner token should renew lease"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn phase1_owner_guarded_upsert_rejects_stale_owner() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd))
.await
.expect("upsert thread");
let first_claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600)
.await
.expect("first claim");
let owner_a_token = match first_claim {
Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
let stolen_claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 100, 0)
.await
.expect("stolen claim");
let owner_b_token = match stolen_claim {
Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
let stale_upsert = runtime
.upsert_thread_memory_for_scope_if_phase1_owner(
thread_id,
"cwd",
"scope",
owner_a_token.as_str(),
"stale raw memory",
"stale summary",
)
.await
.expect("stale owner upsert");
assert!(
stale_upsert.is_none(),
"stale owner token should not upsert thread memory"
);
let current_upsert = runtime
.upsert_thread_memory_for_scope_if_phase1_owner(
thread_id,
"cwd",
"scope",
owner_b_token.as_str(),
"fresh raw memory",
"fresh summary",
)
.await
.expect("current owner upsert");
let current_upsert = current_upsert.expect("current owner should upsert");
assert_eq!(current_upsert.raw_memory, "fresh raw memory");
assert_eq!(current_upsert.memory_summary, "fresh summary");
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn phase1_job_failed_is_terminal() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd))
.await
.expect("upsert thread");
let claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_a, 100, 3600)
.await
.expect("claim");
let ownership_token = match claim {
Phase1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
assert!(
runtime
.mark_phase1_job_failed(
thread_id,
"cwd",
"scope",
ownership_token.as_str(),
"prompt failed"
)
.await
.expect("mark failed"),
"owner token should be able to fail job"
);
let second_claim = runtime
.try_claim_phase1_job(thread_id, "cwd", "scope", owner_b, 101, 3600)
.await
.expect("second claim");
assert_eq!(second_claim, Phase1JobClaimOutcome::SkippedTerminalFailure);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn deleting_thread_cascades_thread_memory() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let cwd = codex_home.join("workspace");
runtime
.upsert_thread(&test_thread_metadata(&codex_home, thread_id, cwd))
.await
.expect("upsert thread");
runtime
.upsert_thread_memory(thread_id, "trace", "memory")
.await
.expect("upsert memory");
let count_before =
sqlx::query("SELECT COUNT(*) AS count FROM thread_memory WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("count before delete")
.try_get::<i64, _>("count")
.expect("count value");
assert_eq!(count_before, 1);
sqlx::query("DELETE FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.execute(runtime.pool.as_ref())
.await
.expect("delete thread");
let count_after =
sqlx::query("SELECT COUNT(*) AS count FROM thread_memory WHERE thread_id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("count after delete")
.try_get::<i64, _>("count")
.expect("count value");
assert_eq!(count_after, 0);
assert_eq!(
runtime
.get_thread_memory(thread_id)
.await
.expect("get memory after delete"),
None
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
fn test_thread_metadata(
codex_home: &Path,
thread_id: ThreadId,
cwd: PathBuf,
) -> ThreadMetadata {
let now = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).expect("timestamp");
ThreadMetadata {
id: thread_id,
rollout_path: codex_home.join(format!("rollout-{thread_id}.jsonl")),
created_at: now,
updated_at: now,
source: "cli".to_string(),
model_provider: "test-provider".to_string(),
cwd,
cli_version: "0.0.0".to_string(),
title: String::new(),
sandbox_policy: crate::extract::enum_to_string(&SandboxPolicy::ReadOnly),
approval_mode: crate::extract::enum_to_string(&AskForApproval::OnRequest),
tokens_used: 0,
first_user_message: Some("hello".to_string()),
archived_at: None,
git_sha: None,
git_branch: None,
git_origin_url: None,
}
}
}