feat: add memory tool (#10637)

Add a tool for memory to retrieve a full memory based on the memory ID
This commit is contained in:
jif-oai 2026-02-05 16:16:31 +00:00 committed by GitHub
parent fe1cbd0f38
commit 41f3b1ba0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 232 additions and 0 deletions

View file

@ -190,6 +190,9 @@
"include_apply_patch_tool": {
"type": "boolean"
},
"memory_tool": {
"type": "boolean"
},
"personality": {
"type": "boolean"
},
@ -1212,6 +1215,9 @@
"include_apply_patch_tool": {
"type": "boolean"
},
"memory_tool": {
"type": "boolean"
},
"personality": {
"type": "boolean"
},

View file

@ -107,6 +107,8 @@ pub enum Feature {
RuntimeMetrics,
/// Persist rollout metadata to a local SQLite database.
Sqlite,
/// Enable the get_memory tool backed by SQLite thread memories.
MemoryTool,
/// Append additional AGENTS.md guidance to user instructions.
ChildAgentsMd,
/// Enforce UTF8 output in Powershell.
@ -449,6 +451,12 @@ pub const FEATURES: &[FeatureSpec] = &[
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::MemoryTool,
key: "memory_tool",
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::ChildAgentsMd,
key: "child_agents_md",

View file

@ -0,0 +1,72 @@
use crate::function_tool::FunctionCallError;
use crate::state_db;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::handlers::parse_arguments;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use async_trait::async_trait;
use codex_protocol::ThreadId;
use codex_protocol::models::FunctionCallOutputBody;
use serde::Deserialize;
use serde_json::json;
pub struct GetMemoryHandler;
#[derive(Deserialize)]
struct GetMemoryArgs {
memory_id: String,
}
#[async_trait]
impl ToolHandler for GetMemoryHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session, payload, ..
} = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"get_memory handler received unsupported payload".to_string(),
));
}
};
let args: GetMemoryArgs = parse_arguments(&arguments)?;
let thread_id = ThreadId::from_string(args.memory_id.as_str()).map_err(|err| {
FunctionCallError::RespondToModel(format!("memory_id must be a valid thread id: {err}"))
})?;
let state_db_ctx = session.state_db();
let memory =
state_db::get_thread_memory(state_db_ctx.as_deref(), thread_id, "get_memory_tool")
.await
.ok_or_else(|| {
FunctionCallError::RespondToModel(format!(
"memory not found for memory_id={}",
args.memory_id
))
})?;
let content = serde_json::to_string_pretty(&json!({
"memory_id": args.memory_id,
"trace_summary": memory.trace_summary,
"memory_summary": memory.memory_summary,
}))
.map_err(|err| {
FunctionCallError::Fatal(format!("failed to serialize memory payload: {err}"))
})?;
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
})
}
}

View file

@ -1,6 +1,7 @@
pub mod apply_patch;
pub(crate) mod collab;
mod dynamic;
mod get_memory;
mod grep_files;
mod list_dir;
mod mcp;
@ -20,6 +21,7 @@ use crate::function_tool::FunctionCallError;
pub use apply_patch::ApplyPatchHandler;
pub use collab::CollabHandler;
pub use dynamic::DynamicToolHandler;
pub use get_memory::GetMemoryHandler;
pub use grep_files::GrepFilesHandler;
pub use list_dir::ListDirHandler;
pub use mcp::McpHandler;

View file

@ -31,6 +31,7 @@ pub(crate) struct ToolsConfig {
pub web_search_mode: Option<WebSearchMode>,
pub collab_tools: bool,
pub collaboration_modes_tools: bool,
pub memory_tools: bool,
pub request_rule_enabled: bool,
pub experimental_supported_tools: Vec<String>,
}
@ -51,6 +52,7 @@ impl ToolsConfig {
let include_apply_patch_tool = features.enabled(Feature::ApplyPatchFreeform);
let include_collab_tools = features.enabled(Feature::Collab);
let include_collaboration_modes_tools = features.enabled(Feature::CollaborationModes);
let include_memory_tools = features.enabled(Feature::MemoryTool);
let request_rule_enabled = features.enabled(Feature::RequestRule);
let shell_type = if !features.enabled(Feature::ShellTool) {
@ -84,6 +86,7 @@ impl ToolsConfig {
web_search_mode: *web_search_mode,
collab_tools: include_collab_tools,
collaboration_modes_tools: include_collaboration_modes_tools,
memory_tools: include_memory_tools,
request_rule_enabled,
experimental_supported_tools: model_info.experimental_supported_tools.clone(),
}
@ -634,6 +637,28 @@ fn create_request_user_input_tool() -> ToolSpec {
})
}
fn create_get_memory_tool() -> ToolSpec {
let properties = BTreeMap::from([(
"memory_id".to_string(),
JsonSchema::String {
description: Some(
"Memory ID to fetch. Uses the thread ID as the memory identifier.".to_string(),
),
},
)]);
ToolSpec::Function(ResponsesApiTool {
name: "get_memory".to_string(),
description: "Loads the full stored memory payload for a memory_id.".to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: Some(vec!["memory_id".to_string()]),
additional_properties: Some(false.into()),
},
})
}
fn create_close_agent_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
@ -1228,6 +1253,7 @@ pub(crate) fn build_specs(
use crate::tools::handlers::ApplyPatchHandler;
use crate::tools::handlers::CollabHandler;
use crate::tools::handlers::DynamicToolHandler;
use crate::tools::handlers::GetMemoryHandler;
use crate::tools::handlers::GrepFilesHandler;
use crate::tools::handlers::ListDirHandler;
use crate::tools::handlers::McpHandler;
@ -1249,6 +1275,7 @@ pub(crate) fn build_specs(
let plan_handler = Arc::new(PlanHandler);
let apply_patch_handler = Arc::new(ApplyPatchHandler);
let dynamic_tool_handler = Arc::new(DynamicToolHandler);
let get_memory_handler = Arc::new(GetMemoryHandler);
let view_image_handler = Arc::new(ViewImageHandler);
let mcp_handler = Arc::new(McpHandler);
let mcp_resource_handler = Arc::new(McpResourceHandler);
@ -1308,6 +1335,11 @@ pub(crate) fn build_specs(
builder.register_handler("request_user_input", request_user_input_handler);
}
if config.memory_tools {
builder.push_spec(create_get_memory_tool());
builder.register_handler("get_memory", get_memory_handler);
}
if let Some(apply_patch_tool_type) = &config.apply_patch_tool_type {
match apply_patch_tool_type {
ApplyPatchToolType::Freeform => {
@ -1669,6 +1701,33 @@ mod tests {
assert_contains_tool_names(&tools, &["request_user_input"]);
}
#[test]
fn get_memory_requires_memory_tool_feature() {
let config = test_config();
let model_info = ModelsManager::construct_model_info_offline("gpt-5-codex", &config);
let mut features = Features::with_defaults();
features.disable(Feature::MemoryTool);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_info: &model_info,
features: &features,
web_search_mode: Some(WebSearchMode::Cached),
});
let (tools, _) = build_specs(&tools_config, None, &[]).build();
assert!(
!tools.iter().any(|t| t.spec.name() == "get_memory"),
"get_memory should be disabled when memory_tool feature is off"
);
features.enable(Feature::MemoryTool);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_info: &model_info,
features: &features,
web_search_mode: Some(WebSearchMode::Cached),
});
let (tools, _) = build_specs(&tools_config, None, &[]).build();
assert_contains_tool_names(&tools, &["get_memory"]);
}
fn assert_model_tools(
model_slug: &str,
features: &Features,

View file

@ -0,0 +1,84 @@
#![allow(clippy::expect_used, clippy::unwrap_used)]
use anyhow::Result;
use codex_core::features::Feature;
use core_test_support::responses::mount_function_call_agent_response;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::test_codex;
use pretty_assertions::assert_eq;
use serde_json::Value;
use serde_json::json;
use tokio::time::Duration;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn get_memory_tool_returns_persisted_thread_memory() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let mut builder = test_codex().with_config(|config| {
config.features.enable(Feature::Sqlite);
config.features.enable(Feature::MemoryTool);
});
let test = builder.build(&server).await?;
let db = test.codex.state_db().expect("state db enabled");
let thread_id = test.session_configured.session_id;
let thread_id_string = thread_id.to_string();
let mut thread_exists = false;
// Wait for DB creation.
for _ in 0..100 {
if db.get_thread(thread_id).await?.is_some() {
thread_exists = true;
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
assert!(thread_exists, "thread should exist in state db");
let trace_summary = "trace summary from sqlite";
let memory_summary = "memory summary from sqlite";
db.upsert_thread_memory(thread_id, trace_summary, memory_summary)
.await?;
let call_id = "memory-call-1";
let arguments = json!({
"memory_id": thread_id_string,
})
.to_string();
let mocks =
mount_function_call_agent_response(&server, call_id, &arguments, "get_memory").await;
test.submit_turn("load the saved memory").await?;
let initial_request = mocks.function_call.single_request().body_json();
assert!(
initial_request["tools"]
.as_array()
.expect("tools array")
.iter()
.filter_map(|tool| tool.get("name").and_then(Value::as_str))
.any(|name| name == "get_memory"),
"get_memory tool should be exposed when memory_tool feature is enabled"
);
let completion_request = mocks.completion.single_request();
let (content_opt, success_opt) = completion_request
.function_call_output_content_and_success(call_id)
.expect("function_call_output should be present");
let success = success_opt.unwrap_or(true);
assert!(success, "expected successful get_memory tool call output");
let content = content_opt.expect("function_call_output content should be present");
let payload: Value = serde_json::from_str(&content)?;
assert_eq!(
payload,
json!({
"memory_id": thread_id_string,
"trace_summary": trace_summary,
"memory_summary": memory_summary,
})
);
Ok(())
}

View file

@ -82,6 +82,7 @@ mod list_dir;
mod list_models;
mod live_cli;
mod live_reload;
mod memory_tool;
mod model_info_overrides;
mod model_overrides;
mod model_switching;