Add hooks implementation and wire up to notify (#9691)
This introduces a `Hooks` service. It registers hooks from config and dispatches hook events at runtime. N.B. The hook config is not wired up to this yet. But for legacy reasons, we wire up `notify` from config and power it using hooks now. Nothing about the `notify` interface has changed. I'd start by reviewing `hooks/types.rs` Some things to note: - hook names subject to change - no hook result yet - stopping semantics yet to be introduced - additional hooks yet to be introduced
This commit is contained in:
parent
9ee746afd6
commit
3b54fd7336
8 changed files with 608 additions and 105 deletions
|
|
@ -24,6 +24,9 @@ use crate::features::FEATURES;
|
|||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
use crate::features::maybe_push_unstable_features_warning;
|
||||
use crate::hooks::HookEvent;
|
||||
use crate::hooks::HookEventAfterAgent;
|
||||
use crate::hooks::Hooks;
|
||||
use crate::models_manager::manager::ModelsManager;
|
||||
use crate::parse_command::parse_command;
|
||||
use crate::parse_turn_item;
|
||||
|
|
@ -35,7 +38,6 @@ use crate::stream_events_utils::last_assistant_message_from_item;
|
|||
use crate::terminal;
|
||||
use crate::truncate::TruncationPolicy;
|
||||
use crate::turn_metadata::build_turn_metadata_header;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use crate::util::error_or_panic;
|
||||
use async_channel::Receiver;
|
||||
use async_channel::Sender;
|
||||
|
|
@ -201,7 +203,6 @@ use crate::tools::spec::ToolsConfig;
|
|||
use crate::tools::spec::ToolsConfigParams;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
use crate::windows_sandbox::WindowsSandboxLevelExt;
|
||||
use codex_async_utils::OrCancelExt;
|
||||
|
|
@ -1015,7 +1016,7 @@ impl Session {
|
|||
Arc::clone(&config),
|
||||
Arc::clone(&auth_manager),
|
||||
),
|
||||
notifier: UserNotifier::new(config.notify.clone()),
|
||||
hooks: Hooks::new(config.as_ref()),
|
||||
rollout: Mutex::new(rollout_recorder),
|
||||
user_shell: Arc::new(default_shell),
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
|
|
@ -2450,8 +2451,8 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn notifier(&self) -> &UserNotifier {
|
||||
&self.services.notifier
|
||||
pub(crate) fn hooks(&self) -> &Hooks {
|
||||
&self.services.hooks
|
||||
}
|
||||
|
||||
pub(crate) fn user_shell(&self) -> Arc<shell::Shell> {
|
||||
|
|
@ -3776,14 +3777,21 @@ pub(crate) async fn run_turn(
|
|||
|
||||
if !needs_follow_up {
|
||||
last_agent_message = sampling_request_last_agent_message;
|
||||
sess.notifier()
|
||||
.notify(&UserNotification::AgentTurnComplete {
|
||||
thread_id: sess.conversation_id.to_string(),
|
||||
turn_id: turn_context.sub_id.clone(),
|
||||
cwd: turn_context.cwd.display().to_string(),
|
||||
input_messages: sampling_request_input_messages,
|
||||
last_assistant_message: last_agent_message.clone(),
|
||||
});
|
||||
sess.hooks()
|
||||
.dispatch(crate::hooks::HookPayload {
|
||||
session_id: sess.conversation_id,
|
||||
cwd: turn_context.cwd.clone(),
|
||||
triggered_at: chrono::Utc::now(),
|
||||
hook_event: HookEvent::AfterAgent {
|
||||
event: HookEventAfterAgent {
|
||||
thread_id: sess.conversation_id,
|
||||
turn_id: turn_context.sub_id.clone(),
|
||||
input_messages: sampling_request_input_messages,
|
||||
last_assistant_message: last_agent_message.clone(),
|
||||
},
|
||||
},
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
|
|
@ -5724,7 +5732,7 @@ mod tests {
|
|||
Arc::clone(&config),
|
||||
Arc::clone(&auth_manager),
|
||||
),
|
||||
notifier: UserNotifier::new(None),
|
||||
hooks: Hooks::new(&config),
|
||||
rollout: Mutex::new(None),
|
||||
user_shell: Arc::new(default_user_shell()),
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
|
|
@ -5854,7 +5862,7 @@ mod tests {
|
|||
Arc::clone(&config),
|
||||
Arc::clone(&auth_manager),
|
||||
),
|
||||
notifier: UserNotifier::new(None),
|
||||
hooks: Hooks::new(&config),
|
||||
rollout: Mutex::new(None),
|
||||
user_shell: Arc::new(default_user_shell()),
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
|
|
|
|||
8
codex-rs/core/src/hooks/mod.rs
Normal file
8
codex-rs/core/src/hooks/mod.rs
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
mod registry;
|
||||
mod types;
|
||||
mod user_notification;
|
||||
|
||||
pub(crate) use registry::Hooks;
|
||||
pub(crate) use types::HookEvent;
|
||||
pub(crate) use types::HookEventAfterAgent;
|
||||
pub(crate) use types::HookPayload;
|
||||
315
codex-rs/core/src/hooks/registry.rs
Normal file
315
codex-rs/core/src/hooks/registry.rs
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
use tokio::process::Command;
|
||||
|
||||
use super::types::Hook;
|
||||
use super::types::HookEvent;
|
||||
use super::types::HookOutcome;
|
||||
use super::types::HookPayload;
|
||||
use super::user_notification::notify_hook;
|
||||
use crate::config::Config;
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub(crate) struct Hooks {
|
||||
after_agent: Vec<Hook>,
|
||||
}
|
||||
|
||||
fn get_notify_hook(config: &Config) -> Option<Hook> {
|
||||
config
|
||||
.notify
|
||||
.as_ref()
|
||||
.filter(|argv| !argv.is_empty() && !argv[0].is_empty())
|
||||
.map(|argv| notify_hook(argv.clone()))
|
||||
}
|
||||
|
||||
// Hooks are arbitrary, user-specified functions that are deterministically
|
||||
// executed after specific events in the Codex lifecycle.
|
||||
impl Hooks {
|
||||
// new creates a new Hooks instance from config.
|
||||
// For legacy compatibility, if config.notify is set, it will be added to
|
||||
// the after_agent hooks.
|
||||
pub(crate) fn new(config: &Config) -> Self {
|
||||
let after_agent = get_notify_hook(config).into_iter().collect();
|
||||
Self { after_agent }
|
||||
}
|
||||
|
||||
fn hooks_for_event(&self, hook_event: &HookEvent) -> &[Hook] {
|
||||
match hook_event {
|
||||
HookEvent::AfterAgent { .. } => &self.after_agent,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn dispatch(&self, hook_payload: HookPayload) {
|
||||
// TODO(gt): support interrupting program execution by returning a result here.
|
||||
for hook in self.hooks_for_event(&hook_payload.hook_event) {
|
||||
let outcome = hook.execute(&hook_payload).await;
|
||||
if matches!(outcome, HookOutcome::Stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn command_from_argv(argv: &[String]) -> Option<Command> {
|
||||
let (program, args) = argv.split_first()?;
|
||||
if program.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut command = Command::new(program);
|
||||
command.args(args);
|
||||
Some(command)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use chrono::TimeZone;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::to_string;
|
||||
use tempfile::tempdir;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::config::test_config;
|
||||
|
||||
use super::super::types::Hook;
|
||||
use super::super::types::HookEvent;
|
||||
use super::super::types::HookEventAfterAgent;
|
||||
use super::super::types::HookOutcome;
|
||||
use super::super::types::HookPayload;
|
||||
use super::Hooks;
|
||||
use super::command_from_argv;
|
||||
use super::get_notify_hook;
|
||||
|
||||
const CWD: &str = "/tmp";
|
||||
const INPUT_MESSAGE: &str = "hello";
|
||||
|
||||
fn hook_payload(label: &str) -> HookPayload {
|
||||
HookPayload {
|
||||
session_id: ThreadId::new(),
|
||||
cwd: PathBuf::from(CWD),
|
||||
triggered_at: Utc
|
||||
.with_ymd_and_hms(2025, 1, 1, 0, 0, 0)
|
||||
.single()
|
||||
.expect("valid timestamp"),
|
||||
hook_event: HookEvent::AfterAgent {
|
||||
event: HookEventAfterAgent {
|
||||
thread_id: ThreadId::new(),
|
||||
turn_id: format!("turn-{label}"),
|
||||
input_messages: vec![INPUT_MESSAGE.to_string()],
|
||||
last_assistant_message: Some("hi".to_string()),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn counting_hook(calls: &Arc<AtomicUsize>, outcome: HookOutcome) -> Hook {
|
||||
let calls = Arc::clone(calls);
|
||||
Hook {
|
||||
func: Arc::new(move |_| {
|
||||
let calls = Arc::clone(&calls);
|
||||
Box::pin(async move {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
outcome
|
||||
})
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn hooks_for_after_agent(hooks: Vec<Hook>) -> Hooks {
|
||||
Hooks { after_agent: hooks }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_from_argv_returns_none_for_empty_args() {
|
||||
assert!(command_from_argv(&[]).is_none());
|
||||
assert!(command_from_argv(&["".to_string()]).is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn command_from_argv_builds_command() -> Result<()> {
|
||||
let argv = if cfg!(windows) {
|
||||
vec![
|
||||
"cmd".to_string(),
|
||||
"/C".to_string(),
|
||||
"echo hello world".to_string(),
|
||||
]
|
||||
} else {
|
||||
vec!["echo".to_string(), "hello".to_string(), "world".to_string()]
|
||||
};
|
||||
let mut command = command_from_argv(&argv).ok_or_else(|| anyhow::anyhow!("command"))?;
|
||||
let output = command.stdout(Stdio::piped()).output().await?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let trimmed = stdout.trim_end_matches(['\r', '\n']);
|
||||
assert_eq!(trimmed, "hello world");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_notify_hook_requires_program_name() {
|
||||
let mut config = test_config();
|
||||
|
||||
config.notify = Some(vec![]);
|
||||
assert!(get_notify_hook(&config).is_none());
|
||||
|
||||
config.notify = Some(vec!["".to_string()]);
|
||||
assert!(get_notify_hook(&config).is_none());
|
||||
|
||||
config.notify = Some(vec!["notify-send".to_string()]);
|
||||
assert!(get_notify_hook(&config).is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_executes_hook() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = hooks_for_after_agent(vec![counting_hook(&calls, HookOutcome::Continue)]);
|
||||
|
||||
hooks.dispatch(hook_payload("1")).await;
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_hook_is_noop_and_continues() {
|
||||
let payload = hook_payload("d");
|
||||
let outcome = Hook::default().execute(&payload).await;
|
||||
assert_eq!(outcome, HookOutcome::Continue);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_executes_multiple_hooks_for_same_event() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = hooks_for_after_agent(vec![
|
||||
counting_hook(&calls, HookOutcome::Continue),
|
||||
counting_hook(&calls, HookOutcome::Continue),
|
||||
]);
|
||||
|
||||
hooks.dispatch(hook_payload("2")).await;
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_stops_when_hook_returns_stop() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = hooks_for_after_agent(vec![
|
||||
counting_hook(&calls, HookOutcome::Stop),
|
||||
counting_hook(&calls, HookOutcome::Continue),
|
||||
]);
|
||||
|
||||
hooks.dispatch(hook_payload("3")).await;
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn hook_executes_program_with_payload_argument_unix() -> Result<()> {
|
||||
let temp_dir = tempdir()?;
|
||||
let payload_path = temp_dir.path().join("payload.json");
|
||||
let payload_path_arg = payload_path.to_string_lossy().into_owned();
|
||||
let hook = Hook {
|
||||
func: Arc::new(move |payload: &HookPayload| {
|
||||
let payload_path_arg = payload_path_arg.clone();
|
||||
Box::pin(async move {
|
||||
let json = to_string(payload).expect("serialize hook payload");
|
||||
let mut command = command_from_argv(&[
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"printf '%s' \"$2\" > \"$1\"".to_string(),
|
||||
"sh".to_string(),
|
||||
payload_path_arg,
|
||||
json,
|
||||
])
|
||||
.expect("build command");
|
||||
command.status().await.expect("run hook command");
|
||||
HookOutcome::Continue
|
||||
})
|
||||
}),
|
||||
};
|
||||
|
||||
let payload = hook_payload("4");
|
||||
let expected = to_string(&payload)?;
|
||||
|
||||
let hooks = hooks_for_after_agent(vec![hook]);
|
||||
hooks.dispatch(payload).await;
|
||||
|
||||
let contents = timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if let Ok(contents) = fs::read_to_string(&payload_path)
|
||||
&& !contents.is_empty()
|
||||
{
|
||||
return contents;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert_eq!(contents, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[tokio::test]
|
||||
async fn hook_executes_program_with_payload_argument_windows() -> Result<()> {
|
||||
let temp_dir = tempdir()?;
|
||||
let payload_path = temp_dir.path().join("payload.json");
|
||||
let payload_path_arg = payload_path.to_string_lossy().into_owned();
|
||||
let script_path = temp_dir.path().join("write_payload.ps1");
|
||||
fs::write(&script_path, "[IO.File]::WriteAllText($args[0], $args[1])")?;
|
||||
let script_path_arg = script_path.to_string_lossy().into_owned();
|
||||
let hook = Hook {
|
||||
func: Arc::new(move |payload: &HookPayload| {
|
||||
let payload_path_arg = payload_path_arg.clone();
|
||||
let script_path_arg = script_path_arg.clone();
|
||||
Box::pin(async move {
|
||||
let json = to_string(payload).expect("serialize hook payload");
|
||||
let powershell = crate::powershell::try_find_powershell_executable_blocking()
|
||||
.map(|path| path.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| "powershell.exe".to_string());
|
||||
let mut command = command_from_argv(&[
|
||||
powershell,
|
||||
"-NoLogo".to_string(),
|
||||
"-NoProfile".to_string(),
|
||||
"-ExecutionPolicy".to_string(),
|
||||
"Bypass".to_string(),
|
||||
"-File".to_string(),
|
||||
script_path_arg,
|
||||
payload_path_arg,
|
||||
json,
|
||||
])
|
||||
.expect("build command");
|
||||
command.status().await.expect("run hook command");
|
||||
HookOutcome::Continue
|
||||
})
|
||||
}),
|
||||
};
|
||||
|
||||
let payload = hook_payload("4");
|
||||
let expected = to_string(&payload)?;
|
||||
|
||||
let hooks = hooks_for_after_agent(vec![hook]);
|
||||
hooks.dispatch(payload).await;
|
||||
|
||||
let contents = timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if let Ok(contents) = fs::read_to_string(&payload_path)
|
||||
&& !contents.is_empty()
|
||||
{
|
||||
return contents;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert_eq!(contents, expected);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
127
codex-rs/core/src/hooks/types.rs
Normal file
127
codex-rs/core/src/hooks/types.rs
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::DateTime;
|
||||
use chrono::SecondsFormat;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use futures::future::BoxFuture;
|
||||
use serde::Serialize;
|
||||
use serde::Serializer;
|
||||
|
||||
pub(crate) type HookFn =
|
||||
Arc<dyn for<'a> Fn(&'a HookPayload) -> BoxFuture<'a, HookOutcome> + Send + Sync>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Hook {
|
||||
pub(crate) func: HookFn,
|
||||
}
|
||||
|
||||
impl Default for Hook {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
func: Arc::new(|_| Box::pin(async { HookOutcome::Continue })),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Hook {
|
||||
pub(super) async fn execute(&self, payload: &HookPayload) -> HookOutcome {
|
||||
(self.func)(payload).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct HookPayload {
|
||||
pub(crate) session_id: ThreadId,
|
||||
pub(crate) cwd: PathBuf,
|
||||
#[serde(serialize_with = "serialize_triggered_at")]
|
||||
pub(crate) triggered_at: DateTime<Utc>,
|
||||
pub(crate) hook_event: HookEvent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct HookEventAfterAgent {
|
||||
pub thread_id: ThreadId,
|
||||
pub turn_id: String,
|
||||
pub input_messages: Vec<String>,
|
||||
pub last_assistant_message: Option<String>,
|
||||
}
|
||||
|
||||
fn serialize_triggered_at<S>(value: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&value.to_rfc3339_opts(SecondsFormat::Secs, true))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "event_type", rename_all = "snake_case")]
|
||||
pub(crate) enum HookEvent {
|
||||
AfterAgent {
|
||||
#[serde(flatten)]
|
||||
event: HookEventAfterAgent,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum HookOutcome {
|
||||
Continue,
|
||||
#[allow(dead_code)]
|
||||
Stop,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use chrono::TimeZone;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
use super::HookEvent;
|
||||
use super::HookEventAfterAgent;
|
||||
use super::HookPayload;
|
||||
|
||||
#[test]
|
||||
fn hook_payload_serializes_stable_wire_shape() {
|
||||
let session_id = ThreadId::new();
|
||||
let thread_id = ThreadId::new();
|
||||
let payload = HookPayload {
|
||||
session_id,
|
||||
cwd: PathBuf::from("tmp"),
|
||||
triggered_at: Utc
|
||||
.with_ymd_and_hms(2025, 1, 1, 0, 0, 0)
|
||||
.single()
|
||||
.expect("valid timestamp"),
|
||||
hook_event: HookEvent::AfterAgent {
|
||||
event: HookEventAfterAgent {
|
||||
thread_id,
|
||||
turn_id: "turn-1".to_string(),
|
||||
input_messages: vec!["hello".to_string()],
|
||||
last_assistant_message: Some("hi".to_string()),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let actual = serde_json::to_value(payload).expect("serialize hook payload");
|
||||
let expected = json!({
|
||||
"session_id": session_id.to_string(),
|
||||
"cwd": "tmp",
|
||||
"triggered_at": "2025-01-01T00:00:00Z",
|
||||
"hook_event": {
|
||||
"event_type": "after_agent",
|
||||
"thread_id": thread_id.to_string(),
|
||||
"turn_id": "turn-1",
|
||||
"input_messages": ["hello"],
|
||||
"last_assistant_message": "hi",
|
||||
},
|
||||
});
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
}
|
||||
132
codex-rs/core/src/hooks/user_notification.rs
Normal file
132
codex-rs/core/src/hooks/user_notification.rs
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use serde::Serialize;
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
|
||||
use super::registry::command_from_argv;
|
||||
use super::types::Hook;
|
||||
use super::types::HookEvent;
|
||||
use super::types::HookOutcome;
|
||||
use super::types::HookPayload;
|
||||
|
||||
/// Legacy notify payload appended as the final argv argument for backward compatibility.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "kebab-case")]
|
||||
enum UserNotification {
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
AgentTurnComplete {
|
||||
thread_id: String,
|
||||
turn_id: String,
|
||||
cwd: String,
|
||||
|
||||
/// Messages that the user sent to the agent to initiate the turn.
|
||||
input_messages: Vec<String>,
|
||||
|
||||
/// The last message sent by the assistant in the turn.
|
||||
last_assistant_message: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
pub(super) fn legacy_notify_json(
|
||||
hook_event: &HookEvent,
|
||||
cwd: &Path,
|
||||
) -> Result<String, serde_json::Error> {
|
||||
serde_json::to_string(&match hook_event {
|
||||
HookEvent::AfterAgent { event } => UserNotification::AgentTurnComplete {
|
||||
thread_id: event.thread_id.to_string(),
|
||||
turn_id: event.turn_id.clone(),
|
||||
cwd: cwd.display().to_string(),
|
||||
input_messages: event.input_messages.clone(),
|
||||
last_assistant_message: event.last_assistant_message.clone(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn notify_hook(argv: Vec<String>) -> Hook {
|
||||
let argv = Arc::new(argv);
|
||||
Hook {
|
||||
func: Arc::new(move |payload: &HookPayload| {
|
||||
let argv = Arc::clone(&argv);
|
||||
Box::pin(async move {
|
||||
let mut command = match command_from_argv(&argv) {
|
||||
Some(command) => command,
|
||||
None => return HookOutcome::Continue,
|
||||
};
|
||||
if let Ok(notify_payload) = legacy_notify_json(&payload.hook_event, &payload.cwd) {
|
||||
command.arg(notify_payload);
|
||||
}
|
||||
|
||||
// Backwards-compat: match legacy notify behavior (argv + JSON arg, fire-and-forget).
|
||||
command
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
|
||||
let _ = command.spawn();
|
||||
HookOutcome::Continue
|
||||
})
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
use codex_protocol::ThreadId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
fn expected_notification_json() -> Value {
|
||||
json!({
|
||||
"type": "agent-turn-complete",
|
||||
"thread-id": "b5f6c1c2-1111-2222-3333-444455556666",
|
||||
"turn-id": "12345",
|
||||
"cwd": "/Users/example/project",
|
||||
"input-messages": ["Rename `foo` to `bar` and update the callsites."],
|
||||
"last-assistant-message": "Rename complete and verified `cargo build` succeeds.",
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_notification() -> Result<()> {
|
||||
let notification = UserNotification::AgentTurnComplete {
|
||||
thread_id: "b5f6c1c2-1111-2222-3333-444455556666".to_string(),
|
||||
turn_id: "12345".to_string(),
|
||||
cwd: "/Users/example/project".to_string(),
|
||||
input_messages: vec!["Rename `foo` to `bar` and update the callsites.".to_string()],
|
||||
last_assistant_message: Some(
|
||||
"Rename complete and verified `cargo build` succeeds.".to_string(),
|
||||
),
|
||||
};
|
||||
let serialized = serde_json::to_string(¬ification)?;
|
||||
let actual: Value = serde_json::from_str(&serialized)?;
|
||||
assert_eq!(actual, expected_notification_json());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn legacy_notify_json_matches_historical_wire_shape() -> Result<()> {
|
||||
let hook_event = HookEvent::AfterAgent {
|
||||
event: super::super::types::HookEventAfterAgent {
|
||||
thread_id: ThreadId::from_string("b5f6c1c2-1111-2222-3333-444455556666")
|
||||
.expect("valid thread id"),
|
||||
turn_id: "12345".to_string(),
|
||||
input_messages: vec!["Rename `foo` to `bar` and update the callsites.".to_string()],
|
||||
last_assistant_message: Some(
|
||||
"Rename complete and verified `cargo build` succeeds.".to_string(),
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
let serialized = legacy_notify_json(&hook_event, Path::new("/Users/example/project"))?;
|
||||
let actual: Value = serde_json::from_str(&serialized)?;
|
||||
assert_eq!(actual, expected_notification_json());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -35,6 +35,7 @@ pub mod features;
|
|||
mod file_watcher;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod hooks;
|
||||
pub mod instructions;
|
||||
pub mod landlock;
|
||||
pub mod mcp;
|
||||
|
|
@ -125,7 +126,6 @@ pub use rollout::session_index::find_thread_names_by_ids;
|
|||
mod function_tool;
|
||||
mod state;
|
||||
mod tasks;
|
||||
mod user_notification;
|
||||
mod user_shell_command;
|
||||
pub mod util;
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ use crate::analytics_client::AnalyticsEventsClient;
|
|||
use crate::client::ModelClient;
|
||||
use crate::exec_policy::ExecPolicyManager;
|
||||
use crate::file_watcher::FileWatcher;
|
||||
use crate::hooks::Hooks;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::models_manager::manager::ModelsManager;
|
||||
use crate::skills::SkillsManager;
|
||||
use crate::state_db::StateDbHandle;
|
||||
use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use codex_otel::OtelManager;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
|
|
@ -24,7 +24,7 @@ pub(crate) struct SessionServices {
|
|||
pub(crate) mcp_startup_cancellation_token: Mutex<CancellationToken>,
|
||||
pub(crate) unified_exec_manager: UnifiedExecProcessManager,
|
||||
pub(crate) analytics_events_client: AnalyticsEventsClient,
|
||||
pub(crate) notifier: UserNotifier,
|
||||
pub(crate) hooks: Hooks,
|
||||
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
|
||||
pub(crate) user_shell: Arc<crate::shell::Shell>,
|
||||
pub(crate) show_raw_agent_reasoning: bool,
|
||||
|
|
|
|||
|
|
@ -1,87 +0,0 @@
|
|||
use serde::Serialize;
|
||||
use tracing::error;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct UserNotifier {
|
||||
notify_command: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl UserNotifier {
|
||||
pub(crate) fn notify(&self, notification: &UserNotification) {
|
||||
if let Some(notify_command) = &self.notify_command
|
||||
&& !notify_command.is_empty()
|
||||
{
|
||||
self.invoke_notify(notify_command, notification)
|
||||
}
|
||||
}
|
||||
|
||||
fn invoke_notify(&self, notify_command: &[String], notification: &UserNotification) {
|
||||
let Ok(json) = serde_json::to_string(¬ification) else {
|
||||
error!("failed to serialise notification payload");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut command = std::process::Command::new(¬ify_command[0]);
|
||||
if notify_command.len() > 1 {
|
||||
command.args(¬ify_command[1..]);
|
||||
}
|
||||
command.arg(json);
|
||||
|
||||
// Fire-and-forget – we do not wait for completion.
|
||||
if let Err(e) = command.spawn() {
|
||||
warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(notify: Option<Vec<String>>) -> Self {
|
||||
Self {
|
||||
notify_command: notify,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// User can configure a program that will receive notifications. Each
|
||||
/// notification is serialized as JSON and passed as an argument to the
|
||||
/// program.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "kebab-case")]
|
||||
pub(crate) enum UserNotification {
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
AgentTurnComplete {
|
||||
thread_id: String,
|
||||
turn_id: String,
|
||||
cwd: String,
|
||||
|
||||
/// Messages that the user sent to the agent to initiate the turn.
|
||||
input_messages: Vec<String>,
|
||||
|
||||
/// The last message sent by the assistant in the turn.
|
||||
last_assistant_message: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
|
||||
#[test]
|
||||
fn test_user_notification() -> Result<()> {
|
||||
let notification = UserNotification::AgentTurnComplete {
|
||||
thread_id: "b5f6c1c2-1111-2222-3333-444455556666".to_string(),
|
||||
turn_id: "12345".to_string(),
|
||||
cwd: "/Users/example/project".to_string(),
|
||||
input_messages: vec!["Rename `foo` to `bar` and update the callsites.".to_string()],
|
||||
last_assistant_message: Some(
|
||||
"Rename complete and verified `cargo build` succeeds.".to_string(),
|
||||
),
|
||||
};
|
||||
let serialized = serde_json::to_string(¬ification)?;
|
||||
assert_eq!(
|
||||
serialized,
|
||||
r#"{"type":"agent-turn-complete","thread-id":"b5f6c1c2-1111-2222-3333-444455556666","turn-id":"12345","cwd":"/Users/example/project","input-messages":["Rename `foo` to `bar` and update the callsites."],"last-assistant-message":"Rename complete and verified `cargo build` succeeds."}"#
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue