feat: gen memories config (#12999)

This commit is contained in:
jif-oai 2026-02-27 12:38:47 +01:00 committed by GitHub
parent a63d8bd569
commit bbd237348d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 98 additions and 7 deletions

View file

@ -613,6 +613,10 @@
"additionalProperties": false,
"description": "Memories settings loaded from config.toml.",
"properties": {
"generate_memories": {
"description": "When `false`, newly created threads are stored with `memory_mode = \"disabled\"` in the state DB.",
"type": "boolean"
},
"max_raw_memories_for_global": {
"description": "Maximum number of recent raw memories retained for global consolidation.",
"format": "uint",

View file

@ -2490,6 +2490,7 @@ persistence = "none"
let memories = r#"
[memories]
generate_memories = false
use_memories = false
max_raw_memories_for_global = 512
max_unused_days = 21
@ -2503,6 +2504,7 @@ phase_2_model = "gpt-5"
toml::from_str::<ConfigToml>(memories).expect("TOML deserialization should succeed");
assert_eq!(
Some(MemoriesToml {
generate_memories: Some(false),
use_memories: Some(false),
max_raw_memories_for_global: Some(512),
max_unused_days: Some(21),
@ -2524,6 +2526,7 @@ phase_2_model = "gpt-5"
assert_eq!(
config.memories,
MemoriesConfig {
generate_memories: false,
use_memories: false,
max_raw_memories_for_global: 512,
max_unused_days: 21,

View file

@ -371,6 +371,8 @@ pub struct FeedbackConfigToml {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, JsonSchema)]
#[schemars(deny_unknown_fields)]
pub struct MemoriesToml {
/// When `false`, newly created threads are stored with `memory_mode = "disabled"` in the state DB.
pub generate_memories: Option<bool>,
/// When `false`, skip injecting memory usage instructions into developer prompts.
pub use_memories: Option<bool>,
/// Maximum number of recent raw memories retained for global consolidation.
@ -392,6 +394,7 @@ pub struct MemoriesToml {
/// Effective memories settings after defaults are applied.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MemoriesConfig {
pub generate_memories: bool,
pub use_memories: bool,
pub max_raw_memories_for_global: usize,
pub max_unused_days: i64,
@ -405,6 +408,7 @@ pub struct MemoriesConfig {
impl Default for MemoriesConfig {
fn default() -> Self {
Self {
generate_memories: true,
use_memories: true,
max_raw_memories_for_global: DEFAULT_MEMORIES_MAX_RAW_MEMORIES_FOR_GLOBAL,
max_unused_days: DEFAULT_MEMORIES_MAX_UNUSED_DAYS,
@ -421,6 +425,7 @@ impl From<MemoriesToml> for MemoriesConfig {
fn from(toml: MemoriesToml) -> Self {
let defaults = Self::default();
Self {
generate_memories: toml.generate_memories.unwrap_or(defaults.generate_memories),
use_memories: toml.use_memories.unwrap_or(defaults.use_memories),
max_raw_memories_for_global: toml
.max_raw_memories_for_global

View file

@ -460,6 +460,7 @@ impl RolloutRecorder {
state_db_ctx.clone(),
state_builder,
config.model_provider_id.clone(),
config.memories.generate_memories,
));
Ok(Self {
@ -711,6 +712,7 @@ async fn rollout_writer(
state_db_ctx: Option<StateDbHandle>,
mut state_builder: Option<ThreadMetadataBuilder>,
default_provider: String,
generate_memories: bool,
) -> std::io::Result<()> {
let mut writer = file.map(|file| JsonlWriter { file });
let mut buffered_items = Vec::<RolloutItem>::new();
@ -731,6 +733,7 @@ async fn rollout_writer(
state_db_ctx.as_deref(),
&mut state_builder,
default_provider.as_str(),
generate_memories,
)
.await?;
}
@ -784,6 +787,7 @@ async fn rollout_writer(
state_db_ctx.as_deref(),
&mut state_builder,
default_provider.as_str(),
generate_memories,
)
.await?;
}
@ -831,6 +835,7 @@ async fn rollout_writer(
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn write_session_meta(
mut writer: Option<&mut JsonlWriter>,
session_meta: SessionMeta,
@ -839,6 +844,7 @@ async fn write_session_meta(
state_db_ctx: Option<&StateRuntime>,
state_builder: &mut Option<ThreadMetadataBuilder>,
default_provider: &str,
generate_memories: bool,
) -> std::io::Result<()> {
let git_info = collect_git_info(cwd).await;
let session_meta_line = SessionMetaLine {
@ -860,6 +866,7 @@ async fn write_session_meta(
state_builder.as_ref(),
std::slice::from_ref(&rollout_item),
None,
(!generate_memories).then_some("disabled"),
)
.await;
Ok(())
@ -888,6 +895,7 @@ async fn write_and_reconcile_items(
state_builder.as_ref(),
items,
"rollout_writer",
None,
)
.await;
Ok(())

View file

@ -345,6 +345,7 @@ pub async fn reconcile_rollout(
builder: Option<&ThreadMetadataBuilder>,
items: &[RolloutItem],
archived_only: Option<bool>,
new_thread_memory_mode: Option<&str>,
) {
let Some(ctx) = context else {
return;
@ -357,6 +358,7 @@ pub async fn reconcile_rollout(
builder,
items,
"reconcile_rollout",
new_thread_memory_mode,
)
.await;
return;
@ -467,6 +469,7 @@ pub async fn read_repair_rollout_path(
None,
&[],
archived_only,
None,
)
.await;
}
@ -479,6 +482,7 @@ pub async fn apply_rollout_items(
builder: Option<&ThreadMetadataBuilder>,
items: &[RolloutItem],
stage: &str,
new_thread_memory_mode: Option<&str>,
) {
let Some(ctx) = context else {
return;
@ -499,7 +503,10 @@ pub async fn apply_rollout_items(
};
builder.rollout_path = rollout_path.to_path_buf();
builder.cwd = normalize_cwd_for_state_db(&builder.cwd);
if let Err(err) = ctx.apply_rollout_items(&builder, items, None).await {
if let Err(err) = ctx
.apply_rollout_items(&builder, items, None, new_thread_memory_mode)
.await
{
warn!(
"state db apply_rollout_items failed during {stage} for {}: {err}",
rollout_path.display()

View file

@ -195,6 +195,15 @@ FROM threads
/// Insert or replace thread metadata directly.
pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> {
self.upsert_thread_with_creation_memory_mode(metadata, None)
.await
}
async fn upsert_thread_with_creation_memory_mode(
&self,
metadata: &crate::ThreadMetadata,
creation_memory_mode: Option<&str>,
) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO threads (
@ -217,8 +226,9 @@ INSERT INTO threads (
archived_at,
git_sha,
git_branch,
git_origin_url
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
git_origin_url,
memory_mode
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
rollout_path = excluded.rollout_path,
created_at = excluded.created_at,
@ -261,6 +271,7 @@ ON CONFLICT(id) DO UPDATE SET
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.bind(creation_memory_mode.unwrap_or("enabled"))
.execute(self.pool.as_ref())
.await?;
Ok(())
@ -316,13 +327,14 @@ ON CONFLICT(thread_id, position) DO NOTHING
builder: &ThreadMetadataBuilder,
items: &[RolloutItem],
otel: Option<&OtelManager>,
new_thread_memory_mode: Option<&str>,
) -> anyhow::Result<()> {
if items.is_empty() {
return Ok(());
}
let mut metadata = self
.get_thread(builder.id)
.await?
let existing_metadata = self.get_thread(builder.id).await?;
let mut metadata = existing_metadata
.clone()
.unwrap_or_else(|| builder.build(&self.default_provider));
metadata.rollout_path = builder.rollout_path.clone();
for item in items {
@ -333,7 +345,13 @@ ON CONFLICT(thread_id, position) DO NOTHING
}
// Keep the thread upsert before dynamic tools to satisfy the foreign key constraint:
// thread_dynamic_tools.thread_id -> threads.id.
if let Err(err) = self.upsert_thread(&metadata).await {
let upsert_result = if existing_metadata.is_none() {
self.upsert_thread_with_creation_memory_mode(&metadata, new_thread_memory_mode)
.await
} else {
self.upsert_thread(&metadata).await
};
if let Err(err) = upsert_result {
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]);
}
@ -494,3 +512,49 @@ pub(super) fn push_thread_order_and_limit(
builder.push(" LIMIT ");
builder.push_bind(limit as i64);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::test_support::test_thread_metadata;
use crate::runtime::test_support::unique_temp_dir;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn upsert_thread_keeps_creation_memory_mode_for_existing_rows() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("state db should initialize");
let thread_id =
ThreadId::from_string("00000000-0000-0000-0000-000000000123").expect("valid thread id");
let mut metadata = test_thread_metadata(&codex_home, thread_id, codex_home.clone());
runtime
.upsert_thread_with_creation_memory_mode(&metadata, Some("disabled"))
.await
.expect("initial insert should succeed");
let memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("memory mode should be readable");
assert_eq!(memory_mode, "disabled");
metadata.title = "updated title".to_string();
runtime
.upsert_thread(&metadata)
.await
.expect("upsert should succeed");
let memory_mode: String =
sqlx::query_scalar("SELECT memory_mode FROM threads WHERE id = ?")
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("memory mode should remain readable");
assert_eq!(memory_mode, "disabled");
}
}