feat: add memory tool (#10637)
Add a tool for memory to retrieve a full memory based on the memory ID
This commit is contained in:
parent
fe1cbd0f38
commit
41f3b1ba0b
7 changed files with 232 additions and 0 deletions
|
|
@ -190,6 +190,9 @@
|
|||
"include_apply_patch_tool": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"memory_tool": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"personality": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
|
@ -1212,6 +1215,9 @@
|
|||
"include_apply_patch_tool": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"memory_tool": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"personality": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -107,6 +107,8 @@ pub enum Feature {
|
|||
RuntimeMetrics,
|
||||
/// Persist rollout metadata to a local SQLite database.
|
||||
Sqlite,
|
||||
/// Enable the get_memory tool backed by SQLite thread memories.
|
||||
MemoryTool,
|
||||
/// Append additional AGENTS.md guidance to user instructions.
|
||||
ChildAgentsMd,
|
||||
/// Enforce UTF8 output in Powershell.
|
||||
|
|
@ -449,6 +451,12 @@ pub const FEATURES: &[FeatureSpec] = &[
|
|||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::MemoryTool,
|
||||
key: "memory_tool",
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ChildAgentsMd,
|
||||
key: "child_agents_md",
|
||||
|
|
|
|||
72
codex-rs/core/src/tools/handlers/get_memory.rs
Normal file
72
codex-rs/core/src/tools/handlers/get_memory.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use crate::function_tool::FunctionCallError;
|
||||
use crate::state_db;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::handlers::parse_arguments;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
pub struct GetMemoryHandler;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GetMemoryArgs {
|
||||
memory_id: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for GetMemoryHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session, payload, ..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"get_memory handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: GetMemoryArgs = parse_arguments(&arguments)?;
|
||||
let thread_id = ThreadId::from_string(args.memory_id.as_str()).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("memory_id must be a valid thread id: {err}"))
|
||||
})?;
|
||||
|
||||
let state_db_ctx = session.state_db();
|
||||
let memory =
|
||||
state_db::get_thread_memory(state_db_ctx.as_deref(), thread_id, "get_memory_tool")
|
||||
.await
|
||||
.ok_or_else(|| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"memory not found for memory_id={}",
|
||||
args.memory_id
|
||||
))
|
||||
})?;
|
||||
|
||||
let content = serde_json::to_string_pretty(&json!({
|
||||
"memory_id": args.memory_id,
|
||||
"trace_summary": memory.trace_summary,
|
||||
"memory_summary": memory.memory_summary,
|
||||
}))
|
||||
.map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("failed to serialize memory payload: {err}"))
|
||||
})?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
body: FunctionCallOutputBody::Text(content),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod apply_patch;
|
||||
pub(crate) mod collab;
|
||||
mod dynamic;
|
||||
mod get_memory;
|
||||
mod grep_files;
|
||||
mod list_dir;
|
||||
mod mcp;
|
||||
|
|
@ -20,6 +21,7 @@ use crate::function_tool::FunctionCallError;
|
|||
pub use apply_patch::ApplyPatchHandler;
|
||||
pub use collab::CollabHandler;
|
||||
pub use dynamic::DynamicToolHandler;
|
||||
pub use get_memory::GetMemoryHandler;
|
||||
pub use grep_files::GrepFilesHandler;
|
||||
pub use list_dir::ListDirHandler;
|
||||
pub use mcp::McpHandler;
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ pub(crate) struct ToolsConfig {
|
|||
pub web_search_mode: Option<WebSearchMode>,
|
||||
pub collab_tools: bool,
|
||||
pub collaboration_modes_tools: bool,
|
||||
pub memory_tools: bool,
|
||||
pub request_rule_enabled: bool,
|
||||
pub experimental_supported_tools: Vec<String>,
|
||||
}
|
||||
|
|
@ -51,6 +52,7 @@ impl ToolsConfig {
|
|||
let include_apply_patch_tool = features.enabled(Feature::ApplyPatchFreeform);
|
||||
let include_collab_tools = features.enabled(Feature::Collab);
|
||||
let include_collaboration_modes_tools = features.enabled(Feature::CollaborationModes);
|
||||
let include_memory_tools = features.enabled(Feature::MemoryTool);
|
||||
let request_rule_enabled = features.enabled(Feature::RequestRule);
|
||||
|
||||
let shell_type = if !features.enabled(Feature::ShellTool) {
|
||||
|
|
@ -84,6 +86,7 @@ impl ToolsConfig {
|
|||
web_search_mode: *web_search_mode,
|
||||
collab_tools: include_collab_tools,
|
||||
collaboration_modes_tools: include_collaboration_modes_tools,
|
||||
memory_tools: include_memory_tools,
|
||||
request_rule_enabled,
|
||||
experimental_supported_tools: model_info.experimental_supported_tools.clone(),
|
||||
}
|
||||
|
|
@ -634,6 +637,28 @@ fn create_request_user_input_tool() -> ToolSpec {
|
|||
})
|
||||
}
|
||||
|
||||
fn create_get_memory_tool() -> ToolSpec {
|
||||
let properties = BTreeMap::from([(
|
||||
"memory_id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Memory ID to fetch. Uses the thread ID as the memory identifier.".to_string(),
|
||||
),
|
||||
},
|
||||
)]);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "get_memory".to_string(),
|
||||
description: "Loads the full stored memory payload for a memory_id.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["memory_id".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_close_agent_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
|
|
@ -1228,6 +1253,7 @@ pub(crate) fn build_specs(
|
|||
use crate::tools::handlers::ApplyPatchHandler;
|
||||
use crate::tools::handlers::CollabHandler;
|
||||
use crate::tools::handlers::DynamicToolHandler;
|
||||
use crate::tools::handlers::GetMemoryHandler;
|
||||
use crate::tools::handlers::GrepFilesHandler;
|
||||
use crate::tools::handlers::ListDirHandler;
|
||||
use crate::tools::handlers::McpHandler;
|
||||
|
|
@ -1249,6 +1275,7 @@ pub(crate) fn build_specs(
|
|||
let plan_handler = Arc::new(PlanHandler);
|
||||
let apply_patch_handler = Arc::new(ApplyPatchHandler);
|
||||
let dynamic_tool_handler = Arc::new(DynamicToolHandler);
|
||||
let get_memory_handler = Arc::new(GetMemoryHandler);
|
||||
let view_image_handler = Arc::new(ViewImageHandler);
|
||||
let mcp_handler = Arc::new(McpHandler);
|
||||
let mcp_resource_handler = Arc::new(McpResourceHandler);
|
||||
|
|
@ -1308,6 +1335,11 @@ pub(crate) fn build_specs(
|
|||
builder.register_handler("request_user_input", request_user_input_handler);
|
||||
}
|
||||
|
||||
if config.memory_tools {
|
||||
builder.push_spec(create_get_memory_tool());
|
||||
builder.register_handler("get_memory", get_memory_handler);
|
||||
}
|
||||
|
||||
if let Some(apply_patch_tool_type) = &config.apply_patch_tool_type {
|
||||
match apply_patch_tool_type {
|
||||
ApplyPatchToolType::Freeform => {
|
||||
|
|
@ -1669,6 +1701,33 @@ mod tests {
|
|||
assert_contains_tool_names(&tools, &["request_user_input"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_memory_requires_memory_tool_feature() {
|
||||
let config = test_config();
|
||||
let model_info = ModelsManager::construct_model_info_offline("gpt-5-codex", &config);
|
||||
let mut features = Features::with_defaults();
|
||||
features.disable(Feature::MemoryTool);
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_info: &model_info,
|
||||
features: &features,
|
||||
web_search_mode: Some(WebSearchMode::Cached),
|
||||
});
|
||||
let (tools, _) = build_specs(&tools_config, None, &[]).build();
|
||||
assert!(
|
||||
!tools.iter().any(|t| t.spec.name() == "get_memory"),
|
||||
"get_memory should be disabled when memory_tool feature is off"
|
||||
);
|
||||
|
||||
features.enable(Feature::MemoryTool);
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_info: &model_info,
|
||||
features: &features,
|
||||
web_search_mode: Some(WebSearchMode::Cached),
|
||||
});
|
||||
let (tools, _) = build_specs(&tools_config, None, &[]).build();
|
||||
assert_contains_tool_names(&tools, &["get_memory"]);
|
||||
}
|
||||
|
||||
fn assert_model_tools(
|
||||
model_slug: &str,
|
||||
features: &Features,
|
||||
|
|
|
|||
84
codex-rs/core/tests/suite/memory_tool.rs
Normal file
84
codex-rs/core/tests/suite/memory_tool.rs
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_core::features::Feature;
|
||||
use core_test_support::responses::mount_function_call_agent_response;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use tokio::time::Duration;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn get_memory_tool_returns_persisted_thread_memory() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.features.enable(Feature::Sqlite);
|
||||
config.features.enable(Feature::MemoryTool);
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let db = test.codex.state_db().expect("state db enabled");
|
||||
let thread_id = test.session_configured.session_id;
|
||||
let thread_id_string = thread_id.to_string();
|
||||
|
||||
let mut thread_exists = false;
|
||||
// Wait for DB creation.
|
||||
for _ in 0..100 {
|
||||
if db.get_thread(thread_id).await?.is_some() {
|
||||
thread_exists = true;
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
assert!(thread_exists, "thread should exist in state db");
|
||||
|
||||
let trace_summary = "trace summary from sqlite";
|
||||
let memory_summary = "memory summary from sqlite";
|
||||
db.upsert_thread_memory(thread_id, trace_summary, memory_summary)
|
||||
.await?;
|
||||
|
||||
let call_id = "memory-call-1";
|
||||
let arguments = json!({
|
||||
"memory_id": thread_id_string,
|
||||
})
|
||||
.to_string();
|
||||
let mocks =
|
||||
mount_function_call_agent_response(&server, call_id, &arguments, "get_memory").await;
|
||||
|
||||
test.submit_turn("load the saved memory").await?;
|
||||
|
||||
let initial_request = mocks.function_call.single_request().body_json();
|
||||
assert!(
|
||||
initial_request["tools"]
|
||||
.as_array()
|
||||
.expect("tools array")
|
||||
.iter()
|
||||
.filter_map(|tool| tool.get("name").and_then(Value::as_str))
|
||||
.any(|name| name == "get_memory"),
|
||||
"get_memory tool should be exposed when memory_tool feature is enabled"
|
||||
);
|
||||
|
||||
let completion_request = mocks.completion.single_request();
|
||||
let (content_opt, success_opt) = completion_request
|
||||
.function_call_output_content_and_success(call_id)
|
||||
.expect("function_call_output should be present");
|
||||
let success = success_opt.unwrap_or(true);
|
||||
assert!(success, "expected successful get_memory tool call output");
|
||||
let content = content_opt.expect("function_call_output content should be present");
|
||||
let payload: Value = serde_json::from_str(&content)?;
|
||||
assert_eq!(
|
||||
payload,
|
||||
json!({
|
||||
"memory_id": thread_id_string,
|
||||
"trace_summary": trace_summary,
|
||||
"memory_summary": memory_summary,
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -82,6 +82,7 @@ mod list_dir;
|
|||
mod list_models;
|
||||
mod live_cli;
|
||||
mod live_reload;
|
||||
mod memory_tool;
|
||||
mod model_info_overrides;
|
||||
mod model_overrides;
|
||||
mod model_switching;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue