feat: gen memories config (#12999)
This commit is contained in:
parent
a63d8bd569
commit
bbd237348d
6 changed files with 98 additions and 7 deletions
|
|
@ -613,6 +613,10 @@
|
|||
"additionalProperties": false,
|
||||
"description": "Memories settings loaded from config.toml.",
|
||||
"properties": {
|
||||
"generate_memories": {
|
||||
"description": "When `false`, newly created threads are stored with `memory_mode = \"disabled\"` in the state DB.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"max_raw_memories_for_global": {
|
||||
"description": "Maximum number of recent raw memories retained for global consolidation.",
|
||||
"format": "uint",
|
||||
|
|
|
|||
|
|
@ -2490,6 +2490,7 @@ persistence = "none"
|
|||
|
||||
let memories = r#"
|
||||
[memories]
|
||||
generate_memories = false
|
||||
use_memories = false
|
||||
max_raw_memories_for_global = 512
|
||||
max_unused_days = 21
|
||||
|
|
@ -2503,6 +2504,7 @@ phase_2_model = "gpt-5"
|
|||
toml::from_str::<ConfigToml>(memories).expect("TOML deserialization should succeed");
|
||||
assert_eq!(
|
||||
Some(MemoriesToml {
|
||||
generate_memories: Some(false),
|
||||
use_memories: Some(false),
|
||||
max_raw_memories_for_global: Some(512),
|
||||
max_unused_days: Some(21),
|
||||
|
|
@ -2524,6 +2526,7 @@ phase_2_model = "gpt-5"
|
|||
assert_eq!(
|
||||
config.memories,
|
||||
MemoriesConfig {
|
||||
generate_memories: false,
|
||||
use_memories: false,
|
||||
max_raw_memories_for_global: 512,
|
||||
max_unused_days: 21,
|
||||
|
|
|
|||
|
|
@ -371,6 +371,8 @@ pub struct FeedbackConfigToml {
|
|||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, JsonSchema)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct MemoriesToml {
|
||||
/// When `false`, newly created threads are stored with `memory_mode = "disabled"` in the state DB.
|
||||
pub generate_memories: Option<bool>,
|
||||
/// When `false`, skip injecting memory usage instructions into developer prompts.
|
||||
pub use_memories: Option<bool>,
|
||||
/// Maximum number of recent raw memories retained for global consolidation.
|
||||
|
|
@ -392,6 +394,7 @@ pub struct MemoriesToml {
|
|||
/// Effective memories settings after defaults are applied.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct MemoriesConfig {
|
||||
pub generate_memories: bool,
|
||||
pub use_memories: bool,
|
||||
pub max_raw_memories_for_global: usize,
|
||||
pub max_unused_days: i64,
|
||||
|
|
@ -405,6 +408,7 @@ pub struct MemoriesConfig {
|
|||
impl Default for MemoriesConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
generate_memories: true,
|
||||
use_memories: true,
|
||||
max_raw_memories_for_global: DEFAULT_MEMORIES_MAX_RAW_MEMORIES_FOR_GLOBAL,
|
||||
max_unused_days: DEFAULT_MEMORIES_MAX_UNUSED_DAYS,
|
||||
|
|
@ -421,6 +425,7 @@ impl From<MemoriesToml> for MemoriesConfig {
|
|||
fn from(toml: MemoriesToml) -> Self {
|
||||
let defaults = Self::default();
|
||||
Self {
|
||||
generate_memories: toml.generate_memories.unwrap_or(defaults.generate_memories),
|
||||
use_memories: toml.use_memories.unwrap_or(defaults.use_memories),
|
||||
max_raw_memories_for_global: toml
|
||||
.max_raw_memories_for_global
|
||||
|
|
|
|||
|
|
@ -460,6 +460,7 @@ impl RolloutRecorder {
|
|||
state_db_ctx.clone(),
|
||||
state_builder,
|
||||
config.model_provider_id.clone(),
|
||||
config.memories.generate_memories,
|
||||
));
|
||||
|
||||
Ok(Self {
|
||||
|
|
@ -711,6 +712,7 @@ async fn rollout_writer(
|
|||
state_db_ctx: Option<StateDbHandle>,
|
||||
mut state_builder: Option<ThreadMetadataBuilder>,
|
||||
default_provider: String,
|
||||
generate_memories: bool,
|
||||
) -> std::io::Result<()> {
|
||||
let mut writer = file.map(|file| JsonlWriter { file });
|
||||
let mut buffered_items = Vec::<RolloutItem>::new();
|
||||
|
|
@ -731,6 +733,7 @@ async fn rollout_writer(
|
|||
state_db_ctx.as_deref(),
|
||||
&mut state_builder,
|
||||
default_provider.as_str(),
|
||||
generate_memories,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
|
@ -784,6 +787,7 @@ async fn rollout_writer(
|
|||
state_db_ctx.as_deref(),
|
||||
&mut state_builder,
|
||||
default_provider.as_str(),
|
||||
generate_memories,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
|
@ -831,6 +835,7 @@ async fn rollout_writer(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn write_session_meta(
|
||||
mut writer: Option<&mut JsonlWriter>,
|
||||
session_meta: SessionMeta,
|
||||
|
|
@ -839,6 +844,7 @@ async fn write_session_meta(
|
|||
state_db_ctx: Option<&StateRuntime>,
|
||||
state_builder: &mut Option<ThreadMetadataBuilder>,
|
||||
default_provider: &str,
|
||||
generate_memories: bool,
|
||||
) -> std::io::Result<()> {
|
||||
let git_info = collect_git_info(cwd).await;
|
||||
let session_meta_line = SessionMetaLine {
|
||||
|
|
@ -860,6 +866,7 @@ async fn write_session_meta(
|
|||
state_builder.as_ref(),
|
||||
std::slice::from_ref(&rollout_item),
|
||||
None,
|
||||
(!generate_memories).then_some("disabled"),
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
|
|
@ -888,6 +895,7 @@ async fn write_and_reconcile_items(
|
|||
state_builder.as_ref(),
|
||||
items,
|
||||
"rollout_writer",
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
|
|
|
|||
|
|
@ -345,6 +345,7 @@ pub async fn reconcile_rollout(
|
|||
builder: Option<&ThreadMetadataBuilder>,
|
||||
items: &[RolloutItem],
|
||||
archived_only: Option<bool>,
|
||||
new_thread_memory_mode: Option<&str>,
|
||||
) {
|
||||
let Some(ctx) = context else {
|
||||
return;
|
||||
|
|
@ -357,6 +358,7 @@ pub async fn reconcile_rollout(
|
|||
builder,
|
||||
items,
|
||||
"reconcile_rollout",
|
||||
new_thread_memory_mode,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
|
|
@ -467,6 +469,7 @@ pub async fn read_repair_rollout_path(
|
|||
None,
|
||||
&[],
|
||||
archived_only,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -479,6 +482,7 @@ pub async fn apply_rollout_items(
|
|||
builder: Option<&ThreadMetadataBuilder>,
|
||||
items: &[RolloutItem],
|
||||
stage: &str,
|
||||
new_thread_memory_mode: Option<&str>,
|
||||
) {
|
||||
let Some(ctx) = context else {
|
||||
return;
|
||||
|
|
@ -499,7 +503,10 @@ pub async fn apply_rollout_items(
|
|||
};
|
||||
builder.rollout_path = rollout_path.to_path_buf();
|
||||
builder.cwd = normalize_cwd_for_state_db(&builder.cwd);
|
||||
if let Err(err) = ctx.apply_rollout_items(&builder, items, None).await {
|
||||
if let Err(err) = ctx
|
||||
.apply_rollout_items(&builder, items, None, new_thread_memory_mode)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
"state db apply_rollout_items failed during {stage} for {}: {err}",
|
||||
rollout_path.display()
|
||||
|
|
|
|||
|
|
@ -195,6 +195,15 @@ FROM threads
|
|||
|
||||
/// Insert or replace thread metadata directly.
|
||||
pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> {
|
||||
self.upsert_thread_with_creation_memory_mode(metadata, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn upsert_thread_with_creation_memory_mode(
|
||||
&self,
|
||||
metadata: &crate::ThreadMetadata,
|
||||
creation_memory_mode: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO threads (
|
||||
|
|
@ -217,8 +226,9 @@ INSERT INTO threads (
|
|||
archived_at,
|
||||
git_sha,
|
||||
git_branch,
|
||||
git_origin_url
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
git_origin_url,
|
||||
memory_mode
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
rollout_path = excluded.rollout_path,
|
||||
created_at = excluded.created_at,
|
||||
|
|
@ -261,6 +271,7 @@ ON CONFLICT(id) DO UPDATE SET
|
|||
.bind(metadata.git_sha.as_deref())
|
||||
.bind(metadata.git_branch.as_deref())
|
||||
.bind(metadata.git_origin_url.as_deref())
|
||||
.bind(creation_memory_mode.unwrap_or("enabled"))
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
@ -316,13 +327,14 @@ ON CONFLICT(thread_id, position) DO NOTHING
|
|||
builder: &ThreadMetadataBuilder,
|
||||
items: &[RolloutItem],
|
||||
otel: Option<&OtelManager>,
|
||||
new_thread_memory_mode: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
if items.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let mut metadata = self
|
||||
.get_thread(builder.id)
|
||||
.await?
|
||||
let existing_metadata = self.get_thread(builder.id).await?;
|
||||
let mut metadata = existing_metadata
|
||||
.clone()
|
||||
.unwrap_or_else(|| builder.build(&self.default_provider));
|
||||
metadata.rollout_path = builder.rollout_path.clone();
|
||||
for item in items {
|
||||
|
|
@ -333,7 +345,13 @@ ON CONFLICT(thread_id, position) DO NOTHING
|
|||
}
|
||||
// Keep the thread upsert before dynamic tools to satisfy the foreign key constraint:
|
||||
// thread_dynamic_tools.thread_id -> threads.id.
|
||||
if let Err(err) = self.upsert_thread(&metadata).await {
|
||||
let upsert_result = if existing_metadata.is_none() {
|
||||
self.upsert_thread_with_creation_memory_mode(&metadata, new_thread_memory_mode)
|
||||
.await
|
||||
} else {
|
||||
self.upsert_thread(&metadata).await
|
||||
};
|
||||
if let Err(err) = upsert_result {
|
||||
if let Some(otel) = otel {
|
||||
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]);
|
||||
}
|
||||
|
|
@ -494,3 +512,49 @@ pub(super) fn push_thread_order_and_limit(
|
|||
builder.push(" LIMIT ");
|
||||
builder.push_bind(limit as i64);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::runtime::test_support::test_thread_metadata;
|
||||
use crate::runtime::test_support::unique_temp_dir;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_thread_keeps_creation_memory_mode_for_existing_rows() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let thread_id =
|
||||
ThreadId::from_string("00000000-0000-0000-0000-000000000123").expect("valid thread id");
|
||||
let mut metadata = test_thread_metadata(&codex_home, thread_id, codex_home.clone());
|
||||
|
||||
runtime
|
||||
.upsert_thread_with_creation_memory_mode(&metadata, Some("disabled"))
|
||||
.await
|
||||
.expect("initial insert should succeed");
|
||||
|
||||
let memory_mode: String =
|
||||
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
|
||||
.bind(thread_id.to_string())
|
||||
.fetch_one(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("memory mode should be readable");
|
||||
assert_eq!(memory_mode, "disabled");
|
||||
|
||||
metadata.title = "updated title".to_string();
|
||||
runtime
|
||||
.upsert_thread(&metadata)
|
||||
.await
|
||||
.expect("upsert should succeed");
|
||||
|
||||
let memory_mode: String =
|
||||
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
|
||||
.bind(thread_id.to_string())
|
||||
.fetch_one(runtime.pool.as_ref())
|
||||
.await
|
||||
.expect("memory mode should remain readable");
|
||||
assert_eq!(memory_mode, "disabled");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue