chore: clean DB runtime (#12905)
This commit is contained in:
parent
382fa338b3
commit
79d6f80e41
7 changed files with 4406 additions and 4357 deletions
File diff suppressed because it is too large
Load diff
562
codex-rs/state/src/runtime/agent_jobs.rs
Normal file
562
codex-rs/state/src/runtime/agent_jobs.rs
Normal file
|
|
@ -0,0 +1,562 @@
|
|||
use super::*;
|
||||
use crate::model::AgentJobItemRow;
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn create_agent_job(
|
||||
&self,
|
||||
params: &AgentJobCreateParams,
|
||||
items: &[AgentJobItemCreateParams],
|
||||
) -> anyhow::Result<AgentJob> {
|
||||
let now = Utc::now().timestamp();
|
||||
let input_headers_json = serde_json::to_string(¶ms.input_headers)?;
|
||||
let output_schema_json = params
|
||||
.output_schema_json
|
||||
.as_ref()
|
||||
.map(serde_json::to_string)
|
||||
.transpose()?;
|
||||
let max_runtime_seconds = params
|
||||
.max_runtime_seconds
|
||||
.map(i64::try_from)
|
||||
.transpose()
|
||||
.map_err(|_| anyhow::anyhow!("invalid max_runtime_seconds value"))?;
|
||||
let mut tx = self.pool.begin().await?;
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO agent_jobs (
|
||||
id,
|
||||
name,
|
||||
status,
|
||||
instruction,
|
||||
auto_export,
|
||||
max_runtime_seconds,
|
||||
output_schema_json,
|
||||
input_headers_json,
|
||||
input_csv_path,
|
||||
output_csv_path,
|
||||
created_at,
|
||||
updated_at,
|
||||
started_at,
|
||||
completed_at,
|
||||
last_error
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL, NULL)
|
||||
"#,
|
||||
)
|
||||
.bind(params.id.as_str())
|
||||
.bind(params.name.as_str())
|
||||
.bind(AgentJobStatus::Pending.as_str())
|
||||
.bind(params.instruction.as_str())
|
||||
.bind(i64::from(params.auto_export))
|
||||
.bind(max_runtime_seconds)
|
||||
.bind(output_schema_json)
|
||||
.bind(input_headers_json)
|
||||
.bind(params.input_csv_path.as_str())
|
||||
.bind(params.output_csv_path.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
for item in items {
|
||||
let row_json = serde_json::to_string(&item.row_json)?;
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO agent_job_items (
|
||||
job_id,
|
||||
item_id,
|
||||
row_index,
|
||||
source_id,
|
||||
row_json,
|
||||
status,
|
||||
assigned_thread_id,
|
||||
attempt_count,
|
||||
result_json,
|
||||
last_error,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
reported_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, NULL, 0, NULL, NULL, ?, ?, NULL, NULL)
|
||||
"#,
|
||||
)
|
||||
.bind(params.id.as_str())
|
||||
.bind(item.item_id.as_str())
|
||||
.bind(item.row_index)
|
||||
.bind(item.source_id.as_deref())
|
||||
.bind(row_json)
|
||||
.bind(AgentJobItemStatus::Pending.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
let job_id = params.id.as_str();
|
||||
self.get_agent_job(job_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to load created agent job {job_id}"))
|
||||
}
|
||||
|
||||
pub async fn get_agent_job(&self, job_id: &str) -> anyhow::Result<Option<AgentJob>> {
|
||||
let row = sqlx::query_as::<_, AgentJobRow>(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
status,
|
||||
instruction,
|
||||
auto_export,
|
||||
max_runtime_seconds,
|
||||
output_schema_json,
|
||||
input_headers_json,
|
||||
input_csv_path,
|
||||
output_csv_path,
|
||||
created_at,
|
||||
updated_at,
|
||||
started_at,
|
||||
completed_at,
|
||||
last_error
|
||||
FROM agent_jobs
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
row.map(AgentJob::try_from).transpose()
|
||||
}
|
||||
|
||||
pub async fn list_agent_job_items(
|
||||
&self,
|
||||
job_id: &str,
|
||||
status: Option<AgentJobItemStatus>,
|
||||
limit: Option<usize>,
|
||||
) -> anyhow::Result<Vec<AgentJobItem>> {
|
||||
let mut builder = QueryBuilder::<Sqlite>::new(
|
||||
r#"
|
||||
SELECT
|
||||
job_id,
|
||||
item_id,
|
||||
row_index,
|
||||
source_id,
|
||||
row_json,
|
||||
status,
|
||||
assigned_thread_id,
|
||||
attempt_count,
|
||||
result_json,
|
||||
last_error,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
reported_at
|
||||
FROM agent_job_items
|
||||
WHERE job_id =
|
||||
"#,
|
||||
);
|
||||
builder.push_bind(job_id);
|
||||
if let Some(status) = status {
|
||||
builder.push(" AND status = ");
|
||||
builder.push_bind(status.as_str());
|
||||
}
|
||||
builder.push(" ORDER BY row_index ASC");
|
||||
if let Some(limit) = limit {
|
||||
builder.push(" LIMIT ");
|
||||
builder.push_bind(limit as i64);
|
||||
}
|
||||
let rows: Vec<AgentJobItemRow> = builder
|
||||
.build_query_as::<AgentJobItemRow>()
|
||||
.fetch_all(self.pool.as_ref())
|
||||
.await?;
|
||||
rows.into_iter().map(AgentJobItem::try_from).collect()
|
||||
}
|
||||
|
||||
pub async fn get_agent_job_item(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
) -> anyhow::Result<Option<AgentJobItem>> {
|
||||
let row: Option<AgentJobItemRow> = sqlx::query_as::<_, AgentJobItemRow>(
|
||||
r#"
|
||||
SELECT
|
||||
job_id,
|
||||
item_id,
|
||||
row_index,
|
||||
source_id,
|
||||
row_json,
|
||||
status,
|
||||
assigned_thread_id,
|
||||
attempt_count,
|
||||
result_json,
|
||||
last_error,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
reported_at
|
||||
FROM agent_job_items
|
||||
WHERE job_id = ? AND item_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
row.map(AgentJobItem::try_from).transpose()
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_running(&self, job_id: &str) -> anyhow::Result<()> {
|
||||
let now = Utc::now().timestamp();
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_jobs
|
||||
SET
|
||||
status = ?,
|
||||
updated_at = ?,
|
||||
started_at = COALESCE(started_at, ?),
|
||||
completed_at = NULL,
|
||||
last_error = NULL
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobStatus::Running.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(job_id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_completed(&self, job_id: &str) -> anyhow::Result<()> {
|
||||
let now = Utc::now().timestamp();
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_jobs
|
||||
SET status = ?, updated_at = ?, completed_at = ?, last_error = NULL
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobStatus::Completed.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(job_id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_failed(
|
||||
&self,
|
||||
job_id: &str,
|
||||
error_message: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let now = Utc::now().timestamp();
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_jobs
|
||||
SET status = ?, updated_at = ?, completed_at = ?, last_error = ?
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobStatus::Failed.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(error_message)
|
||||
.bind(job_id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_cancelled(
|
||||
&self,
|
||||
job_id: &str,
|
||||
reason: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_jobs
|
||||
SET status = ?, updated_at = ?, completed_at = ?, last_error = ?
|
||||
WHERE id = ? AND status IN (?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobStatus::Cancelled.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(reason)
|
||||
.bind(job_id)
|
||||
.bind(AgentJobStatus::Pending.as_str())
|
||||
.bind(AgentJobStatus::Running.as_str())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn is_agent_job_cancelled(&self, job_id: &str) -> anyhow::Result<bool> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT status
|
||||
FROM agent_jobs
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
let Some(row) = row else {
|
||||
return Ok(false);
|
||||
};
|
||||
let status: String = row.try_get("status")?;
|
||||
Ok(AgentJobStatus::parse(status.as_str())? == AgentJobStatus::Cancelled)
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_item_running(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
status = ?,
|
||||
assigned_thread_id = NULL,
|
||||
attempt_count = attempt_count + 1,
|
||||
updated_at = ?,
|
||||
last_error = NULL
|
||||
WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.bind(now)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.bind(AgentJobItemStatus::Pending.as_str())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_item_running_with_thread(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
thread_id: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
status = ?,
|
||||
assigned_thread_id = ?,
|
||||
attempt_count = attempt_count + 1,
|
||||
updated_at = ?,
|
||||
last_error = NULL
|
||||
WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.bind(thread_id)
|
||||
.bind(now)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.bind(AgentJobItemStatus::Pending.as_str())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_item_pending(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
error_message: Option<&str>,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
status = ?,
|
||||
assigned_thread_id = NULL,
|
||||
updated_at = ?,
|
||||
last_error = ?
|
||||
WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobItemStatus::Pending.as_str())
|
||||
.bind(now)
|
||||
.bind(error_message)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn set_agent_job_item_thread(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
thread_id: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET assigned_thread_id = ?, updated_at = ?
|
||||
WHERE job_id = ? AND item_id = ? AND status = ?
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id)
|
||||
.bind(now)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn report_agent_job_item_result(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
reporting_thread_id: &str,
|
||||
result_json: &Value,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let serialized = serde_json::to_string(result_json)?;
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
result_json = ?,
|
||||
reported_at = ?,
|
||||
updated_at = ?,
|
||||
last_error = NULL
|
||||
WHERE
|
||||
job_id = ?
|
||||
AND item_id = ?
|
||||
AND status = ?
|
||||
AND assigned_thread_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(serialized)
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.bind(reporting_thread_id)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_item_completed(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
status = ?,
|
||||
completed_at = ?,
|
||||
updated_at = ?,
|
||||
assigned_thread_id = NULL
|
||||
WHERE
|
||||
job_id = ?
|
||||
AND item_id = ?
|
||||
AND status = ?
|
||||
AND result_json IS NOT NULL
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobItemStatus::Completed.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn mark_agent_job_item_failed(
|
||||
&self,
|
||||
job_id: &str,
|
||||
item_id: &str,
|
||||
error_message: &str,
|
||||
) -> anyhow::Result<bool> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE agent_job_items
|
||||
SET
|
||||
status = ?,
|
||||
completed_at = ?,
|
||||
updated_at = ?,
|
||||
last_error = ?,
|
||||
assigned_thread_id = NULL
|
||||
WHERE
|
||||
job_id = ?
|
||||
AND item_id = ?
|
||||
AND status = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobItemStatus::Failed.as_str())
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.bind(error_message)
|
||||
.bind(job_id)
|
||||
.bind(item_id)
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn get_agent_job_progress(&self, job_id: &str) -> anyhow::Result<AgentJobProgress> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) AS total_items,
|
||||
SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS pending_items,
|
||||
SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS running_items,
|
||||
SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS completed_items,
|
||||
SUM(CASE WHEN status = ? THEN 1 ELSE 0 END) AS failed_items
|
||||
FROM agent_job_items
|
||||
WHERE job_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(AgentJobItemStatus::Pending.as_str())
|
||||
.bind(AgentJobItemStatus::Running.as_str())
|
||||
.bind(AgentJobItemStatus::Completed.as_str())
|
||||
.bind(AgentJobItemStatus::Failed.as_str())
|
||||
.bind(job_id)
|
||||
.fetch_one(self.pool.as_ref())
|
||||
.await?;
|
||||
|
||||
let total_items: i64 = row.try_get("total_items")?;
|
||||
let pending_items: Option<i64> = row.try_get("pending_items")?;
|
||||
let running_items: Option<i64> = row.try_get("running_items")?;
|
||||
let completed_items: Option<i64> = row.try_get("completed_items")?;
|
||||
let failed_items: Option<i64> = row.try_get("failed_items")?;
|
||||
Ok(AgentJobProgress {
|
||||
total_items: usize::try_from(total_items).unwrap_or_default(),
|
||||
pending_items: usize::try_from(pending_items.unwrap_or_default()).unwrap_or_default(),
|
||||
running_items: usize::try_from(running_items.unwrap_or_default()).unwrap_or_default(),
|
||||
completed_items: usize::try_from(completed_items.unwrap_or_default())
|
||||
.unwrap_or_default(),
|
||||
failed_items: usize::try_from(failed_items.unwrap_or_default()).unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
311
codex-rs/state/src/runtime/backfill.rs
Normal file
311
codex-rs/state/src/runtime/backfill.rs
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
use super::*;
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn get_backfill_state(&self) -> anyhow::Result<crate::BackfillState> {
|
||||
self.ensure_backfill_state_row().await?;
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT status, last_watermark, last_success_at
|
||||
FROM backfill_state
|
||||
WHERE id = 1
|
||||
"#,
|
||||
)
|
||||
.fetch_one(self.pool.as_ref())
|
||||
.await?;
|
||||
crate::BackfillState::try_from_row(&row)
|
||||
}
|
||||
|
||||
/// Attempt to claim ownership of rollout metadata backfill.
|
||||
///
|
||||
/// Returns `true` when this runtime claimed the backfill worker slot.
|
||||
/// Returns `false` if backfill is already complete or currently owned by a
|
||||
/// non-expired worker.
|
||||
pub async fn try_claim_backfill(&self, lease_seconds: i64) -> anyhow::Result<bool> {
|
||||
self.ensure_backfill_state_row().await?;
|
||||
let now = Utc::now().timestamp();
|
||||
let lease_cutoff = now.saturating_sub(lease_seconds.max(0));
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
UPDATE backfill_state
|
||||
SET status = ?, updated_at = ?
|
||||
WHERE id = 1
|
||||
AND status != ?
|
||||
AND (status != ? OR updated_at <= ?)
|
||||
"#,
|
||||
)
|
||||
.bind(crate::BackfillStatus::Running.as_str())
|
||||
.bind(now)
|
||||
.bind(crate::BackfillStatus::Complete.as_str())
|
||||
.bind(crate::BackfillStatus::Running.as_str())
|
||||
.bind(lease_cutoff)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected() == 1)
|
||||
}
|
||||
|
||||
/// Mark rollout metadata backfill as running.
|
||||
pub async fn mark_backfill_running(&self) -> anyhow::Result<()> {
|
||||
self.ensure_backfill_state_row().await?;
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE backfill_state
|
||||
SET status = ?, updated_at = ?
|
||||
WHERE id = 1
|
||||
"#,
|
||||
)
|
||||
.bind(crate::BackfillStatus::Running.as_str())
|
||||
.bind(Utc::now().timestamp())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist rollout metadata backfill progress.
|
||||
pub async fn checkpoint_backfill(&self, watermark: &str) -> anyhow::Result<()> {
|
||||
self.ensure_backfill_state_row().await?;
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE backfill_state
|
||||
SET status = ?, last_watermark = ?, updated_at = ?
|
||||
WHERE id = 1
|
||||
"#,
|
||||
)
|
||||
.bind(crate::BackfillStatus::Running.as_str())
|
||||
.bind(watermark)
|
||||
.bind(Utc::now().timestamp())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark rollout metadata backfill as complete.
|
||||
pub async fn mark_backfill_complete(&self, last_watermark: Option<&str>) -> anyhow::Result<()> {
|
||||
self.ensure_backfill_state_row().await?;
|
||||
let now = Utc::now().timestamp();
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE backfill_state
|
||||
SET
|
||||
status = ?,
|
||||
last_watermark = COALESCE(?, last_watermark),
|
||||
last_success_at = ?,
|
||||
updated_at = ?
|
||||
WHERE id = 1
|
||||
"#,
|
||||
)
|
||||
.bind(crate::BackfillStatus::Complete.as_str())
|
||||
.bind(last_watermark)
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_backfill_state_row(&self) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO backfill_state (id, status, last_watermark, last_success_at, updated_at)
|
||||
VALUES (?, ?, NULL, NULL, ?)
|
||||
ON CONFLICT(id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(1_i64)
|
||||
.bind(crate::BackfillStatus::Pending.as_str())
|
||||
.bind(Utc::now().timestamp())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StateRuntime;
|
||||
use super::state_db_filename;
|
||||
use super::test_support::unique_temp_dir;
|
||||
use crate::STATE_DB_FILENAME;
|
||||
use crate::STATE_DB_VERSION;
|
||||
use chrono::Utc;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn init_removes_legacy_state_db_files() {
|
||||
let codex_home = unique_temp_dir();
|
||||
tokio::fs::create_dir_all(&codex_home)
|
||||
.await
|
||||
.expect("create codex_home");
|
||||
|
||||
let current_name = state_db_filename();
|
||||
let previous_version = STATE_DB_VERSION.saturating_sub(1);
|
||||
let unversioned_name = format!("{STATE_DB_FILENAME}.sqlite");
|
||||
for suffix in ["", "-wal", "-shm", "-journal"] {
|
||||
let path = codex_home.join(format!("{unversioned_name}{suffix}"));
|
||||
tokio::fs::write(path, b"legacy")
|
||||
.await
|
||||
.expect("write legacy");
|
||||
let old_version_path = codex_home.join(format!(
|
||||
"{STATE_DB_FILENAME}_{previous_version}.sqlite{suffix}"
|
||||
));
|
||||
tokio::fs::write(old_version_path, b"old_version")
|
||||
.await
|
||||
.expect("write old version");
|
||||
}
|
||||
let unrelated_path = codex_home.join("state.sqlite_backup");
|
||||
tokio::fs::write(&unrelated_path, b"keep")
|
||||
.await
|
||||
.expect("write unrelated");
|
||||
let numeric_path = codex_home.join("123");
|
||||
tokio::fs::write(&numeric_path, b"keep")
|
||||
.await
|
||||
.expect("write numeric");
|
||||
|
||||
let _runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
for suffix in ["", "-wal", "-shm", "-journal"] {
|
||||
let legacy_path = codex_home.join(format!("{unversioned_name}{suffix}"));
|
||||
assert_eq!(
|
||||
tokio::fs::try_exists(&legacy_path)
|
||||
.await
|
||||
.expect("check legacy path"),
|
||||
false
|
||||
);
|
||||
let old_version_path = codex_home.join(format!(
|
||||
"{STATE_DB_FILENAME}_{previous_version}.sqlite{suffix}"
|
||||
));
|
||||
assert_eq!(
|
||||
tokio::fs::try_exists(&old_version_path)
|
||||
.await
|
||||
.expect("check old version path"),
|
||||
false
|
||||
);
|
||||
}
|
||||
assert_eq!(
|
||||
tokio::fs::try_exists(codex_home.join(current_name))
|
||||
.await
|
||||
.expect("check new db path"),
|
||||
true
|
||||
);
|
||||
assert_eq!(
|
||||
tokio::fs::try_exists(&unrelated_path)
|
||||
.await
|
||||
.expect("check unrelated path"),
|
||||
true
|
||||
);
|
||||
assert_eq!(
|
||||
tokio::fs::try_exists(&numeric_path)
|
||||
.await
|
||||
.expect("check numeric path"),
|
||||
true
|
||||
);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn backfill_state_persists_progress_and_completion() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let initial = runtime
|
||||
.get_backfill_state()
|
||||
.await
|
||||
.expect("get initial backfill state");
|
||||
assert_eq!(initial.status, crate::BackfillStatus::Pending);
|
||||
assert_eq!(initial.last_watermark, None);
|
||||
assert_eq!(initial.last_success_at, None);
|
||||
|
||||
runtime
|
||||
.mark_backfill_running()
|
||||
.await
|
||||
.expect("mark backfill running");
|
||||
runtime
|
||||
.checkpoint_backfill("sessions/2026/01/27/rollout-a.jsonl")
|
||||
.await
|
||||
.expect("checkpoint backfill");
|
||||
|
||||
let running = runtime
|
||||
.get_backfill_state()
|
||||
.await
|
||||
.expect("get running backfill state");
|
||||
assert_eq!(running.status, crate::BackfillStatus::Running);
|
||||
assert_eq!(
|
||||
running.last_watermark,
|
||||
Some("sessions/2026/01/27/rollout-a.jsonl".to_string())
|
||||
);
|
||||
assert_eq!(running.last_success_at, None);
|
||||
|
||||
runtime
|
||||
.mark_backfill_complete(Some("sessions/2026/01/28/rollout-b.jsonl"))
|
||||
.await
|
||||
.expect("mark backfill complete");
|
||||
let completed = runtime
|
||||
.get_backfill_state()
|
||||
.await
|
||||
.expect("get completed backfill state");
|
||||
assert_eq!(completed.status, crate::BackfillStatus::Complete);
|
||||
assert_eq!(
|
||||
completed.last_watermark,
|
||||
Some("sessions/2026/01/28/rollout-b.jsonl".to_string())
|
||||
);
|
||||
assert!(completed.last_success_at.is_some());
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn backfill_claim_is_singleton_until_stale_and_blocked_when_complete() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let claimed = runtime
|
||||
.try_claim_backfill(3600)
|
||||
.await
|
||||
.expect("initial backfill claim");
|
||||
assert_eq!(claimed, true);
|
||||
|
||||
let duplicate_claim = runtime
|
||||
.try_claim_backfill(3600)
|
||||
.await
|
||||
.expect("duplicate backfill claim");
|
||||
assert_eq!(duplicate_claim, false);
|
||||
|
||||
let stale_updated_at = Utc::now().timestamp().saturating_sub(10_000);
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE backfill_state
|
||||
SET status = ?, updated_at = ?
|
||||
WHERE id = 1
|
||||
"#,
|
||||
)
|
||||
.bind(crate::BackfillStatus::Running.as_str())
|
||||
.bind(stale_updated_at)
|
||||
.execute(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("force stale backfill lease");
|
||||
|
||||
let stale_claim = runtime
|
||||
.try_claim_backfill(10)
|
||||
.await
|
||||
.expect("stale backfill claim");
|
||||
assert_eq!(stale_claim, true);
|
||||
|
||||
runtime
|
||||
.mark_backfill_complete(None)
|
||||
.await
|
||||
.expect("mark complete");
|
||||
let claim_after_complete = runtime
|
||||
.try_claim_backfill(3600)
|
||||
.await
|
||||
.expect("claim after complete");
|
||||
assert_eq!(claim_after_complete, false);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
}
|
||||
715
codex-rs/state/src/runtime/logs.rs
Normal file
715
codex-rs/state/src/runtime/logs.rs
Normal file
|
|
@ -0,0 +1,715 @@
|
|||
use super::*;
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn insert_log(&self, entry: &LogEntry) -> anyhow::Result<()> {
|
||||
self.insert_logs(std::slice::from_ref(entry)).await
|
||||
}
|
||||
|
||||
/// Insert a batch of log entries into the logs table.
|
||||
pub async fn insert_logs(&self, entries: &[LogEntry]) -> anyhow::Result<()> {
|
||||
if entries.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut tx = self.pool.begin().await?;
|
||||
let mut builder = QueryBuilder::<Sqlite>::new(
|
||||
"INSERT INTO logs (ts, ts_nanos, level, target, message, thread_id, process_uuid, module_path, file, line, estimated_bytes) ",
|
||||
);
|
||||
builder.push_values(entries, |mut row, entry| {
|
||||
let estimated_bytes = entry.message.as_ref().map_or(0, String::len) as i64
|
||||
+ entry.level.len() as i64
|
||||
+ entry.target.len() as i64
|
||||
+ entry.module_path.as_ref().map_or(0, String::len) as i64
|
||||
+ entry.file.as_ref().map_or(0, String::len) as i64;
|
||||
row.push_bind(entry.ts)
|
||||
.push_bind(entry.ts_nanos)
|
||||
.push_bind(&entry.level)
|
||||
.push_bind(&entry.target)
|
||||
.push_bind(&entry.message)
|
||||
.push_bind(&entry.thread_id)
|
||||
.push_bind(&entry.process_uuid)
|
||||
.push_bind(&entry.module_path)
|
||||
.push_bind(&entry.file)
|
||||
.push_bind(entry.line)
|
||||
.push_bind(estimated_bytes);
|
||||
});
|
||||
builder.build().execute(&mut *tx).await?;
|
||||
self.prune_logs_after_insert(entries, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enforce per-partition log size caps after a successful batch insert.
|
||||
///
|
||||
/// We maintain two independent budgets:
|
||||
/// - Thread logs: rows with `thread_id IS NOT NULL`, capped per `thread_id`.
|
||||
/// - Threadless process logs: rows with `thread_id IS NULL` ("threadless"),
|
||||
/// capped per `process_uuid` (including `process_uuid IS NULL` as its own
|
||||
/// threadless partition).
|
||||
///
|
||||
/// "Threadless" means the log row is not associated with any conversation
|
||||
/// thread, so retention is keyed by process identity instead.
|
||||
///
|
||||
/// This runs inside the same transaction as the insert so callers never
|
||||
/// observe "inserted but not yet pruned" rows.
|
||||
async fn prune_logs_after_insert(
|
||||
&self,
|
||||
entries: &[LogEntry],
|
||||
tx: &mut SqliteConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
let thread_ids: BTreeSet<&str> = entries
|
||||
.iter()
|
||||
.filter_map(|entry| entry.thread_id.as_deref())
|
||||
.collect();
|
||||
if !thread_ids.is_empty() {
|
||||
// Cheap precheck: only run the heavier window-function prune for
|
||||
// threads that are currently above the cap.
|
||||
let mut over_limit_threads_query =
|
||||
QueryBuilder::<Sqlite>::new("SELECT thread_id FROM logs WHERE thread_id IN (");
|
||||
{
|
||||
let mut separated = over_limit_threads_query.separated(", ");
|
||||
for thread_id in &thread_ids {
|
||||
separated.push_bind(*thread_id);
|
||||
}
|
||||
}
|
||||
over_limit_threads_query.push(") GROUP BY thread_id HAVING SUM(");
|
||||
over_limit_threads_query.push("estimated_bytes");
|
||||
over_limit_threads_query.push(") > ");
|
||||
over_limit_threads_query.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES);
|
||||
let over_limit_thread_ids: Vec<String> = over_limit_threads_query
|
||||
.build()
|
||||
.fetch_all(&mut *tx)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|row| row.try_get("thread_id"))
|
||||
.collect::<Result<_, _>>()?;
|
||||
if !over_limit_thread_ids.is_empty() {
|
||||
// Enforce a strict per-thread cap by deleting every row whose
|
||||
// newest-first cumulative bytes exceed the partition budget.
|
||||
let mut prune_threads = QueryBuilder::<Sqlite>::new(
|
||||
r#"
|
||||
DELETE FROM logs
|
||||
WHERE id IN (
|
||||
SELECT id
|
||||
FROM (
|
||||
SELECT
|
||||
id,
|
||||
SUM(
|
||||
"#,
|
||||
);
|
||||
prune_threads.push("estimated_bytes");
|
||||
prune_threads.push(
|
||||
r#"
|
||||
) OVER (
|
||||
PARTITION BY thread_id
|
||||
ORDER BY ts DESC, ts_nanos DESC, id DESC
|
||||
) AS cumulative_bytes
|
||||
FROM logs
|
||||
WHERE thread_id IN (
|
||||
"#,
|
||||
);
|
||||
{
|
||||
let mut separated = prune_threads.separated(", ");
|
||||
for thread_id in &over_limit_thread_ids {
|
||||
separated.push_bind(thread_id);
|
||||
}
|
||||
}
|
||||
prune_threads.push(
|
||||
r#"
|
||||
)
|
||||
)
|
||||
WHERE cumulative_bytes >
|
||||
"#,
|
||||
);
|
||||
prune_threads.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES);
|
||||
prune_threads.push("\n)");
|
||||
prune_threads.build().execute(&mut *tx).await?;
|
||||
}
|
||||
}
|
||||
|
||||
let threadless_process_uuids: BTreeSet<&str> = entries
|
||||
.iter()
|
||||
.filter(|entry| entry.thread_id.is_none())
|
||||
.filter_map(|entry| entry.process_uuid.as_deref())
|
||||
.collect();
|
||||
let has_threadless_null_process_uuid = entries
|
||||
.iter()
|
||||
.any(|entry| entry.thread_id.is_none() && entry.process_uuid.is_none());
|
||||
if !threadless_process_uuids.is_empty() {
|
||||
// Threadless logs are budgeted separately per process UUID.
|
||||
let mut over_limit_processes_query = QueryBuilder::<Sqlite>::new(
|
||||
"SELECT process_uuid FROM logs WHERE thread_id IS NULL AND process_uuid IN (",
|
||||
);
|
||||
{
|
||||
let mut separated = over_limit_processes_query.separated(", ");
|
||||
for process_uuid in &threadless_process_uuids {
|
||||
separated.push_bind(*process_uuid);
|
||||
}
|
||||
}
|
||||
over_limit_processes_query.push(") GROUP BY process_uuid HAVING SUM(");
|
||||
over_limit_processes_query.push("estimated_bytes");
|
||||
over_limit_processes_query.push(") > ");
|
||||
over_limit_processes_query.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES);
|
||||
let over_limit_process_uuids: Vec<String> = over_limit_processes_query
|
||||
.build()
|
||||
.fetch_all(&mut *tx)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|row| row.try_get("process_uuid"))
|
||||
.collect::<Result<_, _>>()?;
|
||||
if !over_limit_process_uuids.is_empty() {
|
||||
// Same strict cap policy as thread pruning, but only for
|
||||
// threadless rows in the affected process UUIDs.
|
||||
let mut prune_threadless_process_logs = QueryBuilder::<Sqlite>::new(
|
||||
r#"
|
||||
DELETE FROM logs
|
||||
WHERE id IN (
|
||||
SELECT id
|
||||
FROM (
|
||||
SELECT
|
||||
id,
|
||||
SUM(
|
||||
"#,
|
||||
);
|
||||
prune_threadless_process_logs.push("estimated_bytes");
|
||||
prune_threadless_process_logs.push(
|
||||
r#"
|
||||
) OVER (
|
||||
PARTITION BY process_uuid
|
||||
ORDER BY ts DESC, ts_nanos DESC, id DESC
|
||||
) AS cumulative_bytes
|
||||
FROM logs
|
||||
WHERE thread_id IS NULL
|
||||
AND process_uuid IN (
|
||||
"#,
|
||||
);
|
||||
{
|
||||
let mut separated = prune_threadless_process_logs.separated(", ");
|
||||
for process_uuid in &over_limit_process_uuids {
|
||||
separated.push_bind(process_uuid);
|
||||
}
|
||||
}
|
||||
prune_threadless_process_logs.push(
|
||||
r#"
|
||||
)
|
||||
)
|
||||
WHERE cumulative_bytes >
|
||||
"#,
|
||||
);
|
||||
prune_threadless_process_logs.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES);
|
||||
prune_threadless_process_logs.push("\n)");
|
||||
prune_threadless_process_logs
|
||||
.build()
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
if has_threadless_null_process_uuid {
|
||||
// Rows without a process UUID still need a cap; treat NULL as its
|
||||
// own threadless partition.
|
||||
let mut null_process_usage_query = QueryBuilder::<Sqlite>::new("SELECT SUM(");
|
||||
null_process_usage_query.push("estimated_bytes");
|
||||
null_process_usage_query.push(
|
||||
") AS total_bytes FROM logs WHERE thread_id IS NULL AND process_uuid IS NULL",
|
||||
);
|
||||
let total_null_process_bytes: Option<i64> = null_process_usage_query
|
||||
.build()
|
||||
.fetch_one(&mut *tx)
|
||||
.await?
|
||||
.try_get("total_bytes")?;
|
||||
|
||||
if total_null_process_bytes.unwrap_or(0) > LOG_PARTITION_SIZE_LIMIT_BYTES {
|
||||
let mut prune_threadless_null_process_logs = QueryBuilder::<Sqlite>::new(
|
||||
r#"
|
||||
DELETE FROM logs
|
||||
WHERE id IN (
|
||||
SELECT id
|
||||
FROM (
|
||||
SELECT
|
||||
id,
|
||||
SUM(
|
||||
"#,
|
||||
);
|
||||
prune_threadless_null_process_logs.push("estimated_bytes");
|
||||
prune_threadless_null_process_logs.push(
|
||||
r#"
|
||||
) OVER (
|
||||
PARTITION BY process_uuid
|
||||
ORDER BY ts DESC, ts_nanos DESC, id DESC
|
||||
) AS cumulative_bytes
|
||||
FROM logs
|
||||
WHERE thread_id IS NULL
|
||||
AND process_uuid IS NULL
|
||||
)
|
||||
WHERE cumulative_bytes >
|
||||
"#,
|
||||
);
|
||||
prune_threadless_null_process_logs.push_bind(LOG_PARTITION_SIZE_LIMIT_BYTES);
|
||||
prune_threadless_null_process_logs.push("\n)");
|
||||
prune_threadless_null_process_logs
|
||||
.build()
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn delete_logs_before(&self, cutoff_ts: i64) -> anyhow::Result<u64> {
|
||||
let result = sqlx::query("DELETE FROM logs WHERE ts < ?")
|
||||
.bind(cutoff_ts)
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Query logs with optional filters.
|
||||
pub async fn query_logs(&self, query: &LogQuery) -> anyhow::Result<Vec<LogRow>> {
|
||||
let mut builder = QueryBuilder::<Sqlite>::new(
|
||||
"SELECT id, ts, ts_nanos, level, target, message, thread_id, process_uuid, file, line FROM logs WHERE 1 = 1",
|
||||
);
|
||||
push_log_filters(&mut builder, query);
|
||||
if query.descending {
|
||||
builder.push(" ORDER BY id DESC");
|
||||
} else {
|
||||
builder.push(" ORDER BY id ASC");
|
||||
}
|
||||
if let Some(limit) = query.limit {
|
||||
builder.push(" LIMIT ").push_bind(limit as i64);
|
||||
}
|
||||
|
||||
let rows = builder
|
||||
.build_query_as::<LogRow>()
|
||||
.fetch_all(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
/// Return the max log id matching optional filters.
|
||||
pub async fn max_log_id(&self, query: &LogQuery) -> anyhow::Result<i64> {
|
||||
let mut builder =
|
||||
QueryBuilder::<Sqlite>::new("SELECT MAX(id) AS max_id FROM logs WHERE 1 = 1");
|
||||
push_log_filters(&mut builder, query);
|
||||
let row = builder.build().fetch_one(self.pool.as_ref()).await?;
|
||||
let max_id: Option<i64> = row.try_get("max_id")?;
|
||||
Ok(max_id.unwrap_or(0))
|
||||
}
|
||||
}
|
||||
|
||||
fn push_log_filters<'a>(builder: &mut QueryBuilder<'a, Sqlite>, query: &'a LogQuery) {
|
||||
if let Some(level_upper) = query.level_upper.as_ref() {
|
||||
builder
|
||||
.push(" AND UPPER(level) = ")
|
||||
.push_bind(level_upper.as_str());
|
||||
}
|
||||
if let Some(from_ts) = query.from_ts {
|
||||
builder.push(" AND ts >= ").push_bind(from_ts);
|
||||
}
|
||||
if let Some(to_ts) = query.to_ts {
|
||||
builder.push(" AND ts <= ").push_bind(to_ts);
|
||||
}
|
||||
push_like_filters(builder, "module_path", &query.module_like);
|
||||
push_like_filters(builder, "file", &query.file_like);
|
||||
let has_thread_filter = !query.thread_ids.is_empty() || query.include_threadless;
|
||||
if has_thread_filter {
|
||||
builder.push(" AND (");
|
||||
let mut needs_or = false;
|
||||
for thread_id in &query.thread_ids {
|
||||
if needs_or {
|
||||
builder.push(" OR ");
|
||||
}
|
||||
builder.push("thread_id = ").push_bind(thread_id.as_str());
|
||||
needs_or = true;
|
||||
}
|
||||
if query.include_threadless {
|
||||
if needs_or {
|
||||
builder.push(" OR ");
|
||||
}
|
||||
builder.push("thread_id IS NULL");
|
||||
}
|
||||
builder.push(")");
|
||||
}
|
||||
if let Some(after_id) = query.after_id {
|
||||
builder.push(" AND id > ").push_bind(after_id);
|
||||
}
|
||||
if let Some(search) = query.search.as_ref() {
|
||||
builder.push(" AND INSTR(message, ");
|
||||
builder.push_bind(search.as_str());
|
||||
builder.push(") > 0");
|
||||
}
|
||||
}
|
||||
|
||||
fn push_like_filters<'a>(
|
||||
builder: &mut QueryBuilder<'a, Sqlite>,
|
||||
column: &str,
|
||||
filters: &'a [String],
|
||||
) {
|
||||
if filters.is_empty() {
|
||||
return;
|
||||
}
|
||||
builder.push(" AND (");
|
||||
for (idx, filter) in filters.iter().enumerate() {
|
||||
if idx > 0 {
|
||||
builder.push(" OR ");
|
||||
}
|
||||
builder
|
||||
.push(column)
|
||||
.push(" LIKE '%' || ")
|
||||
.push_bind(filter.as_str())
|
||||
.push(" || '%'");
|
||||
}
|
||||
builder.push(")");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StateRuntime;
|
||||
use super::test_support::unique_temp_dir;
|
||||
use crate::LogEntry;
|
||||
use crate::LogQuery;
|
||||
use pretty_assertions::assert_eq;
|
||||
#[tokio::test]
|
||||
async fn query_logs_with_search_matches_substring() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
runtime
|
||||
.insert_logs(&[
|
||||
LogEntry {
|
||||
ts: 1_700_000_001,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some("alpha".to_string()),
|
||||
thread_id: Some("thread-1".to_string()),
|
||||
process_uuid: None,
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(42),
|
||||
module_path: None,
|
||||
},
|
||||
LogEntry {
|
||||
ts: 1_700_000_002,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some("alphabet".to_string()),
|
||||
thread_id: Some("thread-1".to_string()),
|
||||
process_uuid: None,
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(43),
|
||||
module_path: None,
|
||||
},
|
||||
])
|
||||
.await
|
||||
.expect("insert test logs");
|
||||
|
||||
let rows = runtime
|
||||
.query_logs(&LogQuery {
|
||||
search: Some("alphab".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("query matching logs");
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].message.as_deref(), Some("alphabet"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_logs_prunes_old_rows_when_thread_exceeds_size_limit() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let six_mebibytes = "a".repeat(6 * 1024 * 1024);
|
||||
runtime
|
||||
.insert_logs(&[
|
||||
LogEntry {
|
||||
ts: 1,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(six_mebibytes.clone()),
|
||||
thread_id: Some("thread-1".to_string()),
|
||||
process_uuid: Some("proc-1".to_string()),
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(1),
|
||||
module_path: Some("mod".to_string()),
|
||||
},
|
||||
LogEntry {
|
||||
ts: 2,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(six_mebibytes.clone()),
|
||||
thread_id: Some("thread-1".to_string()),
|
||||
process_uuid: Some("proc-1".to_string()),
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(2),
|
||||
module_path: Some("mod".to_string()),
|
||||
},
|
||||
])
|
||||
.await
|
||||
.expect("insert test logs");
|
||||
|
||||
let rows = runtime
|
||||
.query_logs(&LogQuery {
|
||||
thread_ids: vec!["thread-1".to_string()],
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("query thread logs");
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].ts, 2);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_logs_prunes_single_thread_row_when_it_exceeds_size_limit() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let eleven_mebibytes = "d".repeat(11 * 1024 * 1024);
|
||||
runtime
|
||||
.insert_logs(&[LogEntry {
|
||||
ts: 1,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(eleven_mebibytes),
|
||||
thread_id: Some("thread-oversized".to_string()),
|
||||
process_uuid: Some("proc-1".to_string()),
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(1),
|
||||
module_path: Some("mod".to_string()),
|
||||
}])
|
||||
.await
|
||||
.expect("insert test log");
|
||||
|
||||
let rows = runtime
|
||||
.query_logs(&LogQuery {
|
||||
thread_ids: vec!["thread-oversized".to_string()],
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("query thread logs");
|
||||
|
||||
assert!(rows.is_empty());
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_logs_prunes_threadless_rows_per_process_uuid_only() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let six_mebibytes = "b".repeat(6 * 1024 * 1024);
|
||||
runtime
|
||||
.insert_logs(&[
|
||||
LogEntry {
|
||||
ts: 1,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(six_mebibytes.clone()),
|
||||
thread_id: None,
|
||||
process_uuid: Some("proc-1".to_string()),
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(1),
|
||||
module_path: Some("mod".to_string()),
|
||||
},
|
||||
LogEntry {
|
||||
ts: 2,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(six_mebibytes.clone()),
|
||||
thread_id: None,
|
||||
process_uuid: Some("proc-1".to_string()),
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(2),
|
||||
module_path: Some("mod".to_string()),
|
||||
},
|
||||
LogEntry {
|
||||
ts: 3,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(six_mebibytes),
|
||||
thread_id: Some("thread-1".to_string()),
|
||||
process_uuid: Some("proc-1".to_string()),
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(3),
|
||||
module_path: Some("mod".to_string()),
|
||||
},
|
||||
])
|
||||
.await
|
||||
.expect("insert test logs");
|
||||
|
||||
let rows = runtime
|
||||
.query_logs(&LogQuery {
|
||||
thread_ids: vec!["thread-1".to_string()],
|
||||
include_threadless: true,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("query thread and threadless logs");
|
||||
|
||||
let mut timestamps: Vec<i64> = rows.into_iter().map(|row| row.ts).collect();
|
||||
timestamps.sort_unstable();
|
||||
assert_eq!(timestamps, vec![2, 3]);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_logs_prunes_single_threadless_process_row_when_it_exceeds_size_limit() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let eleven_mebibytes = "e".repeat(11 * 1024 * 1024);
|
||||
runtime
|
||||
.insert_logs(&[LogEntry {
|
||||
ts: 1,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(eleven_mebibytes),
|
||||
thread_id: None,
|
||||
process_uuid: Some("proc-oversized".to_string()),
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(1),
|
||||
module_path: Some("mod".to_string()),
|
||||
}])
|
||||
.await
|
||||
.expect("insert test log");
|
||||
|
||||
let rows = runtime
|
||||
.query_logs(&LogQuery {
|
||||
include_threadless: true,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("query threadless logs");
|
||||
|
||||
assert!(rows.is_empty());
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_logs_prunes_threadless_rows_with_null_process_uuid() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let six_mebibytes = "c".repeat(6 * 1024 * 1024);
|
||||
runtime
|
||||
.insert_logs(&[
|
||||
LogEntry {
|
||||
ts: 1,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(six_mebibytes.clone()),
|
||||
thread_id: None,
|
||||
process_uuid: None,
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(1),
|
||||
module_path: Some("mod".to_string()),
|
||||
},
|
||||
LogEntry {
|
||||
ts: 2,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(six_mebibytes),
|
||||
thread_id: None,
|
||||
process_uuid: None,
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(2),
|
||||
module_path: Some("mod".to_string()),
|
||||
},
|
||||
LogEntry {
|
||||
ts: 3,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some("small".to_string()),
|
||||
thread_id: None,
|
||||
process_uuid: Some("proc-1".to_string()),
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(3),
|
||||
module_path: Some("mod".to_string()),
|
||||
},
|
||||
])
|
||||
.await
|
||||
.expect("insert test logs");
|
||||
|
||||
let rows = runtime
|
||||
.query_logs(&LogQuery {
|
||||
include_threadless: true,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("query threadless logs");
|
||||
|
||||
let mut timestamps: Vec<i64> = rows.into_iter().map(|row| row.ts).collect();
|
||||
timestamps.sort_unstable();
|
||||
assert_eq!(timestamps, vec![2, 3]);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_logs_prunes_single_threadless_null_process_row_when_it_exceeds_limit() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
let eleven_mebibytes = "f".repeat(11 * 1024 * 1024);
|
||||
runtime
|
||||
.insert_logs(&[LogEntry {
|
||||
ts: 1,
|
||||
ts_nanos: 0,
|
||||
level: "INFO".to_string(),
|
||||
target: "cli".to_string(),
|
||||
message: Some(eleven_mebibytes),
|
||||
thread_id: None,
|
||||
process_uuid: None,
|
||||
file: Some("main.rs".to_string()),
|
||||
line: Some(1),
|
||||
module_path: Some("mod".to_string()),
|
||||
}])
|
||||
.await
|
||||
.expect("insert test log");
|
||||
|
||||
let rows = runtime
|
||||
.query_logs(&LogQuery {
|
||||
include_threadless: true,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("query threadless logs");
|
||||
|
||||
assert!(rows.is_empty());
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
64
codex-rs/state/src/runtime/test_support.rs
Normal file
64
codex-rs/state/src/runtime/test_support.rs
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
#[cfg(test)]
|
||||
use chrono::DateTime;
|
||||
#[cfg(test)]
|
||||
use chrono::Utc;
|
||||
#[cfg(test)]
|
||||
use codex_protocol::ThreadId;
|
||||
#[cfg(test)]
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
#[cfg(test)]
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
#[cfg(test)]
|
||||
use std::path::Path;
|
||||
#[cfg(test)]
|
||||
use std::path::PathBuf;
|
||||
#[cfg(test)]
|
||||
use std::time::SystemTime;
|
||||
#[cfg(test)]
|
||||
use std::time::UNIX_EPOCH;
|
||||
#[cfg(test)]
|
||||
use uuid::Uuid;
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::ThreadMetadata;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(super) fn unique_temp_dir() -> PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_or(0, |duration| duration.as_nanos());
|
||||
std::env::temp_dir().join(format!(
|
||||
"codex-state-runtime-test-{nanos}-{}",
|
||||
Uuid::new_v4()
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(super) fn test_thread_metadata(
|
||||
codex_home: &Path,
|
||||
thread_id: ThreadId,
|
||||
cwd: PathBuf,
|
||||
) -> ThreadMetadata {
|
||||
let now = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).expect("timestamp");
|
||||
ThreadMetadata {
|
||||
id: thread_id,
|
||||
rollout_path: codex_home.join(format!("rollout-{thread_id}.jsonl")),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
source: "cli".to_string(),
|
||||
agent_nickname: None,
|
||||
agent_role: None,
|
||||
model_provider: "test-provider".to_string(),
|
||||
cwd,
|
||||
cli_version: "0.0.0".to_string(),
|
||||
title: String::new(),
|
||||
sandbox_policy: crate::extract::enum_to_string(&SandboxPolicy::new_read_only_policy()),
|
||||
approval_mode: crate::extract::enum_to_string(&AskForApproval::OnRequest),
|
||||
tokens_used: 0,
|
||||
first_user_message: Some("hello".to_string()),
|
||||
archived_at: None,
|
||||
git_sha: None,
|
||||
git_branch: None,
|
||||
git_origin_url: None,
|
||||
}
|
||||
}
|
||||
496
codex-rs/state/src/runtime/threads.rs
Normal file
496
codex-rs/state/src/runtime/threads.rs
Normal file
|
|
@ -0,0 +1,496 @@
|
|||
use super::*;
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
rollout_path,
|
||||
created_at,
|
||||
updated_at,
|
||||
source,
|
||||
agent_nickname,
|
||||
agent_role,
|
||||
model_provider,
|
||||
cwd,
|
||||
cli_version,
|
||||
title,
|
||||
sandbox_policy,
|
||||
approval_mode,
|
||||
tokens_used,
|
||||
first_user_message,
|
||||
archived_at,
|
||||
git_sha,
|
||||
git_branch,
|
||||
git_origin_url
|
||||
FROM threads
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(id.to_string())
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Get dynamic tools for a thread, if present.
|
||||
pub async fn get_dynamic_tools(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
) -> anyhow::Result<Option<Vec<DynamicToolSpec>>> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT name, description, input_schema
|
||||
FROM thread_dynamic_tools
|
||||
WHERE thread_id = ?
|
||||
ORDER BY position ASC
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id.to_string())
|
||||
.fetch_all(self.pool.as_ref())
|
||||
.await?;
|
||||
if rows.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
let mut tools = Vec::with_capacity(rows.len());
|
||||
for row in rows {
|
||||
let input_schema: String = row.try_get("input_schema")?;
|
||||
let input_schema = serde_json::from_str::<Value>(input_schema.as_str())?;
|
||||
tools.push(DynamicToolSpec {
|
||||
name: row.try_get("name")?,
|
||||
description: row.try_get("description")?,
|
||||
input_schema,
|
||||
});
|
||||
}
|
||||
Ok(Some(tools))
|
||||
}
|
||||
|
||||
/// Find a rollout path by thread id using the underlying database.
|
||||
pub async fn find_rollout_path_by_id(
|
||||
&self,
|
||||
id: ThreadId,
|
||||
archived_only: Option<bool>,
|
||||
) -> anyhow::Result<Option<PathBuf>> {
|
||||
let mut builder =
|
||||
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
|
||||
builder.push_bind(id.to_string());
|
||||
match archived_only {
|
||||
Some(true) => {
|
||||
builder.push(" AND archived = 1");
|
||||
}
|
||||
Some(false) => {
|
||||
builder.push(" AND archived = 0");
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
|
||||
Ok(row
|
||||
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
|
||||
.map(PathBuf::from))
|
||||
}
|
||||
|
||||
/// List threads using the underlying database.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn list_threads(
|
||||
&self,
|
||||
page_size: usize,
|
||||
anchor: Option<&crate::Anchor>,
|
||||
sort_key: crate::SortKey,
|
||||
allowed_sources: &[String],
|
||||
model_providers: Option<&[String]>,
|
||||
archived_only: bool,
|
||||
search_term: Option<&str>,
|
||||
) -> anyhow::Result<crate::ThreadsPage> {
|
||||
let limit = page_size.saturating_add(1);
|
||||
|
||||
let mut builder = QueryBuilder::<Sqlite>::new(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
rollout_path,
|
||||
created_at,
|
||||
updated_at,
|
||||
source,
|
||||
agent_nickname,
|
||||
agent_role,
|
||||
model_provider,
|
||||
cwd,
|
||||
cli_version,
|
||||
title,
|
||||
sandbox_policy,
|
||||
approval_mode,
|
||||
tokens_used,
|
||||
first_user_message,
|
||||
archived_at,
|
||||
git_sha,
|
||||
git_branch,
|
||||
git_origin_url
|
||||
FROM threads
|
||||
"#,
|
||||
);
|
||||
push_thread_filters(
|
||||
&mut builder,
|
||||
archived_only,
|
||||
allowed_sources,
|
||||
model_providers,
|
||||
anchor,
|
||||
sort_key,
|
||||
search_term,
|
||||
);
|
||||
push_thread_order_and_limit(&mut builder, sort_key, limit);
|
||||
|
||||
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
|
||||
let mut items = rows
|
||||
.into_iter()
|
||||
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let num_scanned_rows = items.len();
|
||||
let next_anchor = if items.len() > page_size {
|
||||
items.pop();
|
||||
items
|
||||
.last()
|
||||
.and_then(|item| anchor_from_item(item, sort_key))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(ThreadsPage {
|
||||
items,
|
||||
next_anchor,
|
||||
num_scanned_rows,
|
||||
})
|
||||
}
|
||||
|
||||
/// List thread ids using the underlying database (no rollout scanning).
|
||||
pub async fn list_thread_ids(
|
||||
&self,
|
||||
limit: usize,
|
||||
anchor: Option<&crate::Anchor>,
|
||||
sort_key: crate::SortKey,
|
||||
allowed_sources: &[String],
|
||||
model_providers: Option<&[String]>,
|
||||
archived_only: bool,
|
||||
) -> anyhow::Result<Vec<ThreadId>> {
|
||||
let mut builder = QueryBuilder::<Sqlite>::new("SELECT id FROM threads");
|
||||
push_thread_filters(
|
||||
&mut builder,
|
||||
archived_only,
|
||||
allowed_sources,
|
||||
model_providers,
|
||||
anchor,
|
||||
sort_key,
|
||||
None,
|
||||
);
|
||||
push_thread_order_and_limit(&mut builder, sort_key, limit);
|
||||
|
||||
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
|
||||
rows.into_iter()
|
||||
.map(|row| {
|
||||
let id: String = row.try_get("id")?;
|
||||
Ok(ThreadId::try_from(id)?)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Insert or replace thread metadata directly.
|
||||
pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO threads (
|
||||
id,
|
||||
rollout_path,
|
||||
created_at,
|
||||
updated_at,
|
||||
source,
|
||||
agent_nickname,
|
||||
agent_role,
|
||||
model_provider,
|
||||
cwd,
|
||||
cli_version,
|
||||
title,
|
||||
sandbox_policy,
|
||||
approval_mode,
|
||||
tokens_used,
|
||||
first_user_message,
|
||||
archived,
|
||||
archived_at,
|
||||
git_sha,
|
||||
git_branch,
|
||||
git_origin_url
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
rollout_path = excluded.rollout_path,
|
||||
created_at = excluded.created_at,
|
||||
updated_at = excluded.updated_at,
|
||||
source = excluded.source,
|
||||
agent_nickname = excluded.agent_nickname,
|
||||
agent_role = excluded.agent_role,
|
||||
model_provider = excluded.model_provider,
|
||||
cwd = excluded.cwd,
|
||||
cli_version = excluded.cli_version,
|
||||
title = excluded.title,
|
||||
sandbox_policy = excluded.sandbox_policy,
|
||||
approval_mode = excluded.approval_mode,
|
||||
tokens_used = excluded.tokens_used,
|
||||
first_user_message = excluded.first_user_message,
|
||||
archived = excluded.archived,
|
||||
archived_at = excluded.archived_at,
|
||||
git_sha = excluded.git_sha,
|
||||
git_branch = excluded.git_branch,
|
||||
git_origin_url = excluded.git_origin_url
|
||||
"#,
|
||||
)
|
||||
.bind(metadata.id.to_string())
|
||||
.bind(metadata.rollout_path.display().to_string())
|
||||
.bind(datetime_to_epoch_seconds(metadata.created_at))
|
||||
.bind(datetime_to_epoch_seconds(metadata.updated_at))
|
||||
.bind(metadata.source.as_str())
|
||||
.bind(metadata.agent_nickname.as_deref())
|
||||
.bind(metadata.agent_role.as_deref())
|
||||
.bind(metadata.model_provider.as_str())
|
||||
.bind(metadata.cwd.display().to_string())
|
||||
.bind(metadata.cli_version.as_str())
|
||||
.bind(metadata.title.as_str())
|
||||
.bind(metadata.sandbox_policy.as_str())
|
||||
.bind(metadata.approval_mode.as_str())
|
||||
.bind(metadata.tokens_used)
|
||||
.bind(metadata.first_user_message.as_deref().unwrap_or_default())
|
||||
.bind(metadata.archived_at.is_some())
|
||||
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
|
||||
.bind(metadata.git_sha.as_deref())
|
||||
.bind(metadata.git_branch.as_deref())
|
||||
.bind(metadata.git_origin_url.as_deref())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist dynamic tools for a thread if none have been stored yet.
|
||||
///
|
||||
/// Dynamic tools are defined at thread start and should not change afterward.
|
||||
/// This only writes the first time we see tools for a given thread.
|
||||
pub async fn persist_dynamic_tools(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
tools: Option<&[DynamicToolSpec]>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(tools) = tools else {
|
||||
return Ok(());
|
||||
};
|
||||
if tools.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let thread_id = thread_id.to_string();
|
||||
let mut tx = self.pool.begin().await?;
|
||||
for (idx, tool) in tools.iter().enumerate() {
|
||||
let position = i64::try_from(idx).unwrap_or(i64::MAX);
|
||||
let input_schema = serde_json::to_string(&tool.input_schema)?;
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO thread_dynamic_tools (
|
||||
thread_id,
|
||||
position,
|
||||
name,
|
||||
description,
|
||||
input_schema
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(thread_id, position) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(thread_id.as_str())
|
||||
.bind(position)
|
||||
.bind(tool.name.as_str())
|
||||
.bind(tool.description.as_str())
|
||||
.bind(input_schema)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply rollout items incrementally using the underlying database.
|
||||
pub async fn apply_rollout_items(
|
||||
&self,
|
||||
builder: &ThreadMetadataBuilder,
|
||||
items: &[RolloutItem],
|
||||
otel: Option<&OtelManager>,
|
||||
) -> anyhow::Result<()> {
|
||||
if items.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let mut metadata = self
|
||||
.get_thread(builder.id)
|
||||
.await?
|
||||
.unwrap_or_else(|| builder.build(&self.default_provider));
|
||||
metadata.rollout_path = builder.rollout_path.clone();
|
||||
for item in items {
|
||||
apply_rollout_item(&mut metadata, item, &self.default_provider);
|
||||
}
|
||||
if let Some(updated_at) = file_modified_time_utc(builder.rollout_path.as_path()).await {
|
||||
metadata.updated_at = updated_at;
|
||||
}
|
||||
// Keep the thread upsert before dynamic tools to satisfy the foreign key constraint:
|
||||
// thread_dynamic_tools.thread_id -> threads.id.
|
||||
if let Err(err) = self.upsert_thread(&metadata).await {
|
||||
if let Some(otel) = otel {
|
||||
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
let dynamic_tools = extract_dynamic_tools(items);
|
||||
if let Some(dynamic_tools) = dynamic_tools
|
||||
&& let Err(err) = self
|
||||
.persist_dynamic_tools(builder.id, dynamic_tools.as_deref())
|
||||
.await
|
||||
{
|
||||
if let Some(otel) = otel {
|
||||
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "persist_dynamic_tools")]);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark a thread as archived using the underlying database.
|
||||
pub async fn mark_archived(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
rollout_path: &Path,
|
||||
archived_at: DateTime<Utc>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(mut metadata) = self.get_thread(thread_id).await? else {
|
||||
return Ok(());
|
||||
};
|
||||
metadata.archived_at = Some(archived_at);
|
||||
metadata.rollout_path = rollout_path.to_path_buf();
|
||||
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
|
||||
metadata.updated_at = updated_at;
|
||||
}
|
||||
if metadata.id != thread_id {
|
||||
warn!(
|
||||
"thread id mismatch during archive: expected {thread_id}, got {}",
|
||||
metadata.id
|
||||
);
|
||||
}
|
||||
self.upsert_thread(&metadata).await
|
||||
}
|
||||
|
||||
/// Mark a thread as unarchived using the underlying database.
|
||||
pub async fn mark_unarchived(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
rollout_path: &Path,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(mut metadata) = self.get_thread(thread_id).await? else {
|
||||
return Ok(());
|
||||
};
|
||||
metadata.archived_at = None;
|
||||
metadata.rollout_path = rollout_path.to_path_buf();
|
||||
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
|
||||
metadata.updated_at = updated_at;
|
||||
}
|
||||
if metadata.id != thread_id {
|
||||
warn!(
|
||||
"thread id mismatch during unarchive: expected {thread_id}, got {}",
|
||||
metadata.id
|
||||
);
|
||||
}
|
||||
self.upsert_thread(&metadata).await
|
||||
}
|
||||
|
||||
/// Delete a thread metadata row by id.
|
||||
pub async fn delete_thread(&self, thread_id: ThreadId) -> anyhow::Result<u64> {
|
||||
let result = sqlx::query("DELETE FROM threads WHERE id = ?")
|
||||
.bind(thread_id.to_string())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn extract_dynamic_tools(items: &[RolloutItem]) -> Option<Option<Vec<DynamicToolSpec>>> {
|
||||
items.iter().find_map(|item| match item {
|
||||
RolloutItem::SessionMeta(meta_line) => Some(meta_line.meta.dynamic_tools.clone()),
|
||||
RolloutItem::ResponseItem(_)
|
||||
| RolloutItem::Compacted(_)
|
||||
| RolloutItem::TurnContext(_)
|
||||
| RolloutItem::EventMsg(_) => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn push_thread_filters<'a>(
|
||||
builder: &mut QueryBuilder<'a, Sqlite>,
|
||||
archived_only: bool,
|
||||
allowed_sources: &'a [String],
|
||||
model_providers: Option<&'a [String]>,
|
||||
anchor: Option<&crate::Anchor>,
|
||||
sort_key: SortKey,
|
||||
search_term: Option<&'a str>,
|
||||
) {
|
||||
builder.push(" WHERE 1 = 1");
|
||||
if archived_only {
|
||||
builder.push(" AND archived = 1");
|
||||
} else {
|
||||
builder.push(" AND archived = 0");
|
||||
}
|
||||
builder.push(" AND first_user_message <> ''");
|
||||
if !allowed_sources.is_empty() {
|
||||
builder.push(" AND source IN (");
|
||||
let mut separated = builder.separated(", ");
|
||||
for source in allowed_sources {
|
||||
separated.push_bind(source);
|
||||
}
|
||||
separated.push_unseparated(")");
|
||||
}
|
||||
if let Some(model_providers) = model_providers
|
||||
&& !model_providers.is_empty()
|
||||
{
|
||||
builder.push(" AND model_provider IN (");
|
||||
let mut separated = builder.separated(", ");
|
||||
for provider in model_providers {
|
||||
separated.push_bind(provider);
|
||||
}
|
||||
separated.push_unseparated(")");
|
||||
}
|
||||
if let Some(search_term) = search_term {
|
||||
builder.push(" AND instr(title, ");
|
||||
builder.push_bind(search_term);
|
||||
builder.push(") > 0");
|
||||
}
|
||||
if let Some(anchor) = anchor {
|
||||
let anchor_ts = datetime_to_epoch_seconds(anchor.ts);
|
||||
let column = match sort_key {
|
||||
SortKey::CreatedAt => "created_at",
|
||||
SortKey::UpdatedAt => "updated_at",
|
||||
};
|
||||
builder.push(" AND (");
|
||||
builder.push(column);
|
||||
builder.push(" < ");
|
||||
builder.push_bind(anchor_ts);
|
||||
builder.push(" OR (");
|
||||
builder.push(column);
|
||||
builder.push(" = ");
|
||||
builder.push_bind(anchor_ts);
|
||||
builder.push(" AND id < ");
|
||||
builder.push_bind(anchor.id.to_string());
|
||||
builder.push("))");
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn push_thread_order_and_limit(
|
||||
builder: &mut QueryBuilder<'_, Sqlite>,
|
||||
sort_key: SortKey,
|
||||
limit: usize,
|
||||
) {
|
||||
let order_column = match sort_key {
|
||||
SortKey::CreatedAt => "created_at",
|
||||
SortKey::UpdatedAt => "updated_at",
|
||||
};
|
||||
builder.push(" ORDER BY ");
|
||||
builder.push(order_column);
|
||||
builder.push(" DESC, id DESC");
|
||||
builder.push(" LIMIT ");
|
||||
builder.push_bind(limit as i64);
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue