core-agent-ide/codex-rs/hooks/src/registry.rs
gt-oai b3095679ed
Allow hooks to error (#11615)
Allow hooks to return errors. 

We should do this before introducing more hook types, or we'll have to
migrate them all.
2026-02-16 14:11:05 +00:00

482 lines
16 KiB
Rust

use tokio::process::Command;
use crate::types::Hook;
use crate::types::HookEvent;
use crate::types::HookPayload;
use crate::types::HookResponse;
#[derive(Default, Clone)]
pub struct HooksConfig {
pub legacy_notify_argv: Option<Vec<String>>,
}
#[derive(Clone)]
pub struct Hooks {
after_agent: Vec<Hook>,
after_tool_use: Vec<Hook>,
}
impl Default for Hooks {
fn default() -> Self {
Self::new(HooksConfig::default())
}
}
// Hooks are arbitrary, user-specified functions that are deterministically
// executed after specific events in the Codex lifecycle.
impl Hooks {
pub fn new(config: HooksConfig) -> Self {
let after_agent = config
.legacy_notify_argv
.filter(|argv| !argv.is_empty() && !argv[0].is_empty())
.map(crate::notify_hook)
.into_iter()
.collect();
Self {
after_agent,
after_tool_use: Vec::new(),
}
}
fn hooks_for_event(&self, hook_event: &HookEvent) -> &[Hook] {
match hook_event {
HookEvent::AfterAgent { .. } => &self.after_agent,
HookEvent::AfterToolUse { .. } => &self.after_tool_use,
}
}
pub async fn dispatch(&self, hook_payload: HookPayload) -> Vec<HookResponse> {
let hooks = self.hooks_for_event(&hook_payload.hook_event);
let mut outcomes = Vec::with_capacity(hooks.len());
for hook in hooks {
let outcome = hook.execute(&hook_payload).await;
let should_abort_operation = outcome.result.should_abort_operation();
outcomes.push(outcome);
if should_abort_operation {
break;
}
}
outcomes
}
}
pub fn command_from_argv(argv: &[String]) -> Option<Command> {
let (program, args) = argv.split_first()?;
if program.is_empty() {
return None;
}
let mut command = Command::new(program);
command.args(args);
Some(command)
}
#[cfg(test)]
mod tests {
use std::fs;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use anyhow::Result;
use chrono::TimeZone;
use chrono::Utc;
use codex_protocol::ThreadId;
use pretty_assertions::assert_eq;
use serde_json::to_string;
use tempfile::tempdir;
use tokio::time::timeout;
use super::*;
use crate::types::HookEventAfterAgent;
use crate::types::HookEventAfterToolUse;
use crate::types::HookResult;
use crate::types::HookToolInput;
use crate::types::HookToolKind;
const CWD: &str = "/tmp";
const INPUT_MESSAGE: &str = "hello";
fn hook_payload(label: &str) -> HookPayload {
HookPayload {
session_id: ThreadId::new(),
cwd: PathBuf::from(CWD),
triggered_at: Utc
.with_ymd_and_hms(2025, 1, 1, 0, 0, 0)
.single()
.expect("valid timestamp"),
hook_event: HookEvent::AfterAgent {
event: HookEventAfterAgent {
thread_id: ThreadId::new(),
turn_id: format!("turn-{label}"),
input_messages: vec![INPUT_MESSAGE.to_string()],
last_assistant_message: Some("hi".to_string()),
},
},
}
}
fn counting_success_hook(calls: &Arc<AtomicUsize>, name: &str) -> Hook {
let hook_name = name.to_string();
let calls = Arc::clone(calls);
Hook {
name: hook_name,
func: Arc::new(move |_| {
let calls = Arc::clone(&calls);
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
HookResult::Success
})
}),
}
}
fn failing_continue_hook(calls: &Arc<AtomicUsize>, name: &str, message: &str) -> Hook {
let hook_name = name.to_string();
let message = message.to_string();
let calls = Arc::clone(calls);
Hook {
name: hook_name,
func: Arc::new(move |_| {
let calls = Arc::clone(&calls);
let message = message.clone();
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
HookResult::FailedContinue(std::io::Error::other(message).into())
})
}),
}
}
fn failing_abort_hook(calls: &Arc<AtomicUsize>, name: &str, message: &str) -> Hook {
let hook_name = name.to_string();
let message = message.to_string();
let calls = Arc::clone(calls);
Hook {
name: hook_name,
func: Arc::new(move |_| {
let calls = Arc::clone(&calls);
let message = message.clone();
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
HookResult::FailedAbort(std::io::Error::other(message).into())
})
}),
}
}
fn after_tool_use_payload(label: &str) -> HookPayload {
HookPayload {
session_id: ThreadId::new(),
cwd: PathBuf::from(CWD),
triggered_at: Utc
.with_ymd_and_hms(2025, 1, 1, 0, 0, 0)
.single()
.expect("valid timestamp"),
hook_event: HookEvent::AfterToolUse {
event: HookEventAfterToolUse {
turn_id: format!("turn-{label}"),
call_id: format!("call-{label}"),
tool_name: "apply_patch".to_string(),
tool_kind: HookToolKind::Custom,
tool_input: HookToolInput::Custom {
input: "*** Begin Patch".to_string(),
},
executed: true,
success: true,
duration_ms: 1,
mutating: true,
sandbox: "none".to_string(),
sandbox_policy: "danger-full-access".to_string(),
output_preview: "ok".to_string(),
},
},
}
}
#[test]
fn command_from_argv_returns_none_for_empty_args() {
assert!(command_from_argv(&[]).is_none());
assert!(command_from_argv(&["".to_string()]).is_none());
}
#[tokio::test]
async fn command_from_argv_builds_command() -> Result<()> {
let argv = if cfg!(windows) {
vec![
"cmd".to_string(),
"/C".to_string(),
"echo hello world".to_string(),
]
} else {
vec!["echo".to_string(), "hello".to_string(), "world".to_string()]
};
let mut command = command_from_argv(&argv).ok_or_else(|| anyhow::anyhow!("command"))?;
let output = command.stdout(Stdio::piped()).output().await?;
let stdout = String::from_utf8_lossy(&output.stdout);
let trimmed = stdout.trim_end_matches(['\r', '\n']);
assert_eq!(trimmed, "hello world");
Ok(())
}
#[test]
fn hooks_new_requires_program_name() {
assert!(Hooks::new(HooksConfig::default()).after_agent.is_empty());
assert!(
Hooks::new(HooksConfig {
legacy_notify_argv: Some(vec![]),
})
.after_agent
.is_empty()
);
assert!(
Hooks::new(HooksConfig {
legacy_notify_argv: Some(vec!["".to_string()]),
})
.after_agent
.is_empty()
);
assert_eq!(
Hooks::new(HooksConfig {
legacy_notify_argv: Some(vec!["notify-send".to_string()]),
})
.after_agent
.len(),
1
);
}
#[tokio::test]
async fn dispatch_executes_hook() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_agent: vec![counting_success_hook(&calls, "counting")],
..Hooks::default()
};
let outcomes = hooks.dispatch(hook_payload("1")).await;
assert_eq!(outcomes.len(), 1);
assert_eq!(outcomes[0].hook_name, "counting");
assert!(matches!(outcomes[0].result, HookResult::Success));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn default_hook_is_noop_and_continues() {
let payload = hook_payload("d");
let outcome = Hook::default().execute(&payload).await;
assert_eq!(outcome.hook_name, "default");
assert!(matches!(outcome.result, HookResult::Success));
}
#[tokio::test]
async fn dispatch_executes_multiple_hooks_for_same_event() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_agent: vec![
counting_success_hook(&calls, "counting-1"),
counting_success_hook(&calls, "counting-2"),
],
..Hooks::default()
};
let outcomes = hooks.dispatch(hook_payload("2")).await;
assert_eq!(outcomes.len(), 2);
assert_eq!(outcomes[0].hook_name, "counting-1");
assert_eq!(outcomes[1].hook_name, "counting-2");
assert!(matches!(outcomes[0].result, HookResult::Success));
assert!(matches!(outcomes[1].result, HookResult::Success));
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn dispatch_stops_when_hook_requests_abort() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_agent: vec![
failing_abort_hook(&calls, "abort", "hook failed"),
counting_success_hook(&calls, "counting"),
],
..Hooks::default()
};
let outcomes = hooks.dispatch(hook_payload("3")).await;
assert_eq!(outcomes.len(), 1);
assert_eq!(outcomes[0].hook_name, "abort");
assert!(matches!(outcomes[0].result, HookResult::FailedAbort(_)));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn dispatch_executes_after_tool_use_hooks() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_tool_use: vec![counting_success_hook(&calls, "counting")],
..Hooks::default()
};
let outcomes = hooks.dispatch(after_tool_use_payload("p")).await;
assert_eq!(outcomes.len(), 1);
assert_eq!(outcomes[0].hook_name, "counting");
assert!(matches!(outcomes[0].result, HookResult::Success));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn dispatch_continues_after_continueable_failure() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_agent: vec![
failing_continue_hook(&calls, "failing", "hook failed"),
counting_success_hook(&calls, "counting"),
],
..Hooks::default()
};
let outcomes = hooks.dispatch(hook_payload("err")).await;
assert_eq!(outcomes.len(), 2);
assert_eq!(outcomes[0].hook_name, "failing");
assert!(matches!(outcomes[0].result, HookResult::FailedContinue(_)));
assert_eq!(outcomes[1].hook_name, "counting");
assert!(matches!(outcomes[1].result, HookResult::Success));
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn dispatch_returns_after_tool_use_failure_outcome() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_tool_use: vec![failing_continue_hook(
&calls,
"failing",
"after_tool_use hook failed",
)],
..Hooks::default()
};
let outcomes = hooks.dispatch(after_tool_use_payload("err-tool")).await;
assert_eq!(outcomes.len(), 1);
assert_eq!(outcomes[0].hook_name, "failing");
assert!(matches!(outcomes[0].result, HookResult::FailedContinue(_)));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[cfg(not(windows))]
#[tokio::test]
async fn hook_executes_program_with_payload_argument_unix() -> Result<()> {
let temp_dir = tempdir()?;
let payload_path = temp_dir.path().join("payload.json");
let payload_path_arg = payload_path.to_string_lossy().into_owned();
let hook = Hook {
name: "write_payload".to_string(),
func: Arc::new(move |payload: &HookPayload| {
let payload_path_arg = payload_path_arg.clone();
Box::pin(async move {
let json = to_string(payload).expect("serialize hook payload");
let mut command = command_from_argv(&[
"/bin/sh".to_string(),
"-c".to_string(),
"printf '%s' \"$2\" > \"$1\"".to_string(),
"sh".to_string(),
payload_path_arg,
json,
])
.expect("build command");
command.status().await.expect("run hook command");
HookResult::Success
})
}),
};
let payload = hook_payload("4");
let expected = to_string(&payload)?;
let hooks = Hooks {
after_agent: vec![hook],
..Hooks::default()
};
let outcomes = hooks.dispatch(payload).await;
assert_eq!(outcomes.len(), 1);
assert!(matches!(outcomes[0].result, HookResult::Success));
let contents = timeout(Duration::from_secs(2), async {
loop {
if let Ok(contents) = fs::read_to_string(&payload_path)
&& !contents.is_empty()
{
return contents;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await?;
assert_eq!(contents, expected);
Ok(())
}
#[cfg(windows)]
#[tokio::test]
async fn hook_executes_program_with_payload_argument_windows() -> Result<()> {
let temp_dir = tempdir()?;
let payload_path = temp_dir.path().join("payload.json");
let payload_path_arg = payload_path.to_string_lossy().into_owned();
let script_path = temp_dir.path().join("write_payload.ps1");
fs::write(&script_path, "[IO.File]::WriteAllText($args[0], $args[1])")?;
let script_path_arg = script_path.to_string_lossy().into_owned();
let hook = Hook {
name: "write_payload".to_string(),
func: Arc::new(move |payload: &HookPayload| {
let payload_path_arg = payload_path_arg.clone();
let script_path_arg = script_path_arg.clone();
Box::pin(async move {
let json = to_string(payload).expect("serialize hook payload");
let mut command = command_from_argv(&[
"powershell.exe".to_string(),
"-NoLogo".to_string(),
"-NoProfile".to_string(),
"-ExecutionPolicy".to_string(),
"Bypass".to_string(),
"-File".to_string(),
script_path_arg,
payload_path_arg,
json,
])
.expect("build command");
command.status().await.expect("run hook command");
HookResult::Success
})
}),
};
let payload = hook_payload("4");
let expected = to_string(&payload)?;
let hooks = Hooks {
after_agent: vec![hook],
..Hooks::default()
};
let outcomes = hooks.dispatch(payload).await;
assert_eq!(outcomes.len(), 1);
assert!(matches!(outcomes[0].result, HookResult::Success));
let contents = timeout(Duration::from_secs(2), async {
loop {
if let Ok(contents) = fs::read_to_string(&payload_path)
&& !contents.is_empty()
{
return contents;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await?;
assert_eq!(contents, expected);
Ok(())
}
}