From 41f3b1ba0bc3f800ea0dba3d3dc72515fc07c666 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Thu, 5 Feb 2026 16:16:31 +0000 Subject: [PATCH] feat: add memory tool (#10637) Add a tool for memory to retrieve a full memory based on the memory ID --- codex-rs/core/config.schema.json | 6 ++ codex-rs/core/src/features.rs | 8 ++ .../core/src/tools/handlers/get_memory.rs | 72 ++++++++++++++++ codex-rs/core/src/tools/handlers/mod.rs | 2 + codex-rs/core/src/tools/spec.rs | 59 +++++++++++++ codex-rs/core/tests/suite/memory_tool.rs | 84 +++++++++++++++++++ codex-rs/core/tests/suite/mod.rs | 1 + 7 files changed, 232 insertions(+) create mode 100644 codex-rs/core/src/tools/handlers/get_memory.rs create mode 100644 codex-rs/core/tests/suite/memory_tool.rs diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index bf2439723..a872d1b91 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -190,6 +190,9 @@ "include_apply_patch_tool": { "type": "boolean" }, + "memory_tool": { + "type": "boolean" + }, "personality": { "type": "boolean" }, @@ -1212,6 +1215,9 @@ "include_apply_patch_tool": { "type": "boolean" }, + "memory_tool": { + "type": "boolean" + }, "personality": { "type": "boolean" }, diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index e14373f85..ee091979e 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -107,6 +107,8 @@ pub enum Feature { RuntimeMetrics, /// Persist rollout metadata to a local SQLite database. Sqlite, + /// Enable the get_memory tool backed by SQLite thread memories. + MemoryTool, /// Append additional AGENTS.md guidance to user instructions. ChildAgentsMd, /// Enforce UTF8 output in Powershell. @@ -449,6 +451,12 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::UnderDevelopment, default_enabled: false, }, + FeatureSpec { + id: Feature::MemoryTool, + key: "memory_tool", + stage: Stage::UnderDevelopment, + default_enabled: false, + }, FeatureSpec { id: Feature::ChildAgentsMd, key: "child_agents_md", diff --git a/codex-rs/core/src/tools/handlers/get_memory.rs b/codex-rs/core/src/tools/handlers/get_memory.rs new file mode 100644 index 000000000..df2929b88 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/get_memory.rs @@ -0,0 +1,72 @@ +use crate::function_tool::FunctionCallError; +use crate::state_db; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use async_trait::async_trait; +use codex_protocol::ThreadId; +use codex_protocol::models::FunctionCallOutputBody; +use serde::Deserialize; +use serde_json::json; + +pub struct GetMemoryHandler; + +#[derive(Deserialize)] +struct GetMemoryArgs { + memory_id: String, +} + +#[async_trait] +impl ToolHandler for GetMemoryHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, payload, .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "get_memory handler received unsupported payload".to_string(), + )); + } + }; + + let args: GetMemoryArgs = parse_arguments(&arguments)?; + let thread_id = ThreadId::from_string(args.memory_id.as_str()).map_err(|err| { + FunctionCallError::RespondToModel(format!("memory_id must be a valid thread id: {err}")) + })?; + + let state_db_ctx = session.state_db(); + let memory = + state_db::get_thread_memory(state_db_ctx.as_deref(), thread_id, "get_memory_tool") + .await + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!( + "memory not found for memory_id={}", + args.memory_id + )) + })?; + + let content = serde_json::to_string_pretty(&json!({ + "memory_id": args.memory_id, + "trace_summary": memory.trace_summary, + "memory_summary": memory.memory_summary, + })) + .map_err(|err| { + FunctionCallError::Fatal(format!("failed to serialize memory payload: {err}")) + })?; + + Ok(ToolOutput::Function { + body: FunctionCallOutputBody::Text(content), + success: Some(true), + }) + } +} diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index dda4760bd..d8ec88716 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -1,6 +1,7 @@ pub mod apply_patch; pub(crate) mod collab; mod dynamic; +mod get_memory; mod grep_files; mod list_dir; mod mcp; @@ -20,6 +21,7 @@ use crate::function_tool::FunctionCallError; pub use apply_patch::ApplyPatchHandler; pub use collab::CollabHandler; pub use dynamic::DynamicToolHandler; +pub use get_memory::GetMemoryHandler; pub use grep_files::GrepFilesHandler; pub use list_dir::ListDirHandler; pub use mcp::McpHandler; diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index ddac191a9..26fccf318 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -31,6 +31,7 @@ pub(crate) struct ToolsConfig { pub web_search_mode: Option, pub collab_tools: bool, pub collaboration_modes_tools: bool, + pub memory_tools: bool, pub request_rule_enabled: bool, pub experimental_supported_tools: Vec, } @@ -51,6 +52,7 @@ impl ToolsConfig { let include_apply_patch_tool = features.enabled(Feature::ApplyPatchFreeform); let include_collab_tools = features.enabled(Feature::Collab); let include_collaboration_modes_tools = features.enabled(Feature::CollaborationModes); + let include_memory_tools = features.enabled(Feature::MemoryTool); let request_rule_enabled = features.enabled(Feature::RequestRule); let shell_type = if !features.enabled(Feature::ShellTool) { @@ -84,6 +86,7 @@ impl ToolsConfig { web_search_mode: *web_search_mode, collab_tools: include_collab_tools, collaboration_modes_tools: include_collaboration_modes_tools, + memory_tools: include_memory_tools, request_rule_enabled, experimental_supported_tools: model_info.experimental_supported_tools.clone(), } @@ -634,6 +637,28 @@ fn create_request_user_input_tool() -> ToolSpec { }) } +fn create_get_memory_tool() -> ToolSpec { + let properties = BTreeMap::from([( + "memory_id".to_string(), + JsonSchema::String { + description: Some( + "Memory ID to fetch. Uses the thread ID as the memory identifier.".to_string(), + ), + }, + )]); + + ToolSpec::Function(ResponsesApiTool { + name: "get_memory".to_string(), + description: "Loads the full stored memory payload for a memory_id.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["memory_id".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + fn create_close_agent_tool() -> ToolSpec { let mut properties = BTreeMap::new(); properties.insert( @@ -1228,6 +1253,7 @@ pub(crate) fn build_specs( use crate::tools::handlers::ApplyPatchHandler; use crate::tools::handlers::CollabHandler; use crate::tools::handlers::DynamicToolHandler; + use crate::tools::handlers::GetMemoryHandler; use crate::tools::handlers::GrepFilesHandler; use crate::tools::handlers::ListDirHandler; use crate::tools::handlers::McpHandler; @@ -1249,6 +1275,7 @@ pub(crate) fn build_specs( let plan_handler = Arc::new(PlanHandler); let apply_patch_handler = Arc::new(ApplyPatchHandler); let dynamic_tool_handler = Arc::new(DynamicToolHandler); + let get_memory_handler = Arc::new(GetMemoryHandler); let view_image_handler = Arc::new(ViewImageHandler); let mcp_handler = Arc::new(McpHandler); let mcp_resource_handler = Arc::new(McpResourceHandler); @@ -1308,6 +1335,11 @@ pub(crate) fn build_specs( builder.register_handler("request_user_input", request_user_input_handler); } + if config.memory_tools { + builder.push_spec(create_get_memory_tool()); + builder.register_handler("get_memory", get_memory_handler); + } + if let Some(apply_patch_tool_type) = &config.apply_patch_tool_type { match apply_patch_tool_type { ApplyPatchToolType::Freeform => { @@ -1669,6 +1701,33 @@ mod tests { assert_contains_tool_names(&tools, &["request_user_input"]); } + #[test] + fn get_memory_requires_memory_tool_feature() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.disable(Feature::MemoryTool); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + }); + let (tools, _) = build_specs(&tools_config, None, &[]).build(); + assert!( + !tools.iter().any(|t| t.spec.name() == "get_memory"), + "get_memory should be disabled when memory_tool feature is off" + ); + + features.enable(Feature::MemoryTool); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + }); + let (tools, _) = build_specs(&tools_config, None, &[]).build(); + assert_contains_tool_names(&tools, &["get_memory"]); + } + fn assert_model_tools( model_slug: &str, features: &Features, diff --git a/codex-rs/core/tests/suite/memory_tool.rs b/codex-rs/core/tests/suite/memory_tool.rs new file mode 100644 index 000000000..09d1ee3ce --- /dev/null +++ b/codex-rs/core/tests/suite/memory_tool.rs @@ -0,0 +1,84 @@ +#![allow(clippy::expect_used, clippy::unwrap_used)] + +use anyhow::Result; +use codex_core::features::Feature; +use core_test_support::responses::mount_function_call_agent_response; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::test_codex; +use pretty_assertions::assert_eq; +use serde_json::Value; +use serde_json::json; +use tokio::time::Duration; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn get_memory_tool_returns_persisted_thread_memory() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Sqlite); + config.features.enable(Feature::MemoryTool); + }); + let test = builder.build(&server).await?; + + let db = test.codex.state_db().expect("state db enabled"); + let thread_id = test.session_configured.session_id; + let thread_id_string = thread_id.to_string(); + + let mut thread_exists = false; + // Wait for DB creation. + for _ in 0..100 { + if db.get_thread(thread_id).await?.is_some() { + thread_exists = true; + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + assert!(thread_exists, "thread should exist in state db"); + + let trace_summary = "trace summary from sqlite"; + let memory_summary = "memory summary from sqlite"; + db.upsert_thread_memory(thread_id, trace_summary, memory_summary) + .await?; + + let call_id = "memory-call-1"; + let arguments = json!({ + "memory_id": thread_id_string, + }) + .to_string(); + let mocks = + mount_function_call_agent_response(&server, call_id, &arguments, "get_memory").await; + + test.submit_turn("load the saved memory").await?; + + let initial_request = mocks.function_call.single_request().body_json(); + assert!( + initial_request["tools"] + .as_array() + .expect("tools array") + .iter() + .filter_map(|tool| tool.get("name").and_then(Value::as_str)) + .any(|name| name == "get_memory"), + "get_memory tool should be exposed when memory_tool feature is enabled" + ); + + let completion_request = mocks.completion.single_request(); + let (content_opt, success_opt) = completion_request + .function_call_output_content_and_success(call_id) + .expect("function_call_output should be present"); + let success = success_opt.unwrap_or(true); + assert!(success, "expected successful get_memory tool call output"); + let content = content_opt.expect("function_call_output content should be present"); + let payload: Value = serde_json::from_str(&content)?; + assert_eq!( + payload, + json!({ + "memory_id": thread_id_string, + "trace_summary": trace_summary, + "memory_summary": memory_summary, + }) + ); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index b903f4e6c..379f52168 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -82,6 +82,7 @@ mod list_dir; mod list_models; mod live_cli; mod live_reload; +mod memory_tool; mod model_info_overrides; mod model_overrides; mod model_switching;