core-agent-ide/codex-rs/state/src/runtime.rs
jif-oai 3878c3dc7c
feat: sqlite 1 (#10004)
Add a `.sqlite` database to be used to store rollout metatdata (and
later logs)
This PR is phase 1:
* Add the database and the required infrastructure
* Add a backfill of the database
* Persist the newly created rollout both in files and in the DB
* When we need to get metadata or a rollout, consider the `JSONL` as the
source of truth but compare the results with the DB and show any errors
2026-01-28 15:29:14 +01:00

458 lines
14 KiB
Rust

use crate::DB_ERROR_METRIC;
use crate::SortKey;
use crate::ThreadMetadata;
use crate::ThreadMetadataBuilder;
use crate::ThreadsPage;
use crate::apply_rollout_item;
use crate::migrations::MIGRATOR;
use crate::model::ThreadRow;
use crate::model::anchor_from_item;
use crate::model::datetime_to_epoch_seconds;
use crate::paths::file_modified_time_utc;
use chrono::DateTime;
use chrono::Utc;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use codex_protocol::protocol::RolloutItem;
use sqlx::QueryBuilder;
use sqlx::Row;
use sqlx::Sqlite;
use sqlx::SqlitePool;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::sqlite::SqliteJournalMode;
use sqlx::sqlite::SqlitePoolOptions;
use sqlx::sqlite::SqliteSynchronous;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tracing::warn;
pub const STATE_DB_FILENAME: &str = "state.sqlite";
const METRIC_DB_INIT: &str = "codex.db.init";
#[derive(Clone)]
pub struct StateRuntime {
codex_home: PathBuf,
default_provider: String,
pool: Arc<sqlx::SqlitePool>,
}
impl StateRuntime {
/// Initialize the state runtime using the provided Codex home and default provider.
///
/// This opens (and migrates) the SQLite database at `codex_home/state.sqlite`.
pub async fn init(
codex_home: PathBuf,
default_provider: String,
otel: Option<OtelManager>,
) -> anyhow::Result<Arc<Self>> {
tokio::fs::create_dir_all(&codex_home).await?;
let state_path = codex_home.join(STATE_DB_FILENAME);
let existed = tokio::fs::try_exists(&state_path).await.unwrap_or(false);
let pool = match open_sqlite(&state_path).await {
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open state db at {}: {err}", state_path.display());
if let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "open_error")]);
}
return Err(err);
}
};
if let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "opened")]);
}
let runtime = Arc::new(Self {
pool,
codex_home,
default_provider,
});
if !existed && let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "created")]);
}
Ok(runtime)
}
/// Return the configured Codex home directory for this runtime.
pub fn codex_home(&self) -> &Path {
self.codex_home.as_path()
}
/// Load thread metadata by id using the underlying database.
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
let row = sqlx::query(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
WHERE id = ?
"#,
)
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.transpose()
}
/// Find a rollout path by thread id using the underlying database.
pub async fn find_rollout_path_by_id(
&self,
id: ThreadId,
archived_only: Option<bool>,
) -> anyhow::Result<Option<PathBuf>> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
builder.push_bind(id.to_string());
match archived_only {
Some(true) => {
builder.push(" AND archived = 1");
}
Some(false) => {
builder.push(" AND archived = 0");
}
None => {}
}
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
Ok(row
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
.map(PathBuf::from))
}
/// List threads using the underlying database.
pub async fn list_threads(
&self,
page_size: usize,
anchor: Option<&crate::Anchor>,
sort_key: crate::SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<crate::ThreadsPage> {
let limit = page_size.saturating_add(1);
let mut builder = QueryBuilder::<Sqlite>::new(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
"#,
);
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
anchor,
sort_key,
);
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
let mut items = rows
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect::<Result<Vec<_>, _>>()?;
let num_scanned_rows = items.len();
let next_anchor = if items.len() > page_size {
items.pop();
items
.last()
.and_then(|item| anchor_from_item(item, sort_key))
} else {
None
};
Ok(ThreadsPage {
items,
next_anchor,
num_scanned_rows,
})
}
/// List thread ids using the underlying database (no rollout scanning).
pub async fn list_thread_ids(
&self,
limit: usize,
anchor: Option<&crate::Anchor>,
sort_key: crate::SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<Vec<ThreadId>> {
let mut builder = QueryBuilder::<Sqlite>::new("SELECT id FROM threads");
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
anchor,
sort_key,
);
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
rows.into_iter()
.map(|row| {
let id: String = row.try_get("id")?;
Ok(ThreadId::try_from(id)?)
})
.collect()
}
/// Insert or replace thread metadata directly.
pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO threads (
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived,
archived_at,
git_sha,
git_branch,
git_origin_url
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
rollout_path = excluded.rollout_path,
created_at = excluded.created_at,
updated_at = excluded.updated_at,
source = excluded.source,
model_provider = excluded.model_provider,
cwd = excluded.cwd,
title = excluded.title,
sandbox_policy = excluded.sandbox_policy,
approval_mode = excluded.approval_mode,
tokens_used = excluded.tokens_used,
has_user_event = excluded.has_user_event,
archived = excluded.archived,
archived_at = excluded.archived_at,
git_sha = excluded.git_sha,
git_branch = excluded.git_branch,
git_origin_url = excluded.git_origin_url
"#,
)
.bind(metadata.id.to_string())
.bind(metadata.rollout_path.display().to_string())
.bind(datetime_to_epoch_seconds(metadata.created_at))
.bind(datetime_to_epoch_seconds(metadata.updated_at))
.bind(metadata.source.as_str())
.bind(metadata.model_provider.as_str())
.bind(metadata.cwd.display().to_string())
.bind(metadata.title.as_str())
.bind(metadata.sandbox_policy.as_str())
.bind(metadata.approval_mode.as_str())
.bind(metadata.tokens_used)
.bind(metadata.has_user_event)
.bind(metadata.archived_at.is_some())
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Apply rollout items incrementally using the underlying database.
pub async fn apply_rollout_items(
&self,
builder: &ThreadMetadataBuilder,
items: &[RolloutItem],
otel: Option<&OtelManager>,
) -> anyhow::Result<()> {
if items.is_empty() {
return Ok(());
}
let mut metadata = self
.get_thread(builder.id)
.await?
.unwrap_or_else(|| builder.build(&self.default_provider));
metadata.rollout_path = builder.rollout_path.clone();
for item in items {
apply_rollout_item(&mut metadata, item, &self.default_provider);
}
if let Some(updated_at) = file_modified_time_utc(builder.rollout_path.as_path()).await {
metadata.updated_at = updated_at;
}
if let Err(err) = self.upsert_thread(&metadata).await {
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]);
}
return Err(err);
}
Ok(())
}
/// Mark a thread as archived using the underlying database.
pub async fn mark_archived(
&self,
thread_id: ThreadId,
rollout_path: &Path,
archived_at: DateTime<Utc>,
) -> anyhow::Result<()> {
let Some(mut metadata) = self.get_thread(thread_id).await? else {
return Ok(());
};
metadata.archived_at = Some(archived_at);
metadata.rollout_path = rollout_path.to_path_buf();
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
metadata.updated_at = updated_at;
}
if metadata.id != thread_id {
warn!(
"thread id mismatch during archive: expected {thread_id}, got {}",
metadata.id
);
}
self.upsert_thread(&metadata).await
}
/// Mark a thread as unarchived using the underlying database.
pub async fn mark_unarchived(
&self,
thread_id: ThreadId,
rollout_path: &Path,
) -> anyhow::Result<()> {
let Some(mut metadata) = self.get_thread(thread_id).await? else {
return Ok(());
};
metadata.archived_at = None;
metadata.rollout_path = rollout_path.to_path_buf();
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
metadata.updated_at = updated_at;
}
if metadata.id != thread_id {
warn!(
"thread id mismatch during unarchive: expected {thread_id}, got {}",
metadata.id
);
}
self.upsert_thread(&metadata).await
}
}
async fn open_sqlite(path: &Path) -> anyhow::Result<SqlitePool> {
let options = SqliteConnectOptions::new()
.filename(path)
.create_if_missing(true)
.journal_mode(SqliteJournalMode::Wal)
.synchronous(SqliteSynchronous::Normal)
.busy_timeout(Duration::from_secs(5));
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
MIGRATOR.run(&pool).await?;
Ok(pool)
}
fn push_thread_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
archived_only: bool,
allowed_sources: &'a [String],
model_providers: Option<&'a [String]>,
anchor: Option<&crate::Anchor>,
sort_key: SortKey,
) {
builder.push(" WHERE 1 = 1");
if archived_only {
builder.push(" AND archived = 1");
} else {
builder.push(" AND archived = 0");
}
builder.push(" AND has_user_event = 1");
if !allowed_sources.is_empty() {
builder.push(" AND source IN (");
let mut separated = builder.separated(", ");
for source in allowed_sources {
separated.push_bind(source);
}
separated.push_unseparated(")");
}
if let Some(model_providers) = model_providers
&& !model_providers.is_empty()
{
builder.push(" AND model_provider IN (");
let mut separated = builder.separated(", ");
for provider in model_providers {
separated.push_bind(provider);
}
separated.push_unseparated(")");
}
if let Some(anchor) = anchor {
let anchor_ts = datetime_to_epoch_seconds(anchor.ts);
let column = match sort_key {
SortKey::CreatedAt => "created_at",
SortKey::UpdatedAt => "updated_at",
};
builder.push(" AND (");
builder.push(column);
builder.push(" < ");
builder.push_bind(anchor_ts);
builder.push(" OR (");
builder.push(column);
builder.push(" = ");
builder.push_bind(anchor_ts);
builder.push(" AND id < ");
builder.push_bind(anchor.id.to_string());
builder.push("))");
}
}
fn push_thread_order_and_limit(
builder: &mut QueryBuilder<'_, Sqlite>,
sort_key: SortKey,
limit: usize,
) {
let order_column = match sort_key {
SortKey::CreatedAt => "created_at",
SortKey::UpdatedAt => "updated_at",
};
builder.push(" ORDER BY ");
builder.push(order_column);
builder.push(" DESC, id DESC");
builder.push(" LIMIT ");
builder.push_bind(limit as i64);
}