refactor: remove 28K lines of dead/legacy code

Removed:
- pkg/loop/ — superseded by Claude native tool use
- pkg/lifecycle/ — 14K lines, old PHP API polling client
- pkg/jobrunner/ — old CodeRabbit orchestration (rebuilt in verify.go)
- pkg/orchestrator/ — old AgentCI config (replaced by agents.yaml)
- pkg/workspace/ — empty stub
- pkg/plugin/ — empty stub
- cmd/agent/ — old fleet management CLI
- cmd/dispatch/ — old polling dispatcher
- cmd/workspace/ — unused CLI
- cmd/tasks/ — unused CLI
- cmd/taskgit/ — unused CLI

120 files deleted, 28,780 lines removed.
Remaining: 31 Go files, 6,666 lines — cmd/core-agent + pkg/agentic + pkg/brain + pkg/monitor.

All functionality preserved in the new MCP-native architecture.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-03-17 19:06:03 +00:00
parent 742ca0799f
commit 7248928545
120 changed files with 0 additions and 28780 deletions

View file

@ -1,437 +0,0 @@
package agent
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"forge.lthn.ai/core/cli/pkg/cli"
agentic "forge.lthn.ai/core/agent/pkg/lifecycle"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-scm/agentci"
"forge.lthn.ai/core/config"
)
func init() {
cli.RegisterCommands(AddAgentCommands)
}
// Style aliases from shared package.
var (
successStyle = cli.SuccessStyle
errorStyle = cli.ErrorStyle
dimStyle = cli.DimStyle
taskPriorityMediumStyle = cli.NewStyle().Foreground(cli.ColourAmber500)
)
const defaultWorkDir = "ai-work"
// AddAgentCommands registers the 'agent' subcommand group under 'ai'.
func AddAgentCommands(parent *cli.Command) {
agentCmd := &cli.Command{
Use: "agent",
Short: "Manage AgentCI dispatch targets",
}
agentCmd.AddCommand(agentAddCmd())
agentCmd.AddCommand(agentListCmd())
agentCmd.AddCommand(agentStatusCmd())
agentCmd.AddCommand(agentLogsCmd())
agentCmd.AddCommand(agentSetupCmd())
agentCmd.AddCommand(agentRemoveCmd())
agentCmd.AddCommand(agentFleetCmd())
parent.AddCommand(agentCmd)
}
func loadConfig() (*config.Config, error) {
return config.New()
}
func agentAddCmd() *cli.Command {
cmd := &cli.Command{
Use: "add <name> <user@host>",
Short: "Add an agent to the config and verify SSH",
Args: cli.ExactArgs(2),
RunE: func(cmd *cli.Command, args []string) error {
name := args[0]
host := args[1]
forgejoUser, _ := cmd.Flags().GetString("forgejo-user")
if forgejoUser == "" {
forgejoUser = name
}
queueDir, _ := cmd.Flags().GetString("queue-dir")
if queueDir == "" {
queueDir = "" // resolved by orchestrator config
}
model, _ := cmd.Flags().GetString("model")
dualRun, _ := cmd.Flags().GetBool("dual-run")
// Scan and add host key to known_hosts.
parts := strings.Split(host, "@")
hostname := parts[len(parts)-1]
fmt.Printf("Scanning host key for %s... ", hostname)
scanCmd := exec.Command("ssh-keyscan", "-H", hostname)
keys, err := scanCmd.Output()
if err != nil {
fmt.Println(errorStyle.Render("FAILED"))
return coreerr.E("agent.add", "failed to scan host keys", err)
}
home, _ := os.UserHomeDir()
knownHostsPath := filepath.Join(home, ".ssh", "known_hosts")
f, err := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return coreerr.E("agent.add", "failed to open known_hosts", err)
}
if _, err := f.Write(keys); err != nil {
f.Close()
return coreerr.E("agent.add", "failed to write known_hosts", err)
}
f.Close()
fmt.Println(successStyle.Render("OK"))
// Test SSH with strict host key checking.
fmt.Printf("Testing SSH to %s... ", host)
testCmd := agentci.SecureSSHCommand(host, "echo ok")
out, err := testCmd.CombinedOutput()
if err != nil {
fmt.Println(errorStyle.Render("FAILED"))
return coreerr.E("agent.add", "SSH failed: "+strings.TrimSpace(string(out)), nil)
}
fmt.Println(successStyle.Render("OK"))
cfg, err := loadConfig()
if err != nil {
return err
}
ac := agentci.AgentConfig{
Host: host,
QueueDir: queueDir,
ForgejoUser: forgejoUser,
Model: model,
DualRun: dualRun,
Active: true,
}
if err := agentci.SaveAgent(cfg, name, ac); err != nil {
return err
}
fmt.Printf("Agent %s added (%s)\n", successStyle.Render(name), host)
return nil
},
}
cmd.Flags().String("forgejo-user", "", "Forgejo username (defaults to agent name)")
cmd.Flags().String("queue-dir", "", "Remote queue directory (default: ~/.core/queue)")
cmd.Flags().String("model", "sonnet", "Primary AI model")
cmd.Flags().Bool("dual-run", false, "Enable Clotho dual-run verification")
return cmd
}
func agentListCmd() *cli.Command {
return &cli.Command{
Use: "list",
Short: "List configured agents",
RunE: func(cmd *cli.Command, args []string) error {
cfg, err := loadConfig()
if err != nil {
return err
}
agents, err := agentci.ListAgents(cfg)
if err != nil {
return err
}
if len(agents) == 0 {
fmt.Println(dimStyle.Render("No agents configured. Use 'core ai agent add' to add one."))
return nil
}
table := cli.NewTable("NAME", "HOST", "MODEL", "DUAL", "ACTIVE", "QUEUE")
for name, ac := range agents {
active := dimStyle.Render("no")
if ac.Active {
active = successStyle.Render("yes")
}
dual := dimStyle.Render("no")
if ac.DualRun {
dual = successStyle.Render("yes")
}
// Quick SSH check for queue depth.
queue := dimStyle.Render("-")
checkCmd := agentci.SecureSSHCommand(ac.Host, fmt.Sprintf("ls %s/ticket-*.json 2>/dev/null | wc -l", ac.QueueDir))
out, err := checkCmd.Output()
if err == nil {
n := strings.TrimSpace(string(out))
if n != "0" {
queue = n
} else {
queue = "0"
}
}
table.AddRow(name, ac.Host, ac.Model, dual, active, queue)
}
table.Render()
return nil
},
}
}
func agentStatusCmd() *cli.Command {
return &cli.Command{
Use: "status <name>",
Short: "Check agent status via SSH",
Args: cli.ExactArgs(1),
RunE: func(cmd *cli.Command, args []string) error {
name := args[0]
cfg, err := loadConfig()
if err != nil {
return err
}
agents, err := agentci.ListAgents(cfg)
if err != nil {
return err
}
ac, ok := agents[name]
if !ok {
return coreerr.E("agent.status", "agent not found: "+name, nil)
}
script := `
echo "=== Queue ==="
ls ~/ai-work/queue/ticket-*.json 2>/dev/null | wc -l
echo "=== Active ==="
ls ~/ai-work/active/ticket-*.json 2>/dev/null || echo "none"
echo "=== Done ==="
ls ~/ai-work/done/ticket-*.json 2>/dev/null | wc -l
echo "=== Lock ==="
if [ -f ~/ai-work/.runner.lock ]; then
PID=$(cat ~/ai-work/.runner.lock)
if kill -0 "$PID" 2>/dev/null; then
echo "RUNNING (PID $PID)"
else
echo "STALE (PID $PID)"
fi
else
echo "IDLE"
fi
`
sshCmd := agentci.SecureSSHCommand(ac.Host, script)
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
return sshCmd.Run()
},
}
}
func agentLogsCmd() *cli.Command {
cmd := &cli.Command{
Use: "logs <name>",
Short: "Stream agent runner logs",
Args: cli.ExactArgs(1),
RunE: func(cmd *cli.Command, args []string) error {
name := args[0]
follow, _ := cmd.Flags().GetBool("follow")
lines, _ := cmd.Flags().GetInt("lines")
cfg, err := loadConfig()
if err != nil {
return err
}
agents, err := agentci.ListAgents(cfg)
if err != nil {
return err
}
ac, ok := agents[name]
if !ok {
return coreerr.E("agent.status", "agent not found: "+name, nil)
}
remoteCmd := fmt.Sprintf("tail -n %d ~/ai-work/logs/runner.log", lines)
if follow {
remoteCmd = fmt.Sprintf("tail -f -n %d ~/ai-work/logs/runner.log", lines)
}
sshCmd := agentci.SecureSSHCommand(ac.Host, remoteCmd)
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
sshCmd.Stdin = os.Stdin
return sshCmd.Run()
},
}
cmd.Flags().BoolP("follow", "f", false, "Follow log output")
cmd.Flags().IntP("lines", "n", 50, "Number of lines to show")
return cmd
}
func agentSetupCmd() *cli.Command {
return &cli.Command{
Use: "setup <name>",
Short: "Bootstrap agent machine (create dirs, copy runner, install cron)",
Args: cli.ExactArgs(1),
RunE: func(cmd *cli.Command, args []string) error {
name := args[0]
cfg, err := loadConfig()
if err != nil {
return err
}
agents, err := agentci.ListAgents(cfg)
if err != nil {
return err
}
ac, ok := agents[name]
if !ok {
return coreerr.E("agent.setup", "agent not found: "+name+" — use 'core ai agent add' first", nil)
}
// Find the setup script relative to the binary or in known locations.
scriptPath := findSetupScript()
if scriptPath == "" {
return coreerr.E("agent.setup", "agent-setup.sh not found — expected in scripts/ directory", nil)
}
fmt.Printf("Setting up %s on %s...\n", name, ac.Host)
setupCmd := exec.Command("bash", scriptPath, ac.Host)
setupCmd.Stdout = os.Stdout
setupCmd.Stderr = os.Stderr
if err := setupCmd.Run(); err != nil {
return coreerr.E("agent.setup", "setup failed", err)
}
fmt.Println(successStyle.Render("Setup complete!"))
return nil
},
}
}
func agentRemoveCmd() *cli.Command {
return &cli.Command{
Use: "remove <name>",
Short: "Remove an agent from config",
Args: cli.ExactArgs(1),
RunE: func(cmd *cli.Command, args []string) error {
name := args[0]
cfg, err := loadConfig()
if err != nil {
return err
}
if err := agentci.RemoveAgent(cfg, name); err != nil {
return err
}
fmt.Printf("Agent %s removed.\n", name)
return nil
},
}
}
func agentFleetCmd() *cli.Command {
cmd := &cli.Command{
Use: "fleet",
Short: "Show fleet status from the go-agentic registry",
RunE: func(cmd *cli.Command, args []string) error {
workDir, _ := cmd.Flags().GetString("work-dir")
if workDir == "" {
home, _ := os.UserHomeDir()
workDir = filepath.Join(home, defaultWorkDir)
}
dbPath := filepath.Join(workDir, "registry.db")
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
fmt.Println(dimStyle.Render("No registry found. Start a dispatch watcher first: core ai dispatch watch"))
return nil
}
registry, err := agentic.NewSQLiteRegistry(dbPath)
if err != nil {
return coreerr.E("agent.fleet", "failed to open registry", err)
}
defer registry.Close()
// Reap stale agents (no heartbeat for 10 minutes).
reaped := registry.Reap(10 * time.Minute)
if len(reaped) > 0 {
for _, id := range reaped {
fmt.Printf(" Reaped stale agent: %s\n", dimStyle.Render(id))
}
fmt.Println()
}
agents := registry.List()
if len(agents) == 0 {
fmt.Println(dimStyle.Render("No agents registered."))
return nil
}
table := cli.NewTable("ID", "STATUS", "LOAD", "LAST HEARTBEAT", "CAPABILITIES")
for _, a := range agents {
status := dimStyle.Render(string(a.Status))
switch a.Status {
case agentic.AgentAvailable:
status = successStyle.Render("available")
case agentic.AgentBusy:
status = taskPriorityMediumStyle.Render("busy")
case agentic.AgentOffline:
status = errorStyle.Render("offline")
}
load := fmt.Sprintf("%d/%d", a.CurrentLoad, a.MaxLoad)
hb := a.LastHeartbeat.Format("15:04:05")
ago := time.Since(a.LastHeartbeat).Truncate(time.Second)
hbStr := fmt.Sprintf("%s (%s ago)", hb, ago)
caps := "-"
if len(a.Capabilities) > 0 {
caps = strings.Join(a.Capabilities, ", ")
}
table.AddRow(a.ID, status, load, hbStr, caps)
}
table.Render()
return nil
},
}
cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)")
return cmd
}
// findSetupScript looks for agent-setup.sh in common locations.
func findSetupScript() string {
exe, _ := os.Executable()
if exe != "" {
dir := filepath.Dir(exe)
candidates := []string{
filepath.Join(dir, "scripts", "agent-setup.sh"),
filepath.Join(dir, "..", "scripts", "agent-setup.sh"),
}
for _, c := range candidates {
if _, err := os.Stat(c); err == nil {
return c
}
}
}
cwd, _ := os.Getwd()
if cwd != "" {
p := filepath.Join(cwd, "scripts", "agent-setup.sh")
if _, err := os.Stat(p); err == nil {
return p
}
}
return ""
}

View file

@ -1,877 +0,0 @@
package dispatch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
"forge.lthn.ai/core/cli/pkg/cli"
coreio "forge.lthn.ai/core/go-io"
"forge.lthn.ai/core/go-log"
agentic "forge.lthn.ai/core/agent/pkg/lifecycle"
)
func init() {
cli.RegisterCommands(AddDispatchCommands)
}
// AddDispatchCommands registers the 'dispatch' subcommand group under 'ai'.
// These commands run ON the agent machine to process the work queue.
func AddDispatchCommands(parent *cli.Command) {
dispatchCmd := &cli.Command{
Use: "dispatch",
Short: "Agent work queue processor (runs on agent machine)",
}
dispatchCmd.AddCommand(dispatchRunCmd())
dispatchCmd.AddCommand(dispatchWatchCmd())
dispatchCmd.AddCommand(dispatchStatusCmd())
parent.AddCommand(dispatchCmd)
}
// dispatchTicket represents the work item JSON structure.
type dispatchTicket struct {
ID string `json:"id"`
RepoOwner string `json:"repo_owner"`
RepoName string `json:"repo_name"`
IssueNumber int `json:"issue_number"`
IssueTitle string `json:"issue_title"`
IssueBody string `json:"issue_body"`
TargetBranch string `json:"target_branch"`
EpicNumber int `json:"epic_number"`
ForgeURL string `json:"forge_url"`
ForgeToken string `json:"forge_token"`
ForgeUser string `json:"forgejo_user"`
Model string `json:"model"`
Runner string `json:"runner"`
Timeout string `json:"timeout"`
CreatedAt string `json:"created_at"`
}
const (
defaultWorkDir = "ai-work"
lockFileName = ".runner.lock"
)
type runnerPaths struct {
root string
queue string
active string
done string
logs string
jobs string
lock string
}
func getPaths(baseDir string) runnerPaths {
if baseDir == "" {
home, _ := os.UserHomeDir()
baseDir = filepath.Join(home, defaultWorkDir)
}
return runnerPaths{
root: baseDir,
queue: filepath.Join(baseDir, "queue"),
active: filepath.Join(baseDir, "active"),
done: filepath.Join(baseDir, "done"),
logs: filepath.Join(baseDir, "logs"),
jobs: filepath.Join(baseDir, "jobs"),
lock: filepath.Join(baseDir, lockFileName),
}
}
func dispatchRunCmd() *cli.Command {
cmd := &cli.Command{
Use: "run",
Short: "Process a single ticket from the queue",
RunE: func(cmd *cli.Command, args []string) error {
workDir, _ := cmd.Flags().GetString("work-dir")
paths := getPaths(workDir)
if err := ensureDispatchDirs(paths); err != nil {
return err
}
if err := acquireLock(paths.lock); err != nil {
log.Info("Runner locked, skipping run", "lock", paths.lock)
return nil
}
defer releaseLock(paths.lock)
ticketFile, err := pickOldestTicket(paths.queue)
if err != nil {
return err
}
if ticketFile == "" {
return nil
}
_, err = processTicket(paths, ticketFile)
return err
},
}
cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)")
return cmd
}
// fastFailThreshold is how quickly a job must fail to be considered rate-limited.
// Real work always takes longer than 30 seconds; a 3-second exit means the CLI
// was rejected before it could start (rate limit, auth error, etc.).
const fastFailThreshold = 30 * time.Second
// maxBackoffMultiplier caps the exponential backoff at 8x the base interval.
const maxBackoffMultiplier = 8
func dispatchWatchCmd() *cli.Command {
cmd := &cli.Command{
Use: "watch",
Short: "Poll the PHP agentic API for work",
RunE: func(cmd *cli.Command, args []string) error {
workDir, _ := cmd.Flags().GetString("work-dir")
interval, _ := cmd.Flags().GetDuration("interval")
agentID, _ := cmd.Flags().GetString("agent-id")
agentType, _ := cmd.Flags().GetString("agent-type")
apiURL, _ := cmd.Flags().GetString("api-url")
apiKey, _ := cmd.Flags().GetString("api-key")
paths := getPaths(workDir)
if err := ensureDispatchDirs(paths); err != nil {
return err
}
// Create the go-agentic API client.
client := agentic.NewClient(apiURL, apiKey)
client.AgentID = agentID
// Verify connectivity.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := client.Ping(ctx); err != nil {
return log.E("dispatch.watch", "API ping failed (url="+apiURL+")", err)
}
log.Info("Connected to agentic API", "url", apiURL, "agent", agentID)
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
// Backoff state.
backoffMultiplier := 1
currentInterval := interval
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
adjustTicker := func(fastFail bool) {
if fastFail {
if backoffMultiplier < maxBackoffMultiplier {
backoffMultiplier *= 2
}
currentInterval = interval * time.Duration(backoffMultiplier)
log.Warn("Fast failure detected, backing off",
"multiplier", backoffMultiplier, "next_poll", currentInterval)
} else {
if backoffMultiplier > 1 {
log.Info("Job succeeded, resetting backoff")
}
backoffMultiplier = 1
currentInterval = interval
}
ticker.Reset(currentInterval)
}
log.Info("Starting API poller", "interval", interval, "agent", agentID, "type", agentType)
// Initial poll.
ff := pollAndExecute(ctx, client, agentID, agentType, paths)
adjustTicker(ff)
for {
select {
case <-ticker.C:
ff := pollAndExecute(ctx, client, agentID, agentType, paths)
adjustTicker(ff)
case <-sigChan:
log.Info("Shutting down watcher...")
return nil
case <-ctx.Done():
return nil
}
}
},
}
cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)")
cmd.Flags().Duration("interval", 2*time.Minute, "Polling interval")
cmd.Flags().String("agent-id", defaultAgentID(), "Agent identifier")
cmd.Flags().String("agent-type", "opus", "Agent type (opus, sonnet, gemini)")
cmd.Flags().String("api-url", "https://api.lthn.sh", "Agentic API base URL")
cmd.Flags().String("api-key", os.Getenv("AGENTIC_API_KEY"), "Agentic API key")
return cmd
}
// pollAndExecute checks the API for workable plans and executes one phase per cycle.
// Returns true if a fast failure occurred (signals backoff).
func pollAndExecute(ctx context.Context, client *agentic.Client, agentID, agentType string, paths runnerPaths) bool {
// List active plans.
plans, err := client.ListPlans(ctx, agentic.ListPlanOptions{Status: agentic.PlanActive})
if err != nil {
log.Error("Failed to list plans", "error", err)
return false
}
if len(plans) == 0 {
log.Debug("No active plans")
return false
}
// Find the first workable phase across all plans.
for _, plan := range plans {
// Fetch full plan with phases.
fullPlan, err := client.GetPlan(ctx, plan.Slug)
if err != nil {
log.Error("Failed to get plan", "slug", plan.Slug, "error", err)
continue
}
// Find first workable phase.
var targetPhase *agentic.Phase
for i := range fullPlan.Phases {
p := &fullPlan.Phases[i]
switch p.Status {
case agentic.PhaseInProgress:
targetPhase = p
case agentic.PhasePending:
if p.CanStart {
targetPhase = p
}
}
if targetPhase != nil {
break
}
}
if targetPhase == nil {
continue
}
log.Info("Found workable phase",
"plan", fullPlan.Slug, "phase", targetPhase.Name, "status", targetPhase.Status)
// Start session.
session, err := client.StartSession(ctx, agentic.StartSessionRequest{
AgentType: agentType,
PlanSlug: fullPlan.Slug,
Context: map[string]any{
"agent_id": agentID,
"phase": targetPhase.Name,
},
})
if err != nil {
log.Error("Failed to start session", "error", err)
return false
}
log.Info("Session started", "session_id", session.SessionID)
// Mark phase in-progress if pending.
if targetPhase.Status == agentic.PhasePending {
if err := client.UpdatePhaseStatus(ctx, fullPlan.Slug, targetPhase.Name, agentic.PhaseInProgress, ""); err != nil {
log.Warn("Failed to mark phase in-progress", "error", err)
}
}
// Extract repo info from plan metadata.
fastFail := executePhaseWork(ctx, client, fullPlan, targetPhase, session.SessionID, paths)
return fastFail
}
log.Debug("No workable phases found across active plans")
return false
}
// executePhaseWork does the actual repo prep + agent run for a phase.
// Returns true if the execution was a fast failure.
func executePhaseWork(ctx context.Context, client *agentic.Client, plan *agentic.Plan, phase *agentic.Phase, sessionID string, paths runnerPaths) bool {
// Extract repo metadata from the plan.
meta, _ := plan.Metadata.(map[string]any)
repoOwner, _ := meta["repo_owner"].(string)
repoName, _ := meta["repo_name"].(string)
issueNumFloat, _ := meta["issue_number"].(float64) // JSON numbers are float64
issueNumber := int(issueNumFloat)
forgeURL, _ := meta["forge_url"].(string)
forgeToken, _ := meta["forge_token"].(string)
forgeUser, _ := meta["forgejo_user"].(string)
targetBranch, _ := meta["target_branch"].(string)
runner, _ := meta["runner"].(string)
model, _ := meta["model"].(string)
timeout, _ := meta["timeout"].(string)
if targetBranch == "" {
targetBranch = "main"
}
if runner == "" {
runner = "claude"
}
// Build a dispatchTicket from the metadata so existing functions work.
t := dispatchTicket{
ID: fmt.Sprintf("%s-%s", plan.Slug, phase.Name),
RepoOwner: repoOwner,
RepoName: repoName,
IssueNumber: issueNumber,
IssueTitle: plan.Title,
IssueBody: phase.Description,
TargetBranch: targetBranch,
ForgeURL: forgeURL,
ForgeToken: forgeToken,
ForgeUser: forgeUser,
Model: model,
Runner: runner,
Timeout: timeout,
}
if t.RepoOwner == "" || t.RepoName == "" {
log.Error("Plan metadata missing repo_owner or repo_name", "plan", plan.Slug)
_ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), "missing repo metadata")
return false
}
// Prepare the repository.
jobDir := filepath.Join(paths.jobs, fmt.Sprintf("%s-%s-%d", t.RepoOwner, t.RepoName, t.IssueNumber))
repoDir := filepath.Join(jobDir, t.RepoName)
if err := coreio.Local.EnsureDir(jobDir); err != nil {
log.Error("Failed to create job dir", "error", err)
_ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), fmt.Sprintf("mkdir failed: %v", err))
return false
}
if err := prepareRepo(t, repoDir); err != nil {
log.Error("Repo preparation failed", "error", err)
_ = client.UpdatePhaseStatus(ctx, plan.Slug, phase.Name, agentic.PhaseBlocked, fmt.Sprintf("git setup failed: %v", err))
_ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), fmt.Sprintf("repo prep failed: %v", err))
return false
}
// Build prompt and run.
prompt := buildPrompt(t)
logFile := filepath.Join(paths.logs, fmt.Sprintf("%s-%s.log", plan.Slug, phase.Name))
start := time.Now()
success, exitCode, runErr := runAgent(t, prompt, repoDir, logFile)
elapsed := time.Since(start)
// Detect fast failure.
if !success && elapsed < fastFailThreshold {
log.Warn("Agent rejected fast, likely rate-limited",
"elapsed", elapsed.Round(time.Second), "plan", plan.Slug, "phase", phase.Name)
_ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), "fast failure — likely rate-limited")
return true
}
// Report results.
if success {
_ = client.UpdatePhaseStatus(ctx, plan.Slug, phase.Name, agentic.PhaseCompleted,
fmt.Sprintf("completed in %s", elapsed.Round(time.Second)))
_ = client.EndSession(ctx, sessionID, string(agentic.SessionCompleted),
fmt.Sprintf("Phase %q completed successfully (exit %d, %s)", phase.Name, exitCode, elapsed.Round(time.Second)))
} else {
note := fmt.Sprintf("failed with exit code %d after %s", exitCode, elapsed.Round(time.Second))
if runErr != nil {
note += fmt.Sprintf(": %v", runErr)
}
_ = client.UpdatePhaseStatus(ctx, plan.Slug, phase.Name, agentic.PhaseBlocked, note)
_ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), note)
}
// Also report to Forge issue if configured.
msg := fmt.Sprintf("Agent completed phase %q of plan %q. Exit code: %d.", phase.Name, plan.Slug, exitCode)
if !success {
msg = fmt.Sprintf("Agent failed phase %q of plan %q (exit code: %d).", phase.Name, plan.Slug, exitCode)
}
reportToForge(t, success, msg)
log.Info("Phase complete", "plan", plan.Slug, "phase", phase.Name, "success", success, "elapsed", elapsed.Round(time.Second))
return false
}
// defaultAgentID returns a sensible agent ID from hostname.
func defaultAgentID() string {
host, _ := os.Hostname()
if host == "" {
return "unknown"
}
return host
}
// --- Legacy registry/heartbeat functions (replaced by PHP API poller) ---
// registerAgent creates a SQLite registry and registers this agent.
// DEPRECATED: The watch command now uses the PHP agentic API instead.
// Kept for reference; remove once the API poller is proven stable.
/*
func registerAgent(agentID string, paths runnerPaths) (agentic.AgentRegistry, agentic.EventEmitter, func()) {
dbPath := filepath.Join(paths.root, "registry.db")
registry, err := agentic.NewSQLiteRegistry(dbPath)
if err != nil {
log.Warn("Failed to create agent registry", "error", err, "path", dbPath)
return nil, nil, nil
}
info := agentic.AgentInfo{
ID: agentID,
Name: agentID,
Status: agentic.AgentAvailable,
LastHeartbeat: time.Now().UTC(),
MaxLoad: 1,
}
if err := registry.Register(info); err != nil {
log.Warn("Failed to register agent", "error", err)
} else {
log.Info("Agent registered", "id", agentID)
}
events := agentic.NewChannelEmitter(64)
// Drain events to log.
go func() {
for ev := range events.Events() {
log.Debug("Event", "type", string(ev.Type), "task", ev.TaskID, "agent", ev.AgentID)
}
}()
return registry, events, func() {
events.Close()
}
}
*/
// heartbeatLoop sends periodic heartbeats to keep the agent status fresh.
// DEPRECATED: Replaced by PHP API poller.
/*
func heartbeatLoop(ctx context.Context, registry agentic.AgentRegistry, agentID string, interval time.Duration) {
if interval < 30*time.Second {
interval = 30 * time.Second
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
_ = registry.Heartbeat(agentID)
}
}
}
*/
// runCycleWithEvents wraps runCycle with registry status updates and event emission.
// DEPRECATED: Replaced by pollAndExecute.
/*
func runCycleWithEvents(paths runnerPaths, registry agentic.AgentRegistry, events agentic.EventEmitter, agentID string) bool {
if registry != nil {
if agent, err := registry.Get(agentID); err == nil {
agent.Status = agentic.AgentBusy
_ = registry.Register(agent)
}
}
fastFail := runCycle(paths)
if registry != nil {
if agent, err := registry.Get(agentID); err == nil {
agent.Status = agentic.AgentAvailable
agent.LastHeartbeat = time.Now().UTC()
_ = registry.Register(agent)
}
}
return fastFail
}
*/
func dispatchStatusCmd() *cli.Command {
cmd := &cli.Command{
Use: "status",
Short: "Show runner status",
RunE: func(cmd *cli.Command, args []string) error {
workDir, _ := cmd.Flags().GetString("work-dir")
paths := getPaths(workDir)
lockStatus := "IDLE"
if data, err := coreio.Local.Read(paths.lock); err == nil {
pidStr := strings.TrimSpace(data)
pid, _ := strconv.Atoi(pidStr)
if isProcessAlive(pid) {
lockStatus = fmt.Sprintf("RUNNING (PID %d)", pid)
} else {
lockStatus = fmt.Sprintf("STALE (PID %d)", pid)
}
}
countFiles := func(dir string) int {
entries, _ := os.ReadDir(dir)
count := 0
for _, e := range entries {
if !e.IsDir() && strings.HasPrefix(e.Name(), "ticket-") {
count++
}
}
return count
}
fmt.Println("=== Agent Dispatch Status ===")
fmt.Printf("Work Dir: %s\n", paths.root)
fmt.Printf("Status: %s\n", lockStatus)
fmt.Printf("Queue: %d\n", countFiles(paths.queue))
fmt.Printf("Active: %d\n", countFiles(paths.active))
fmt.Printf("Done: %d\n", countFiles(paths.done))
return nil
},
}
cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)")
return cmd
}
// runCycle picks and processes one ticket. Returns true if the job fast-failed
// (likely rate-limited), signalling the caller to back off.
func runCycle(paths runnerPaths) bool {
if err := acquireLock(paths.lock); err != nil {
log.Debug("Runner locked, skipping cycle")
return false
}
defer releaseLock(paths.lock)
ticketFile, err := pickOldestTicket(paths.queue)
if err != nil {
log.Error("Failed to pick ticket", "error", err)
return false
}
if ticketFile == "" {
return false // empty queue, no backoff needed
}
start := time.Now()
success, err := processTicket(paths, ticketFile)
elapsed := time.Since(start)
if err != nil {
log.Error("Failed to process ticket", "file", ticketFile, "error", err)
}
// Detect fast failure: job failed in under 30s → likely rate-limited.
if !success && elapsed < fastFailThreshold {
log.Warn("Job finished too fast, likely rate-limited",
"elapsed", elapsed.Round(time.Second), "file", filepath.Base(ticketFile))
return true
}
return false
}
// processTicket processes a single ticket. Returns (success, error).
// On fast failure the caller is responsible for detecting the timing and backing off.
// The ticket is moved active→done on completion, or active→queue on fast failure.
func processTicket(paths runnerPaths, ticketPath string) (bool, error) {
fileName := filepath.Base(ticketPath)
log.Info("Processing ticket", "file", fileName)
activePath := filepath.Join(paths.active, fileName)
if err := os.Rename(ticketPath, activePath); err != nil {
return false, log.E("processTicket", "failed to move ticket to active", err)
}
data, err := coreio.Local.Read(activePath)
if err != nil {
return false, log.E("processTicket", "failed to read ticket", err)
}
var t dispatchTicket
if err := json.Unmarshal([]byte(data), &t); err != nil {
return false, log.E("processTicket", "failed to unmarshal ticket", err)
}
jobDir := filepath.Join(paths.jobs, fmt.Sprintf("%s-%s-%d", t.RepoOwner, t.RepoName, t.IssueNumber))
repoDir := filepath.Join(jobDir, t.RepoName)
if err := coreio.Local.EnsureDir(jobDir); err != nil {
return false, err
}
if err := prepareRepo(t, repoDir); err != nil {
reportToForge(t, false, fmt.Sprintf("Git setup failed: %v", err))
moveToDone(paths, activePath, fileName)
return false, err
}
prompt := buildPrompt(t)
logFile := filepath.Join(paths.logs, fmt.Sprintf("%s-%s-%d.log", t.RepoOwner, t.RepoName, t.IssueNumber))
start := time.Now()
success, exitCode, runErr := runAgent(t, prompt, repoDir, logFile)
elapsed := time.Since(start)
// Fast failure: agent exited in <30s without success → likely rate-limited.
// Requeue the ticket so it's retried after the backoff period.
if !success && elapsed < fastFailThreshold {
log.Warn("Agent rejected fast, requeuing ticket", "elapsed", elapsed.Round(time.Second), "file", fileName)
requeuePath := filepath.Join(paths.queue, fileName)
if err := os.Rename(activePath, requeuePath); err != nil {
// Fallback: move to done if requeue fails.
moveToDone(paths, activePath, fileName)
}
return false, runErr
}
msg := fmt.Sprintf("Agent completed work on #%d. Exit code: %d.", t.IssueNumber, exitCode)
if !success {
msg = fmt.Sprintf("Agent failed on #%d (exit code: %d). Check logs on agent machine.", t.IssueNumber, exitCode)
if runErr != nil {
msg += fmt.Sprintf(" Error: %v", runErr)
}
}
reportToForge(t, success, msg)
moveToDone(paths, activePath, fileName)
log.Info("Ticket complete", "id", t.ID, "success", success, "elapsed", elapsed.Round(time.Second))
return success, nil
}
func prepareRepo(t dispatchTicket, repoDir string) error {
user := t.ForgeUser
if user == "" {
host, _ := os.Hostname()
user = fmt.Sprintf("%s-%s", host, os.Getenv("USER"))
}
cleanURL := strings.TrimPrefix(t.ForgeURL, "https://")
cleanURL = strings.TrimPrefix(cleanURL, "http://")
cloneURL := fmt.Sprintf("https://%s:%s@%s/%s/%s.git", user, t.ForgeToken, cleanURL, t.RepoOwner, t.RepoName)
if _, err := os.Stat(filepath.Join(repoDir, ".git")); err == nil {
log.Info("Updating existing repo", "dir", repoDir)
cmds := [][]string{
{"git", "fetch", "origin"},
{"git", "checkout", t.TargetBranch},
{"git", "pull", "origin", t.TargetBranch},
}
for _, args := range cmds {
cmd := exec.Command(args[0], args[1:]...)
cmd.Dir = repoDir
if out, err := cmd.CombinedOutput(); err != nil {
if args[1] == "checkout" {
createCmd := exec.Command("git", "checkout", "-b", t.TargetBranch, "origin/"+t.TargetBranch)
createCmd.Dir = repoDir
if _, err2 := createCmd.CombinedOutput(); err2 == nil {
continue
}
}
return log.E("prepareRepo", "git command failed: "+string(out), err)
}
}
} else {
log.Info("Cloning repo", "url", t.RepoOwner+"/"+t.RepoName)
cmd := exec.Command("git", "clone", "-b", t.TargetBranch, cloneURL, repoDir)
if out, err := cmd.CombinedOutput(); err != nil {
return log.E("prepareRepo", "git clone failed: "+string(out), err)
}
}
return nil
}
func buildPrompt(t dispatchTicket) string {
return fmt.Sprintf(`You are working on issue #%d in %s/%s.
Title: %s
Description:
%s
The repo is cloned at the current directory on branch '%s'.
Create a feature branch from '%s', make minimal targeted changes, commit referencing #%d, and push.
Then create a PR targeting '%s' using the forgejo MCP tools or git push.`,
t.IssueNumber, t.RepoOwner, t.RepoName,
t.IssueTitle,
t.IssueBody,
t.TargetBranch,
t.TargetBranch, t.IssueNumber,
t.TargetBranch,
)
}
func runAgent(t dispatchTicket, prompt, dir, logPath string) (bool, int, error) {
timeout := 30 * time.Minute
if t.Timeout != "" {
if d, err := time.ParseDuration(t.Timeout); err == nil {
timeout = d
}
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
model := t.Model
if model == "" {
model = "sonnet"
}
log.Info("Running agent", "runner", t.Runner, "model", model)
// For Gemini runner, wrap with rate limiting.
if t.Runner == "gemini" {
return executeWithRateLimit(ctx, model, prompt, func() (bool, int, error) {
return execAgent(ctx, t.Runner, model, prompt, dir, logPath)
})
}
return execAgent(ctx, t.Runner, model, prompt, dir, logPath)
}
func execAgent(ctx context.Context, runner, model, prompt, dir, logPath string) (bool, int, error) {
var cmd *exec.Cmd
switch runner {
case "codex":
cmd = exec.CommandContext(ctx, "codex", "exec", "--full-auto", prompt)
case "gemini":
args := []string{"-p", "-", "-y", "-m", model}
cmd = exec.CommandContext(ctx, "gemini", args...)
cmd.Stdin = strings.NewReader(prompt)
default: // claude
cmd = exec.CommandContext(ctx, "claude", "-p", "--model", model, "--dangerously-skip-permissions", "--output-format", "text")
cmd.Stdin = strings.NewReader(prompt)
}
cmd.Dir = dir
f, err := os.Create(logPath)
if err != nil {
return false, -1, err
}
defer f.Close()
cmd.Stdout = f
cmd.Stderr = f
if err := cmd.Run(); err != nil {
exitCode := -1
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
}
return false, exitCode, err
}
return true, 0, nil
}
func reportToForge(t dispatchTicket, success bool, body string) {
token := t.ForgeToken
if token == "" {
token = os.Getenv("FORGE_TOKEN")
}
if token == "" {
log.Warn("No forge token available, skipping report")
return
}
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d/comments",
strings.TrimSuffix(t.ForgeURL, "/"), t.RepoOwner, t.RepoName, t.IssueNumber)
payload := map[string]string{"body": body}
jsonBody, _ := json.Marshal(payload)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
log.Error("Failed to create request", "err", err)
return
}
req.Header.Set("Authorization", "token "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
log.Error("Failed to report to Forge", "err", err)
return
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
log.Warn("Forge reported error", "status", resp.Status)
}
}
func moveToDone(paths runnerPaths, activePath, fileName string) {
donePath := filepath.Join(paths.done, fileName)
if err := os.Rename(activePath, donePath); err != nil {
log.Error("Failed to move ticket to done", "err", err)
}
}
func ensureDispatchDirs(p runnerPaths) error {
dirs := []string{p.queue, p.active, p.done, p.logs, p.jobs}
for _, d := range dirs {
if err := coreio.Local.EnsureDir(d); err != nil {
return log.E("ensureDispatchDirs", "mkdir "+d+" failed", err)
}
}
return nil
}
func acquireLock(lockPath string) error {
if data, err := coreio.Local.Read(lockPath); err == nil {
pidStr := strings.TrimSpace(data)
pid, _ := strconv.Atoi(pidStr)
if isProcessAlive(pid) {
return log.E("acquireLock", fmt.Sprintf("locked by PID %d", pid), nil)
}
log.Info("Removing stale lock", "pid", pid)
_ = coreio.Local.Delete(lockPath)
}
return coreio.Local.Write(lockPath, fmt.Sprintf("%d", os.Getpid()))
}
func releaseLock(lockPath string) {
_ = coreio.Local.Delete(lockPath)
}
func isProcessAlive(pid int) bool {
if pid <= 0 {
return false
}
process, err := os.FindProcess(pid)
if err != nil {
return false
}
return process.Signal(syscall.Signal(0)) == nil
}
func pickOldestTicket(queueDir string) (string, error) {
entries, err := os.ReadDir(queueDir)
if err != nil {
return "", err
}
var tickets []string
for _, e := range entries {
if !e.IsDir() && strings.HasPrefix(e.Name(), "ticket-") && strings.HasSuffix(e.Name(), ".json") {
tickets = append(tickets, filepath.Join(queueDir, e.Name()))
}
}
if len(tickets) == 0 {
return "", nil
}
slices.Sort(tickets)
return tickets[0], nil
}

View file

@ -1,46 +0,0 @@
package dispatch
import (
"context"
"forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-ratelimit"
)
// executeWithRateLimit wraps an agent execution with rate limiting logic.
// It estimates token usage, waits for capacity, executes the runner, and records usage.
func executeWithRateLimit(ctx context.Context, model, prompt string, runner func() (bool, int, error)) (bool, int, error) {
rl, err := ratelimit.New()
if err != nil {
log.Warn("Failed to initialize rate limiter, proceeding without limits", "error", err)
return runner()
}
if err := rl.Load(); err != nil {
log.Warn("Failed to load rate limit state", "error", err)
}
// Estimate tokens from prompt length (1 token ≈ 4 chars)
estTokens := len(prompt) / 4
if estTokens == 0 {
estTokens = 1
}
log.Info("Checking rate limits", "model", model, "est_tokens", estTokens)
if err := rl.WaitForCapacity(ctx, model, estTokens); err != nil {
return false, -1, err
}
success, exitCode, runErr := runner()
// Record usage with conservative output estimate (actual tokens unknown from shell runner).
outputEst := max(estTokens/10, 50)
rl.RecordUsage(model, estTokens, outputEst)
if err := rl.Persist(); err != nil {
log.Warn("Failed to persist rate limit state", "error", err)
}
return success, exitCode, runErr
}

View file

@ -1,256 +0,0 @@
// Package taskgit implements git integration commands for task commits and PRs.
package taskgit
import (
"bytes"
"context"
"os"
"os/exec"
"strings"
"time"
agentic "forge.lthn.ai/core/agent/pkg/lifecycle"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/go-i18n"
)
func init() {
cli.RegisterCommands(AddTaskGitCommands)
}
// Style aliases from shared package.
var (
successStyle = cli.SuccessStyle
dimStyle = cli.DimStyle
)
// task:commit command flags
var (
taskCommitMessage string
taskCommitScope string
taskCommitPush bool
)
// task:pr command flags
var (
taskPRTitle string
taskPRDraft bool
taskPRLabels string
taskPRBase string
)
var taskCommitCmd = &cli.Command{
Use: "task:commit [task-id]",
Short: i18n.T("cmd.ai.task_commit.short"),
Long: i18n.T("cmd.ai.task_commit.long"),
Args: cli.ExactArgs(1),
RunE: func(cmd *cli.Command, args []string) error {
taskID := args[0]
if taskCommitMessage == "" {
return cli.Err("commit message required")
}
cfg, err := agentic.LoadConfig("")
if err != nil {
return cli.WrapVerb(err, "load", "config")
}
client := agentic.NewClientFromConfig(cfg)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Get task details
task, err := client.GetTask(ctx, taskID)
if err != nil {
return cli.WrapVerb(err, "get", "task")
}
// Build commit message with optional scope
commitType := inferCommitType(task.Labels)
var fullMessage string
if taskCommitScope != "" {
fullMessage = cli.Sprintf("%s(%s): %s", commitType, taskCommitScope, taskCommitMessage)
} else {
fullMessage = cli.Sprintf("%s: %s", commitType, taskCommitMessage)
}
// Get current directory
cwd, err := os.Getwd()
if err != nil {
return cli.WrapVerb(err, "get", "working directory")
}
// Check for uncommitted changes
hasChanges, err := agentic.HasUncommittedChanges(ctx, cwd)
if err != nil {
return cli.WrapVerb(err, "check", "git status")
}
if !hasChanges {
cli.Println("No changes to commit")
return nil
}
// Create commit
cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.ProgressSubject("create", "commit for "+taskID))
if err := agentic.AutoCommit(ctx, task, cwd, fullMessage); err != nil {
return cli.WrapAction(err, "commit")
}
cli.Print("%s %s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.commit")+":", fullMessage)
// Push if requested
if taskCommitPush {
cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.Progress("push"))
if err := agentic.PushChanges(ctx, cwd); err != nil {
return cli.WrapAction(err, "push")
}
cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.push", "changes"))
}
return nil
},
}
var taskPRCmd = &cli.Command{
Use: "task:pr [task-id]",
Short: i18n.T("cmd.ai.task_pr.short"),
Long: i18n.T("cmd.ai.task_pr.long"),
Args: cli.ExactArgs(1),
RunE: func(cmd *cli.Command, args []string) error {
taskID := args[0]
cfg, err := agentic.LoadConfig("")
if err != nil {
return cli.WrapVerb(err, "load", "config")
}
client := agentic.NewClientFromConfig(cfg)
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
// Get task details
task, err := client.GetTask(ctx, taskID)
if err != nil {
return cli.WrapVerb(err, "get", "task")
}
// Get current directory
cwd, err := os.Getwd()
if err != nil {
return cli.WrapVerb(err, "get", "working directory")
}
// Check current branch
branch, err := agentic.GetCurrentBranch(ctx, cwd)
if err != nil {
return cli.WrapVerb(err, "get", "branch")
}
if branch == "main" || branch == "master" {
return cli.Err("cannot create PR from %s branch", branch)
}
// Push current branch
cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.ProgressSubject("push", branch))
if err := agentic.PushChanges(ctx, cwd); err != nil {
// Try setting upstream
if _, err := runGitCommand(cwd, "push", "-u", "origin", branch); err != nil {
return cli.WrapVerb(err, "push", "branch")
}
}
// Build PR options
opts := agentic.PROptions{
Title: taskPRTitle,
Draft: taskPRDraft,
Base: taskPRBase,
}
if taskPRLabels != "" {
opts.Labels = strings.Split(taskPRLabels, ",")
}
// Create PR
cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.ProgressSubject("create", "PR"))
prURL, err := agentic.CreatePR(ctx, task, cwd, opts)
if err != nil {
return cli.WrapVerb(err, "create", "PR")
}
cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.create", "PR"))
cli.Print(" %s %s\n", i18n.Label("url"), prURL)
return nil
},
}
func initGitFlags() {
// task:commit command flags
taskCommitCmd.Flags().StringVarP(&taskCommitMessage, "message", "m", "", i18n.T("cmd.ai.task_commit.flag.message"))
taskCommitCmd.Flags().StringVar(&taskCommitScope, "scope", "", i18n.T("cmd.ai.task_commit.flag.scope"))
taskCommitCmd.Flags().BoolVar(&taskCommitPush, "push", false, i18n.T("cmd.ai.task_commit.flag.push"))
// task:pr command flags
taskPRCmd.Flags().StringVar(&taskPRTitle, "title", "", i18n.T("cmd.ai.task_pr.flag.title"))
taskPRCmd.Flags().BoolVar(&taskPRDraft, "draft", false, i18n.T("cmd.ai.task_pr.flag.draft"))
taskPRCmd.Flags().StringVar(&taskPRLabels, "labels", "", i18n.T("cmd.ai.task_pr.flag.labels"))
taskPRCmd.Flags().StringVar(&taskPRBase, "base", "", i18n.T("cmd.ai.task_pr.flag.base"))
}
// AddTaskGitCommands registers the task:commit and task:pr commands under a parent.
func AddTaskGitCommands(parent *cli.Command) {
initGitFlags()
parent.AddCommand(taskCommitCmd)
parent.AddCommand(taskPRCmd)
}
// inferCommitType infers the commit type from task labels.
func inferCommitType(labels []string) string {
for _, label := range labels {
switch strings.ToLower(label) {
case "bug", "bugfix", "fix":
return "fix"
case "docs", "documentation":
return "docs"
case "refactor", "refactoring":
return "refactor"
case "test", "tests", "testing":
return "test"
case "chore":
return "chore"
case "style":
return "style"
case "perf", "performance":
return "perf"
case "ci":
return "ci"
case "build":
return "build"
}
}
return "feat"
}
// runGitCommand runs a git command in the specified directory.
func runGitCommand(dir string, args ...string) (string, error) {
cmd := exec.Command("git", args...)
cmd.Dir = dir
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if stderr.Len() > 0 {
return "", cli.Wrap(err, stderr.String())
}
return "", err
}
return stdout.String(), nil
}

View file

@ -1,328 +0,0 @@
// Package tasks implements task listing, viewing, and claiming commands.
package tasks
import (
"context"
"os"
"slices"
"strings"
"time"
"forge.lthn.ai/core/cli/pkg/cli"
agentic "forge.lthn.ai/core/agent/pkg/lifecycle"
"forge.lthn.ai/core/go-ai/ai"
"forge.lthn.ai/core/go-i18n"
)
// Style aliases from shared package
var (
successStyle = cli.SuccessStyle
errorStyle = cli.ErrorStyle
dimStyle = cli.DimStyle
truncate = cli.Truncate
formatAge = cli.FormatAge
)
// Task priority/status styles from shared
var (
taskPriorityHighStyle = cli.NewStyle().Foreground(cli.ColourRed500)
taskPriorityMediumStyle = cli.NewStyle().Foreground(cli.ColourAmber500)
taskPriorityLowStyle = cli.NewStyle().Foreground(cli.ColourBlue400)
taskStatusPendingStyle = cli.DimStyle
taskStatusInProgressStyle = cli.NewStyle().Foreground(cli.ColourBlue500)
taskStatusCompletedStyle = cli.SuccessStyle
taskStatusBlockedStyle = cli.ErrorStyle
)
// Task-specific styles (aliases to shared where possible)
var (
taskIDStyle = cli.TitleStyle // Bold + blue
taskTitleStyle = cli.ValueStyle // Light gray
taskLabelStyle = cli.NewStyle().Foreground(cli.ColourViolet500) // Violet for labels
)
// tasks command flags
var (
tasksStatus string
tasksPriority string
tasksLabels string
tasksLimit int
tasksProject string
)
// task command flags
var (
taskAutoSelect bool
taskClaim bool
taskShowContext bool
)
var tasksCmd = &cli.Command{
Use: "tasks",
Short: i18n.T("cmd.ai.tasks.short"),
Long: i18n.T("cmd.ai.tasks.long"),
RunE: func(cmd *cli.Command, args []string) error {
limit := tasksLimit
if limit == 0 {
limit = 20
}
cfg, err := agentic.LoadConfig("")
if err != nil {
return cli.WrapVerb(err, "load", "config")
}
client := agentic.NewClientFromConfig(cfg)
opts := agentic.ListOptions{
Limit: limit,
Project: tasksProject,
}
if tasksStatus != "" {
opts.Status = agentic.TaskStatus(tasksStatus)
}
if tasksPriority != "" {
opts.Priority = agentic.TaskPriority(tasksPriority)
}
if tasksLabels != "" {
opts.Labels = strings.Split(tasksLabels, ",")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
tasks, err := client.ListTasks(ctx, opts)
if err != nil {
return cli.WrapVerb(err, "list", "tasks")
}
if len(tasks) == 0 {
cli.Text(i18n.T("cmd.ai.tasks.none_found"))
return nil
}
printTaskList(tasks)
return nil
},
}
var taskCmd = &cli.Command{
Use: "task [task-id]",
Short: i18n.T("cmd.ai.task.short"),
Long: i18n.T("cmd.ai.task.long"),
RunE: func(cmd *cli.Command, args []string) error {
cfg, err := agentic.LoadConfig("")
if err != nil {
return cli.WrapVerb(err, "load", "config")
}
client := agentic.NewClientFromConfig(cfg)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var task *agentic.Task
// Get the task ID from args
var taskID string
if len(args) > 0 {
taskID = args[0]
}
if taskAutoSelect {
// Auto-select: find highest priority pending task
tasks, err := client.ListTasks(ctx, agentic.ListOptions{
Status: agentic.StatusPending,
Limit: 50,
})
if err != nil {
return cli.WrapVerb(err, "list", "tasks")
}
if len(tasks) == 0 {
cli.Text(i18n.T("cmd.ai.task.no_pending"))
return nil
}
// Sort by priority (critical > high > medium > low)
priorityOrder := map[agentic.TaskPriority]int{
agentic.PriorityCritical: 0,
agentic.PriorityHigh: 1,
agentic.PriorityMedium: 2,
agentic.PriorityLow: 3,
}
slices.SortFunc(tasks, func(a, b agentic.Task) int {
return priorityOrder[a.Priority] - priorityOrder[b.Priority]
})
task = &tasks[0]
taskClaim = true // Auto-select implies claiming
} else {
if taskID == "" {
return cli.Err("%s", i18n.T("cmd.ai.task.id_required"))
}
task, err = client.GetTask(ctx, taskID)
if err != nil {
return cli.WrapVerb(err, "get", "task")
}
}
// Show context if requested
if taskShowContext {
cwd, _ := os.Getwd()
taskCtx, err := agentic.BuildTaskContext(task, cwd)
if err != nil {
cli.Print("%s %s: %s\n", errorStyle.Render(">>"), i18n.T("i18n.fail.build", "context"), err)
} else {
cli.Text(taskCtx.FormatContext())
}
} else {
printTaskDetails(task)
}
if taskClaim && task.Status == agentic.StatusPending {
cli.Blank()
cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.T("cmd.ai.task.claiming"))
claimedTask, err := client.ClaimTask(ctx, task.ID)
if err != nil {
return cli.WrapVerb(err, "claim", "task")
}
// Record task claim event
_ = ai.Record(ai.Event{
Type: "task.claimed",
AgentID: cfg.AgentID,
Data: map[string]any{"task_id": task.ID, "title": task.Title},
})
cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.claim", "task"))
cli.Print(" %s %s\n", i18n.Label("status"), formatTaskStatus(claimedTask.Status))
}
return nil
},
}
func initTasksFlags() {
// tasks command flags
tasksCmd.Flags().StringVar(&tasksStatus, "status", "", i18n.T("cmd.ai.tasks.flag.status"))
tasksCmd.Flags().StringVar(&tasksPriority, "priority", "", i18n.T("cmd.ai.tasks.flag.priority"))
tasksCmd.Flags().StringVar(&tasksLabels, "labels", "", i18n.T("cmd.ai.tasks.flag.labels"))
tasksCmd.Flags().IntVar(&tasksLimit, "limit", 20, i18n.T("cmd.ai.tasks.flag.limit"))
tasksCmd.Flags().StringVar(&tasksProject, "project", "", i18n.T("cmd.ai.tasks.flag.project"))
// task command flags
taskCmd.Flags().BoolVar(&taskAutoSelect, "auto", false, i18n.T("cmd.ai.task.flag.auto"))
taskCmd.Flags().BoolVar(&taskClaim, "claim", false, i18n.T("cmd.ai.task.flag.claim"))
taskCmd.Flags().BoolVar(&taskShowContext, "context", false, i18n.T("cmd.ai.task.flag.context"))
}
// AddTaskCommands adds the task management commands to a parent command.
func AddTaskCommands(parent *cli.Command) {
// Task listing and viewing
initTasksFlags()
parent.AddCommand(tasksCmd)
parent.AddCommand(taskCmd)
// Task updates
initUpdatesFlags()
parent.AddCommand(taskUpdateCmd)
parent.AddCommand(taskCompleteCmd)
}
func printTaskList(tasks []agentic.Task) {
cli.Print("\n%s\n\n", i18n.T("cmd.ai.tasks.found", map[string]any{"Count": len(tasks)}))
for _, task := range tasks {
id := taskIDStyle.Render(task.ID)
title := taskTitleStyle.Render(truncate(task.Title, 50))
priority := formatTaskPriority(task.Priority)
status := formatTaskStatus(task.Status)
line := cli.Sprintf(" %s %s %s %s", id, priority, status, title)
if len(task.Labels) > 0 {
labels := taskLabelStyle.Render("[" + strings.Join(task.Labels, ", ") + "]")
line += " " + labels
}
cli.Text(line)
}
cli.Blank()
cli.Print("%s\n", dimStyle.Render(i18n.T("cmd.ai.tasks.hint")))
}
func printTaskDetails(task *agentic.Task) {
cli.Blank()
cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.id")), taskIDStyle.Render(task.ID))
cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.title")), taskTitleStyle.Render(task.Title))
cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.priority")), formatTaskPriority(task.Priority))
cli.Print("%s %s\n", dimStyle.Render(i18n.Label("status")), formatTaskStatus(task.Status))
if task.Project != "" {
cli.Print("%s %s\n", dimStyle.Render(i18n.Label("project")), task.Project)
}
if len(task.Labels) > 0 {
cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.labels")), taskLabelStyle.Render(strings.Join(task.Labels, ", ")))
}
if task.ClaimedBy != "" {
cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.claimed_by")), task.ClaimedBy)
}
cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.created")), formatAge(task.CreatedAt))
cli.Blank()
cli.Print("%s\n", dimStyle.Render(i18n.T("cmd.ai.label.description")))
cli.Text(task.Description)
if len(task.Files) > 0 {
cli.Blank()
cli.Print("%s\n", dimStyle.Render(i18n.T("cmd.ai.label.related_files")))
for _, f := range task.Files {
cli.Print(" - %s\n", f)
}
}
if len(task.Dependencies) > 0 {
cli.Blank()
cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.blocked_by")), strings.Join(task.Dependencies, ", "))
}
}
func formatTaskPriority(p agentic.TaskPriority) string {
switch p {
case agentic.PriorityCritical:
return taskPriorityHighStyle.Render("[" + i18n.T("cmd.ai.priority.critical") + "]")
case agentic.PriorityHigh:
return taskPriorityHighStyle.Render("[" + i18n.T("cmd.ai.priority.high") + "]")
case agentic.PriorityMedium:
return taskPriorityMediumStyle.Render("[" + i18n.T("cmd.ai.priority.medium") + "]")
case agentic.PriorityLow:
return taskPriorityLowStyle.Render("[" + i18n.T("cmd.ai.priority.low") + "]")
default:
return dimStyle.Render("[" + string(p) + "]")
}
}
func formatTaskStatus(s agentic.TaskStatus) string {
switch s {
case agentic.StatusPending:
return taskStatusPendingStyle.Render(i18n.T("cmd.ai.status.pending"))
case agentic.StatusInProgress:
return taskStatusInProgressStyle.Render(i18n.T("cmd.ai.status.in_progress"))
case agentic.StatusCompleted:
return taskStatusCompletedStyle.Render(i18n.T("cmd.ai.status.completed"))
case agentic.StatusBlocked:
return taskStatusBlockedStyle.Render(i18n.T("cmd.ai.status.blocked"))
default:
return dimStyle.Render(string(s))
}
}

View file

@ -1,122 +0,0 @@
// updates.go implements task update and completion commands.
package tasks
import (
"context"
"time"
agentic "forge.lthn.ai/core/agent/pkg/lifecycle"
"forge.lthn.ai/core/go-ai/ai"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/go-i18n"
)
// task:update command flags
var (
taskUpdateStatus string
taskUpdateProgress int
taskUpdateNotes string
)
// task:complete command flags
var (
taskCompleteOutput string
taskCompleteFailed bool
taskCompleteErrorMsg string
)
var taskUpdateCmd = &cli.Command{
Use: "task:update [task-id]",
Short: i18n.T("cmd.ai.task_update.short"),
Long: i18n.T("cmd.ai.task_update.long"),
Args: cli.ExactArgs(1),
RunE: func(cmd *cli.Command, args []string) error {
taskID := args[0]
if taskUpdateStatus == "" && taskUpdateProgress == 0 && taskUpdateNotes == "" {
return cli.Err("%s", i18n.T("cmd.ai.task_update.flag_required"))
}
cfg, err := agentic.LoadConfig("")
if err != nil {
return cli.WrapVerb(err, "load", "config")
}
client := agentic.NewClientFromConfig(cfg)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
update := agentic.TaskUpdate{
Progress: taskUpdateProgress,
Notes: taskUpdateNotes,
}
if taskUpdateStatus != "" {
update.Status = agentic.TaskStatus(taskUpdateStatus)
}
if err := client.UpdateTask(ctx, taskID, update); err != nil {
return cli.WrapVerb(err, "update", "task")
}
cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.update", "task"))
return nil
},
}
var taskCompleteCmd = &cli.Command{
Use: "task:complete [task-id]",
Short: i18n.T("cmd.ai.task_complete.short"),
Long: i18n.T("cmd.ai.task_complete.long"),
Args: cli.ExactArgs(1),
RunE: func(cmd *cli.Command, args []string) error {
taskID := args[0]
cfg, err := agentic.LoadConfig("")
if err != nil {
return cli.WrapVerb(err, "load", "config")
}
client := agentic.NewClientFromConfig(cfg)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
result := agentic.TaskResult{
Success: !taskCompleteFailed,
Output: taskCompleteOutput,
ErrorMessage: taskCompleteErrorMsg,
}
if err := client.CompleteTask(ctx, taskID, result); err != nil {
return cli.WrapVerb(err, "complete", "task")
}
// Record task completion event
_ = ai.Record(ai.Event{
Type: "task.completed",
AgentID: cfg.AgentID,
Data: map[string]any{"task_id": taskID, "success": !taskCompleteFailed},
})
if taskCompleteFailed {
cli.Print("%s %s\n", errorStyle.Render(">>"), i18n.T("cmd.ai.task_complete.failed", map[string]any{"ID": taskID}))
} else {
cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.complete", "task"))
}
return nil
},
}
func initUpdatesFlags() {
// task:update command flags
taskUpdateCmd.Flags().StringVar(&taskUpdateStatus, "status", "", i18n.T("cmd.ai.task_update.flag.status"))
taskUpdateCmd.Flags().IntVar(&taskUpdateProgress, "progress", 0, i18n.T("cmd.ai.task_update.flag.progress"))
taskUpdateCmd.Flags().StringVar(&taskUpdateNotes, "notes", "", i18n.T("cmd.ai.task_update.flag.notes"))
// task:complete command flags
taskCompleteCmd.Flags().StringVar(&taskCompleteOutput, "output", "", i18n.T("cmd.ai.task_complete.flag.output"))
taskCompleteCmd.Flags().BoolVar(&taskCompleteFailed, "failed", false, i18n.T("cmd.ai.task_complete.flag.failed"))
taskCompleteCmd.Flags().StringVar(&taskCompleteErrorMsg, "error", "", i18n.T("cmd.ai.task_complete.flag.error"))
}

View file

@ -1 +0,0 @@
package workspace

View file

@ -1,289 +0,0 @@
// cmd_agent.go manages persistent agent context within task workspaces.
//
// Each agent gets a directory at:
//
// .core/workspace/p{epic}/i{issue}/agents/{provider}/{agent-name}/
//
// This directory persists across invocations, allowing agents to build
// understanding over time — QA agents accumulate findings, reviewers
// track patterns, implementors record decisions.
//
// Layout:
//
// agents/
// ├── claude-opus/implementor/
// │ ├── memory.md # Persistent notes, decisions, context
// │ └── artifacts/ # Generated artifacts (reports, diffs, etc.)
// ├── claude-opus/qa/
// │ ├── memory.md
// │ └── artifacts/
// └── gemini/reviewer/
// └── memory.md
package workspace
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
"time"
"forge.lthn.ai/core/cli/pkg/cli"
coreio "forge.lthn.ai/core/go-io"
coreerr "forge.lthn.ai/core/go-log"
)
var (
agentProvider string
agentName string
)
func addAgentCommands(parent *cli.Command) {
agentCmd := &cli.Command{
Use: "agent",
Short: "Manage persistent agent context within task workspaces",
}
initCmd := &cli.Command{
Use: "init <provider/agent-name>",
Short: "Initialize an agent's context directory in the task workspace",
Long: `Creates agents/{provider}/{agent-name}/ with memory.md and artifacts/
directory. The agent can read/write memory.md across invocations to
build understanding over time.`,
Args: cli.ExactArgs(1),
RunE: runAgentInit,
}
initCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number")
initCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number")
_ = initCmd.MarkFlagRequired("epic")
_ = initCmd.MarkFlagRequired("issue")
agentListCmd := &cli.Command{
Use: "list",
Short: "List agents in a task workspace",
RunE: runAgentList,
}
agentListCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number")
agentListCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number")
_ = agentListCmd.MarkFlagRequired("epic")
_ = agentListCmd.MarkFlagRequired("issue")
pathCmd := &cli.Command{
Use: "path <provider/agent-name>",
Short: "Print the agent's context directory path",
Args: cli.ExactArgs(1),
RunE: runAgentPath,
}
pathCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number")
pathCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number")
_ = pathCmd.MarkFlagRequired("epic")
_ = pathCmd.MarkFlagRequired("issue")
agentCmd.AddCommand(initCmd, agentListCmd, pathCmd)
parent.AddCommand(agentCmd)
}
// agentContextPath returns the path for an agent's context directory.
func agentContextPath(wsPath, provider, name string) string {
return filepath.Join(wsPath, "agents", provider, name)
}
// parseAgentID splits "provider/agent-name" into parts.
func parseAgentID(id string) (provider, name string, err error) {
parts := strings.SplitN(id, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", coreerr.E("parseAgentID", "agent ID must be provider/agent-name (e.g. claude-opus/qa)", nil)
}
return parts[0], parts[1], nil
}
// AgentManifest tracks agent metadata for a task workspace.
type AgentManifest struct {
Provider string `json:"provider"`
Name string `json:"name"`
CreatedAt time.Time `json:"created_at"`
LastSeen time.Time `json:"last_seen"`
}
func runAgentInit(cmd *cli.Command, args []string) error {
provider, name, err := parseAgentID(args[0])
if err != nil {
return err
}
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace")
}
wsPath := taskWorkspacePath(root, taskEpic, taskIssue)
if !coreio.Local.IsDir(wsPath) {
return cli.Err("task workspace does not exist: p%d/i%d — create it first with `core workspace task create`", taskEpic, taskIssue)
}
agentDir := agentContextPath(wsPath, provider, name)
if coreio.Local.IsDir(agentDir) {
// Update last_seen
updateAgentManifest(agentDir, provider, name)
cli.Print("Agent %s/%s already initialized at p%d/i%d\n",
cli.ValueStyle.Render(provider), cli.ValueStyle.Render(name), taskEpic, taskIssue)
cli.Print("Path: %s\n", cli.DimStyle.Render(agentDir))
return nil
}
// Create directory structure
if err := coreio.Local.EnsureDir(agentDir); err != nil {
return coreerr.E("agentInit", "failed to create agent directory", err)
}
if err := coreio.Local.EnsureDir(filepath.Join(agentDir, "artifacts")); err != nil {
return coreerr.E("agentInit", "failed to create artifacts directory", err)
}
// Create initial memory.md
memoryContent := fmt.Sprintf(`# %s/%s Issue #%d (EPIC #%d)
## Context
- **Task workspace:** p%d/i%d
- **Initialized:** %s
## Notes
<!-- Add observations, decisions, and findings below -->
`, provider, name, taskIssue, taskEpic, taskEpic, taskIssue, time.Now().Format(time.RFC3339))
if err := coreio.Local.Write(filepath.Join(agentDir, "memory.md"), memoryContent); err != nil {
return coreerr.E("agentInit", "failed to create memory.md", err)
}
// Write manifest
updateAgentManifest(agentDir, provider, name)
cli.Print("%s Agent %s/%s initialized at p%d/i%d\n",
cli.SuccessStyle.Render("Done:"),
cli.ValueStyle.Render(provider), cli.ValueStyle.Render(name),
taskEpic, taskIssue)
cli.Print("Memory: %s\n", cli.DimStyle.Render(filepath.Join(agentDir, "memory.md")))
return nil
}
func runAgentList(cmd *cli.Command, args []string) error {
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace")
}
wsPath := taskWorkspacePath(root, taskEpic, taskIssue)
agentsDir := filepath.Join(wsPath, "agents")
if !coreio.Local.IsDir(agentsDir) {
cli.Println("No agents in this workspace.")
return nil
}
providers, err := coreio.Local.List(agentsDir)
if err != nil {
return coreerr.E("agentList", "failed to list agents", err)
}
found := false
for _, providerEntry := range providers {
if !providerEntry.IsDir() {
continue
}
providerDir := filepath.Join(agentsDir, providerEntry.Name())
agents, err := coreio.Local.List(providerDir)
if err != nil {
continue
}
for _, agentEntry := range agents {
if !agentEntry.IsDir() {
continue
}
found = true
agentDir := filepath.Join(providerDir, agentEntry.Name())
// Read manifest for last_seen
lastSeen := ""
manifestPath := filepath.Join(agentDir, "manifest.json")
if data, err := coreio.Local.Read(manifestPath); err == nil {
var m AgentManifest
if json.Unmarshal([]byte(data), &m) == nil {
lastSeen = m.LastSeen.Format("2006-01-02 15:04")
}
}
// Check if memory has content beyond the template
memorySize := ""
if content, err := coreio.Local.Read(filepath.Join(agentDir, "memory.md")); err == nil {
lines := len(strings.Split(content, "\n"))
memorySize = fmt.Sprintf("%d lines", lines)
}
cli.Print(" %s/%s %s",
cli.ValueStyle.Render(providerEntry.Name()),
cli.ValueStyle.Render(agentEntry.Name()),
cli.DimStyle.Render(memorySize))
if lastSeen != "" {
cli.Print(" last: %s", cli.DimStyle.Render(lastSeen))
}
cli.Print("\n")
}
}
if !found {
cli.Println("No agents in this workspace.")
}
return nil
}
func runAgentPath(cmd *cli.Command, args []string) error {
provider, name, err := parseAgentID(args[0])
if err != nil {
return err
}
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace")
}
wsPath := taskWorkspacePath(root, taskEpic, taskIssue)
agentDir := agentContextPath(wsPath, provider, name)
if !coreio.Local.IsDir(agentDir) {
return cli.Err("agent %s/%s not initialized — run `core workspace agent init %s/%s`", provider, name, provider, name)
}
// Print just the path (useful for scripting: cd $(core workspace agent path ...))
cli.Text(agentDir)
return nil
}
func updateAgentManifest(agentDir, provider, name string) {
now := time.Now()
manifest := AgentManifest{
Provider: provider,
Name: name,
CreatedAt: now,
LastSeen: now,
}
// Try to preserve created_at from existing manifest
manifestPath := filepath.Join(agentDir, "manifest.json")
if data, err := coreio.Local.Read(manifestPath); err == nil {
var existing AgentManifest
if json.Unmarshal([]byte(data), &existing) == nil {
manifest.CreatedAt = existing.CreatedAt
}
}
data, err := json.MarshalIndent(manifest, "", " ")
if err != nil {
return
}
_ = coreio.Local.Write(manifestPath, string(data))
}

View file

@ -1,79 +0,0 @@
package workspace
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseAgentID_Good(t *testing.T) {
provider, name, err := parseAgentID("claude-opus/qa")
require.NoError(t, err)
assert.Equal(t, "claude-opus", provider)
assert.Equal(t, "qa", name)
}
func TestParseAgentID_Bad(t *testing.T) {
tests := []string{
"noslash",
"/missing-provider",
"missing-name/",
"",
}
for _, id := range tests {
_, _, err := parseAgentID(id)
assert.Error(t, err, "expected error for: %q", id)
}
}
func TestAgentContextPath(t *testing.T) {
path := agentContextPath("/ws/p101/i343", "claude-opus", "qa")
assert.Equal(t, "/ws/p101/i343/agents/claude-opus/qa", path)
}
func TestUpdateAgentManifest_Good(t *testing.T) {
tmp := t.TempDir()
agentDir := filepath.Join(tmp, "agents", "test-provider", "test-agent")
require.NoError(t, os.MkdirAll(agentDir, 0755))
updateAgentManifest(agentDir, "test-provider", "test-agent")
data, err := os.ReadFile(filepath.Join(agentDir, "manifest.json"))
require.NoError(t, err)
var m AgentManifest
require.NoError(t, json.Unmarshal(data, &m))
assert.Equal(t, "test-provider", m.Provider)
assert.Equal(t, "test-agent", m.Name)
assert.False(t, m.CreatedAt.IsZero())
assert.False(t, m.LastSeen.IsZero())
}
func TestUpdateAgentManifest_PreservesCreatedAt(t *testing.T) {
tmp := t.TempDir()
agentDir := filepath.Join(tmp, "agents", "p", "a")
require.NoError(t, os.MkdirAll(agentDir, 0755))
// First call sets created_at
updateAgentManifest(agentDir, "p", "a")
data, err := os.ReadFile(filepath.Join(agentDir, "manifest.json"))
require.NoError(t, err)
var first AgentManifest
require.NoError(t, json.Unmarshal(data, &first))
// Second call should preserve created_at
updateAgentManifest(agentDir, "p", "a")
data, err = os.ReadFile(filepath.Join(agentDir, "manifest.json"))
require.NoError(t, err)
var second AgentManifest
require.NoError(t, json.Unmarshal(data, &second))
assert.Equal(t, first.CreatedAt, second.CreatedAt)
assert.True(t, second.LastSeen.After(first.CreatedAt) || second.LastSeen.Equal(first.CreatedAt))
}

View file

@ -1,543 +0,0 @@
// cmd_prep.go implements the `workspace prep` command.
//
// Prepares an agent workspace with wiki KB, protocol specs, a TODO from a
// Forge issue, and vector-recalled context from OpenBrain. All output goes
// to .core/ in the current directory, matching the convention used by
// KBConfig (go-scm) and build/release config.
package workspace
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"forge.lthn.ai/core/agent/pkg/lifecycle"
"forge.lthn.ai/core/cli/pkg/cli"
coreio "forge.lthn.ai/core/go-io"
"forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-scm/forge"
)
var (
prepRepo string
prepIssue int
prepOrg string
prepOutput string
prepSpecsPath string
prepDryRun bool
)
func addPrepCommands(parent *cli.Command) {
prepCmd := &cli.Command{
Use: "prep",
Short: "Prepare agent workspace with wiki KB, specs, TODO, and vector context",
Long: `Fetches wiki pages from Forge, copies protocol specs, generates a task
file from a Forge issue, and queries OpenBrain for relevant context.
All output is written to .core/ in the current directory.`,
RunE: runPrep,
}
prepCmd.Flags().StringVar(&prepRepo, "repo", "", "Forge repo name (e.g. go-ai)")
prepCmd.Flags().IntVar(&prepIssue, "issue", 0, "Issue number to build TODO from")
prepCmd.Flags().StringVar(&prepOrg, "org", "core", "Forge organisation")
prepCmd.Flags().StringVar(&prepOutput, "output", "", "Output directory (default: ./.core)")
prepCmd.Flags().StringVar(&prepSpecsPath, "specs-path", "", "Path to specs dir")
prepCmd.Flags().BoolVar(&prepDryRun, "dry-run", false, "Preview without writing files")
_ = prepCmd.MarkFlagRequired("repo")
parent.AddCommand(prepCmd)
}
func runPrep(cmd *cli.Command, args []string) error {
ctx := context.Background()
// Resolve output directory
outputDir := prepOutput
if outputDir == "" {
cwd, err := os.Getwd()
if err != nil {
return cli.Err("failed to get working directory")
}
outputDir = filepath.Join(cwd, ".core")
}
// Resolve specs path
specsPath := prepSpecsPath
if specsPath == "" {
home, err := os.UserHomeDir()
if err == nil {
specsPath = filepath.Join(home, "Code", "specs")
}
}
// Resolve Forge connection
forgeURL, forgeToken, err := forge.ResolveConfig("", "")
if err != nil {
return log.E("workspace.prep", "failed to resolve Forge config", err)
}
if forgeToken == "" {
return log.E("workspace.prep", "no Forge token configured — set FORGE_TOKEN or run: core forge login", nil)
}
cli.Print("Preparing workspace for %s/%s\n", cli.ValueStyle.Render(prepOrg), cli.ValueStyle.Render(prepRepo))
cli.Print("Output: %s\n", cli.DimStyle.Render(outputDir))
if prepDryRun {
cli.Print("%s No files will be written.\n", cli.WarningStyle.Render("[DRY RUN]"))
}
fmt.Println()
// Create output directory structure
if !prepDryRun {
if err := coreio.Local.EnsureDir(filepath.Join(outputDir, "kb")); err != nil {
return log.E("workspace.prep", "failed to create kb directory", err)
}
if err := coreio.Local.EnsureDir(filepath.Join(outputDir, "specs")); err != nil {
return log.E("workspace.prep", "failed to create specs directory", err)
}
}
// Step 1: Pull wiki pages
wikiCount, err := prepPullWiki(ctx, forgeURL, forgeToken, prepOrg, prepRepo, outputDir, prepDryRun)
if err != nil {
cli.Print("%s wiki: %v\n", cli.WarningStyle.Render("warn"), err)
}
// Step 2: Copy spec files
specsCount := prepCopySpecs(specsPath, outputDir, prepDryRun)
// Step 3: Generate TODO from issue
var issueTitle, issueBody string
if prepIssue > 0 {
issueTitle, issueBody, err = prepGenerateTodo(ctx, forgeURL, forgeToken, prepOrg, prepRepo, prepIssue, outputDir, prepDryRun)
if err != nil {
cli.Print("%s todo: %v\n", cli.WarningStyle.Render("warn"), err)
prepGenerateTodoSkeleton(prepOrg, prepRepo, outputDir, prepDryRun)
}
} else {
prepGenerateTodoSkeleton(prepOrg, prepRepo, outputDir, prepDryRun)
}
// Step 4: Generate context from OpenBrain
contextCount := prepGenerateContext(ctx, prepRepo, issueTitle, issueBody, outputDir, prepDryRun)
// Summary
fmt.Println()
prefix := ""
if prepDryRun {
prefix = "[DRY RUN] "
}
cli.Print("%s%s\n", prefix, cli.SuccessStyle.Render("Workspace prep complete:"))
cli.Print(" Wiki pages: %s\n", cli.ValueStyle.Render(fmt.Sprintf("%d", wikiCount)))
cli.Print(" Spec files: %s\n", cli.ValueStyle.Render(fmt.Sprintf("%d", specsCount)))
if issueTitle != "" {
cli.Print(" TODO: %s\n", cli.ValueStyle.Render(fmt.Sprintf("from issue #%d", prepIssue)))
} else {
cli.Print(" TODO: %s\n", cli.DimStyle.Render("skeleton"))
}
cli.Print(" Context: %s\n", cli.ValueStyle.Render(fmt.Sprintf("%d memories", contextCount)))
return nil
}
// --- Step 1: Pull wiki pages from Forge API ---
type wikiPageRef struct {
Title string `json:"title"`
SubURL string `json:"sub_url"`
}
type wikiPageContent struct {
ContentBase64 string `json:"content_base64"`
}
func prepPullWiki(ctx context.Context, forgeURL, token, org, repo, outputDir string, dryRun bool) (int, error) {
cli.Print("Fetching wiki pages for %s/%s...\n", org, repo)
endpoint := fmt.Sprintf("%s/api/v1/repos/%s/%s/wiki/pages", forgeURL, org, repo)
resp, err := forgeGet(ctx, endpoint, token)
if err != nil {
return 0, log.E("workspace.prep.wiki", "API request failed", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
cli.Print(" %s No wiki found for %s\n", cli.WarningStyle.Render("warn"), repo)
if !dryRun {
content := fmt.Sprintf("# No wiki found for %s\n\nThis repo has no wiki pages on Forge.\n", repo)
_ = coreio.Local.Write(filepath.Join(outputDir, "kb", "README.md"), content)
}
return 0, nil
}
if resp.StatusCode != http.StatusOK {
return 0, log.E("workspace.prep.wiki", fmt.Sprintf("API error: %d", resp.StatusCode), nil)
}
var pages []wikiPageRef
if err := json.NewDecoder(resp.Body).Decode(&pages); err != nil {
return 0, log.E("workspace.prep.wiki", "failed to decode pages", err)
}
if len(pages) == 0 {
cli.Print(" %s Wiki exists but has no pages.\n", cli.WarningStyle.Render("warn"))
return 0, nil
}
count := 0
for _, page := range pages {
title := page.Title
if title == "" {
title = "Untitled"
}
subURL := page.SubURL
if subURL == "" {
subURL = title
}
if dryRun {
cli.Print(" [would fetch] %s\n", title)
count++
continue
}
pageEndpoint := fmt.Sprintf("%s/api/v1/repos/%s/%s/wiki/page/%s",
forgeURL, org, repo, url.PathEscape(subURL))
pageResp, err := forgeGet(ctx, pageEndpoint, token)
if err != nil || pageResp.StatusCode != http.StatusOK {
cli.Print(" %s Failed to fetch: %s\n", cli.WarningStyle.Render("warn"), title)
if pageResp != nil {
pageResp.Body.Close()
}
continue
}
var pageData wikiPageContent
if err := json.NewDecoder(pageResp.Body).Decode(&pageData); err != nil {
pageResp.Body.Close()
continue
}
pageResp.Body.Close()
if pageData.ContentBase64 == "" {
continue
}
decoded, err := base64.StdEncoding.DecodeString(pageData.ContentBase64)
if err != nil {
continue
}
filename := sanitiseFilename(title) + ".md"
_ = coreio.Local.Write(filepath.Join(outputDir, "kb", filename), string(decoded))
cli.Print(" %s\n", title)
count++
}
cli.Print(" %d wiki page(s) saved to kb/\n", count)
return count, nil
}
// --- Step 2: Copy protocol spec files ---
func prepCopySpecs(specsPath, outputDir string, dryRun bool) int {
cli.Print("Copying spec files...\n")
specFiles := []string{"AGENT_CONTEXT.md", "TASK_PROTOCOL.md"}
count := 0
for _, file := range specFiles {
source := filepath.Join(specsPath, file)
if !coreio.Local.IsFile(source) {
cli.Print(" %s Not found: %s\n", cli.WarningStyle.Render("warn"), source)
continue
}
if dryRun {
cli.Print(" [would copy] %s\n", file)
count++
continue
}
content, err := coreio.Local.Read(source)
if err != nil {
cli.Print(" %s Failed to read: %s\n", cli.WarningStyle.Render("warn"), file)
continue
}
dest := filepath.Join(outputDir, "specs", file)
if err := coreio.Local.Write(dest, content); err != nil {
cli.Print(" %s Failed to write: %s\n", cli.WarningStyle.Render("warn"), file)
continue
}
cli.Print(" %s\n", file)
count++
}
cli.Print(" %d spec file(s) copied.\n", count)
return count
}
// --- Step 3: Generate TODO from Forge issue ---
type forgeIssue struct {
Title string `json:"title"`
Body string `json:"body"`
}
func prepGenerateTodo(ctx context.Context, forgeURL, token, org, repo string, issueNum int, outputDir string, dryRun bool) (string, string, error) {
cli.Print("Generating TODO from issue #%d...\n", issueNum)
endpoint := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", forgeURL, org, repo, issueNum)
resp, err := forgeGet(ctx, endpoint, token)
if err != nil {
return "", "", log.E("workspace.prep.todo", "issue API request failed", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", "", log.E("workspace.prep.todo", fmt.Sprintf("failed to fetch issue #%d: %d", issueNum, resp.StatusCode), nil)
}
var issue forgeIssue
if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil {
return "", "", log.E("workspace.prep.todo", "failed to decode issue", err)
}
title := issue.Title
if title == "" {
title = "Untitled"
}
objective := extractObjective(issue.Body)
checklist := extractChecklist(issue.Body)
var b strings.Builder
fmt.Fprintf(&b, "# TASK: %s\n\n", title)
fmt.Fprintf(&b, "**Status:** ready\n")
fmt.Fprintf(&b, "**Source:** %s/%s/%s/issues/%d\n", forgeURL, org, repo, issueNum)
fmt.Fprintf(&b, "**Created:** %s\n", time.Now().Format("2006-01-02 15:04:05"))
fmt.Fprintf(&b, "**Repo:** %s/%s\n", org, repo)
b.WriteString("\n---\n\n")
fmt.Fprintf(&b, "## Objective\n\n%s\n", objective)
b.WriteString("\n---\n\n")
b.WriteString("## Acceptance Criteria\n\n")
if len(checklist) > 0 {
for _, item := range checklist {
fmt.Fprintf(&b, "- [ ] %s\n", item)
}
} else {
b.WriteString("_No checklist items found in issue. Agent should define acceptance criteria._\n")
}
b.WriteString("\n---\n\n")
b.WriteString("## Implementation Checklist\n\n")
b.WriteString("_To be filled by the agent during planning._\n")
b.WriteString("\n---\n\n")
b.WriteString("## Notes\n\n")
b.WriteString("Full issue body preserved below for reference.\n\n")
b.WriteString("<details>\n<summary>Original Issue</summary>\n\n")
b.WriteString(issue.Body)
b.WriteString("\n\n</details>\n")
if dryRun {
cli.Print(" [would write] todo.md from: %s\n", title)
} else {
if err := coreio.Local.Write(filepath.Join(outputDir, "todo.md"), b.String()); err != nil {
return title, issue.Body, log.E("workspace.prep.todo", "failed to write todo.md", err)
}
cli.Print(" todo.md generated from: %s\n", title)
}
return title, issue.Body, nil
}
func prepGenerateTodoSkeleton(org, repo, outputDir string, dryRun bool) {
var b strings.Builder
b.WriteString("# TASK: [Define task]\n\n")
fmt.Fprintf(&b, "**Status:** ready\n")
fmt.Fprintf(&b, "**Created:** %s\n", time.Now().Format("2006-01-02 15:04:05"))
fmt.Fprintf(&b, "**Repo:** %s/%s\n", org, repo)
b.WriteString("\n---\n\n")
b.WriteString("## Objective\n\n_Define the objective._\n")
b.WriteString("\n---\n\n")
b.WriteString("## Acceptance Criteria\n\n- [ ] _Define criteria_\n")
b.WriteString("\n---\n\n")
b.WriteString("## Implementation Checklist\n\n_To be filled by the agent._\n")
if dryRun {
cli.Print(" [would write] todo.md skeleton\n")
} else {
_ = coreio.Local.Write(filepath.Join(outputDir, "todo.md"), b.String())
cli.Print(" todo.md skeleton generated (no --issue provided)\n")
}
}
// --- Step 4: Generate context from OpenBrain ---
func prepGenerateContext(ctx context.Context, repo, issueTitle, issueBody, outputDir string, dryRun bool) int {
cli.Print("Querying vector DB for context...\n")
apiURL := os.Getenv("CORE_API_URL")
if apiURL == "" {
apiURL = "http://localhost:8000"
}
apiToken := os.Getenv("CORE_API_TOKEN")
client := lifecycle.NewClient(apiURL, apiToken)
// Query 1: Repo-specific knowledge
repoResult, err := client.Recall(ctx, lifecycle.RecallRequest{
Query: "How does " + repo + " work? Architecture and key interfaces.",
TopK: 10,
Project: repo,
})
if err != nil {
cli.Print(" %s BrainService unavailable: %v\n", cli.WarningStyle.Render("warn"), err)
writeBrainUnavailable(repo, outputDir, dryRun)
return 0
}
repoMemories := repoResult.Memories
repoScores := repoResult.Scores
// Query 2: Issue-specific context
var issueMemories []lifecycle.Memory
var issueScores map[string]float64
if issueTitle != "" {
query := issueTitle
if len(issueBody) > 500 {
query += " " + issueBody[:500]
} else if issueBody != "" {
query += " " + issueBody
}
issueResult, err := client.Recall(ctx, lifecycle.RecallRequest{
Query: query,
TopK: 5,
})
if err == nil {
issueMemories = issueResult.Memories
issueScores = issueResult.Scores
}
}
totalMemories := len(repoMemories) + len(issueMemories)
var b strings.Builder
fmt.Fprintf(&b, "# Agent Context — %s\n\n", repo)
b.WriteString("> Auto-generated by `core workspace prep`. Query the vector DB for more.\n\n")
b.WriteString("## Repo Knowledge\n\n")
if len(repoMemories) > 0 {
for i, mem := range repoMemories {
score := repoScores[mem.ID]
project := mem.Project
if project == "" {
project = "unknown"
}
memType := mem.Type
if memType == "" {
memType = "memory"
}
fmt.Fprintf(&b, "### %d. %s [%s] (score: %.3f)\n\n", i+1, project, memType, score)
fmt.Fprintf(&b, "%s\n\n", mem.Content)
}
} else {
b.WriteString("_No repo-specific memories found. The vector DB may not have been seeded for this repo._\n\n")
}
b.WriteString("## Task-Relevant Context\n\n")
if len(issueMemories) > 0 {
for i, mem := range issueMemories {
score := issueScores[mem.ID]
project := mem.Project
if project == "" {
project = "unknown"
}
memType := mem.Type
if memType == "" {
memType = "memory"
}
fmt.Fprintf(&b, "### %d. %s [%s] (score: %.3f)\n\n", i+1, project, memType, score)
fmt.Fprintf(&b, "%s\n\n", mem.Content)
}
} else if issueTitle != "" {
b.WriteString("_No task-relevant memories found._\n\n")
} else {
b.WriteString("_No issue provided — skipped task-specific recall._\n\n")
}
if dryRun {
cli.Print(" [would write] context.md with %d memories\n", totalMemories)
} else {
_ = coreio.Local.Write(filepath.Join(outputDir, "context.md"), b.String())
cli.Print(" context.md generated with %d memories\n", totalMemories)
}
return totalMemories
}
func writeBrainUnavailable(repo, outputDir string, dryRun bool) {
var b strings.Builder
fmt.Fprintf(&b, "# Agent Context — %s\n\n", repo)
b.WriteString("> Vector DB was unavailable when this workspace was prepared.\n")
b.WriteString("> Run `core workspace prep` again once Ollama/Qdrant are reachable.\n")
if !dryRun {
_ = coreio.Local.Write(filepath.Join(outputDir, "context.md"), b.String())
}
}
// --- Helpers ---
func forgeGet(ctx context.Context, endpoint, token string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "token "+token)
client := &http.Client{Timeout: 30 * time.Second}
return client.Do(req)
}
var nonAlphanumeric = regexp.MustCompile(`[^a-zA-Z0-9_\-.]`)
func sanitiseFilename(title string) string {
return nonAlphanumeric.ReplaceAllString(title, "-")
}
func extractObjective(body string) string {
if body == "" {
return "_No description provided._"
}
parts := strings.SplitN(body, "\n\n", 2)
first := strings.TrimSpace(parts[0])
if len(first) > 500 {
return first[:497] + "..."
}
return first
}
func extractChecklist(body string) []string {
re := regexp.MustCompile(`- \[[ xX]\] (.+)`)
matches := re.FindAllStringSubmatch(body, -1)
var items []string
for _, m := range matches {
items = append(items, strings.TrimSpace(m[1]))
}
return items
}

View file

@ -1,465 +0,0 @@
// cmd_task.go implements task workspace isolation using git worktrees.
//
// Each task gets an isolated workspace at .core/workspace/p{epic}/i{issue}/
// containing git worktrees of required repos. This prevents agents from
// writing to the implementor's working tree.
//
// Safety checks enforce that workspaces cannot be removed if they contain
// uncommitted changes or unpushed branches.
package workspace
import (
"context"
"fmt"
"os/exec"
"path/filepath"
"strconv"
"strings"
"forge.lthn.ai/core/cli/pkg/cli"
coreio "forge.lthn.ai/core/go-io"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-scm/repos"
)
var (
taskEpic int
taskIssue int
taskRepos []string
taskForce bool
taskBranch string
)
func addTaskCommands(parent *cli.Command) {
taskCmd := &cli.Command{
Use: "task",
Short: "Manage isolated task workspaces for agents",
}
createCmd := &cli.Command{
Use: "create",
Short: "Create an isolated task workspace with git worktrees",
Long: `Creates a workspace at .core/workspace/p{epic}/i{issue}/ with git
worktrees for each specified repo. Each worktree gets a fresh branch
(issue/{id} by default) so agents work in isolation.`,
RunE: runTaskCreate,
}
createCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number")
createCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number")
createCmd.Flags().StringSliceVar(&taskRepos, "repo", nil, "Repos to include (default: all from registry)")
createCmd.Flags().StringVar(&taskBranch, "branch", "", "Branch name (default: issue/{issue})")
_ = createCmd.MarkFlagRequired("epic")
_ = createCmd.MarkFlagRequired("issue")
removeCmd := &cli.Command{
Use: "remove",
Short: "Remove a task workspace (with safety checks)",
Long: `Removes a task workspace after checking for uncommitted changes and
unpushed branches. Use --force to skip safety checks.`,
RunE: runTaskRemove,
}
removeCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number")
removeCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number")
removeCmd.Flags().BoolVar(&taskForce, "force", false, "Skip safety checks")
_ = removeCmd.MarkFlagRequired("epic")
_ = removeCmd.MarkFlagRequired("issue")
listCmd := &cli.Command{
Use: "list",
Short: "List all task workspaces",
RunE: runTaskList,
}
statusCmd := &cli.Command{
Use: "status",
Short: "Show status of a task workspace",
RunE: runTaskStatus,
}
statusCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number")
statusCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number")
_ = statusCmd.MarkFlagRequired("epic")
_ = statusCmd.MarkFlagRequired("issue")
addAgentCommands(taskCmd)
taskCmd.AddCommand(createCmd, removeCmd, listCmd, statusCmd)
parent.AddCommand(taskCmd)
}
// taskWorkspacePath returns the path for a task workspace.
func taskWorkspacePath(root string, epic, issue int) string {
return filepath.Join(root, ".core", "workspace", fmt.Sprintf("p%d", epic), fmt.Sprintf("i%d", issue))
}
func runTaskCreate(cmd *cli.Command, args []string) error {
ctx := context.Background()
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace — run from workspace root or a package directory")
}
wsPath := taskWorkspacePath(root, taskEpic, taskIssue)
if coreio.Local.IsDir(wsPath) {
return cli.Err("task workspace already exists: %s", wsPath)
}
branch := taskBranch
if branch == "" {
branch = fmt.Sprintf("issue/%d", taskIssue)
}
// Determine repos to include
repoNames := taskRepos
if len(repoNames) == 0 {
repoNames, err = registryRepoNames(root)
if err != nil {
return coreerr.E("taskCreate", "failed to load registry", err)
}
}
if len(repoNames) == 0 {
return cli.Err("no repos specified and no registry found")
}
// Resolve package paths
config, _ := LoadConfig(root)
pkgDir := "./packages"
if config != nil && config.PackagesDir != "" {
pkgDir = config.PackagesDir
}
if !filepath.IsAbs(pkgDir) {
pkgDir = filepath.Join(root, pkgDir)
}
if err := coreio.Local.EnsureDir(wsPath); err != nil {
return coreerr.E("taskCreate", "failed to create workspace directory", err)
}
cli.Print("Creating task workspace: %s\n", cli.ValueStyle.Render(fmt.Sprintf("p%d/i%d", taskEpic, taskIssue)))
cli.Print("Branch: %s\n", cli.ValueStyle.Render(branch))
cli.Print("Path: %s\n\n", cli.DimStyle.Render(wsPath))
var created, skipped int
for _, repoName := range repoNames {
repoPath := filepath.Join(pkgDir, repoName)
if !coreio.Local.IsDir(filepath.Join(repoPath, ".git")) {
cli.Print(" %s %s (not cloned, skipping)\n", cli.DimStyle.Render("·"), repoName)
skipped++
continue
}
worktreePath := filepath.Join(wsPath, repoName)
cli.Print(" %s %s... ", cli.DimStyle.Render("·"), repoName)
if err := createWorktree(ctx, repoPath, worktreePath, branch); err != nil {
cli.Print("%s\n", cli.ErrorStyle.Render("x "+err.Error()))
skipped++
continue
}
cli.Print("%s\n", cli.SuccessStyle.Render("ok"))
created++
}
cli.Print("\n%s %d worktrees created", cli.SuccessStyle.Render("Done:"), created)
if skipped > 0 {
cli.Print(", %d skipped", skipped)
}
cli.Print("\n")
return nil
}
func runTaskRemove(cmd *cli.Command, args []string) error {
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace")
}
wsPath := taskWorkspacePath(root, taskEpic, taskIssue)
if !coreio.Local.IsDir(wsPath) {
return cli.Err("task workspace does not exist: p%d/i%d", taskEpic, taskIssue)
}
if !taskForce {
dirty, reasons := checkWorkspaceSafety(wsPath)
if dirty {
cli.Print("%s Cannot remove workspace p%d/i%d:\n", cli.ErrorStyle.Render("Blocked:"), taskEpic, taskIssue)
for _, r := range reasons {
cli.Print(" %s %s\n", cli.ErrorStyle.Render("·"), r)
}
cli.Print("\nUse --force to override or resolve the issues first.\n")
return coreerr.E("taskRemove", "workspace has unresolved changes", nil)
}
}
// Remove worktrees first (so git knows they're gone)
entries, err := coreio.Local.List(wsPath)
if err != nil {
return coreerr.E("taskRemove", "failed to list workspace", err)
}
config, _ := LoadConfig(root)
pkgDir := "./packages"
if config != nil && config.PackagesDir != "" {
pkgDir = config.PackagesDir
}
if !filepath.IsAbs(pkgDir) {
pkgDir = filepath.Join(root, pkgDir)
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
worktreePath := filepath.Join(wsPath, entry.Name())
repoPath := filepath.Join(pkgDir, entry.Name())
// Remove worktree from git
if coreio.Local.IsDir(filepath.Join(repoPath, ".git")) {
removeWorktree(repoPath, worktreePath)
}
}
// Remove the workspace directory
if err := coreio.Local.DeleteAll(wsPath); err != nil {
return coreerr.E("taskRemove", "failed to remove workspace directory", err)
}
// Clean up empty parent (p{epic}/) if it's now empty
epicDir := filepath.Dir(wsPath)
if entries, err := coreio.Local.List(epicDir); err == nil && len(entries) == 0 {
coreio.Local.DeleteAll(epicDir)
}
cli.Print("%s Removed workspace p%d/i%d\n", cli.SuccessStyle.Render("Done:"), taskEpic, taskIssue)
return nil
}
func runTaskList(cmd *cli.Command, args []string) error {
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace")
}
wsRoot := filepath.Join(root, ".core", "workspace")
if !coreio.Local.IsDir(wsRoot) {
cli.Println("No task workspaces found.")
return nil
}
epics, err := coreio.Local.List(wsRoot)
if err != nil {
return coreerr.E("taskList", "failed to list workspaces", err)
}
found := false
for _, epicEntry := range epics {
if !epicEntry.IsDir() || !strings.HasPrefix(epicEntry.Name(), "p") {
continue
}
epicDir := filepath.Join(wsRoot, epicEntry.Name())
issues, err := coreio.Local.List(epicDir)
if err != nil {
continue
}
for _, issueEntry := range issues {
if !issueEntry.IsDir() || !strings.HasPrefix(issueEntry.Name(), "i") {
continue
}
found = true
wsPath := filepath.Join(epicDir, issueEntry.Name())
// Count worktrees
entries, _ := coreio.Local.List(wsPath)
dirCount := 0
for _, e := range entries {
if e.IsDir() {
dirCount++
}
}
// Check safety
dirty, _ := checkWorkspaceSafety(wsPath)
status := cli.SuccessStyle.Render("clean")
if dirty {
status = cli.ErrorStyle.Render("dirty")
}
cli.Print(" %s/%s %d repos %s\n",
epicEntry.Name(), issueEntry.Name(),
dirCount, status)
}
}
if !found {
cli.Println("No task workspaces found.")
}
return nil
}
func runTaskStatus(cmd *cli.Command, args []string) error {
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace")
}
wsPath := taskWorkspacePath(root, taskEpic, taskIssue)
if !coreio.Local.IsDir(wsPath) {
return cli.Err("task workspace does not exist: p%d/i%d", taskEpic, taskIssue)
}
cli.Print("Workspace: %s\n", cli.ValueStyle.Render(fmt.Sprintf("p%d/i%d", taskEpic, taskIssue)))
cli.Print("Path: %s\n\n", cli.DimStyle.Render(wsPath))
entries, err := coreio.Local.List(wsPath)
if err != nil {
return coreerr.E("taskStatus", "failed to list workspace", err)
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
worktreePath := filepath.Join(wsPath, entry.Name())
// Get branch
branch := gitOutput(worktreePath, "rev-parse", "--abbrev-ref", "HEAD")
branch = strings.TrimSpace(branch)
// Get status
status := gitOutput(worktreePath, "status", "--porcelain")
statusLabel := cli.SuccessStyle.Render("clean")
if strings.TrimSpace(status) != "" {
lines := len(strings.Split(strings.TrimSpace(status), "\n"))
statusLabel = cli.ErrorStyle.Render(fmt.Sprintf("%d changes", lines))
}
// Get unpushed
unpushed := gitOutput(worktreePath, "log", "--oneline", "@{u}..HEAD")
unpushedLabel := ""
if trimmed := strings.TrimSpace(unpushed); trimmed != "" {
count := len(strings.Split(trimmed, "\n"))
unpushedLabel = cli.WarningStyle.Render(fmt.Sprintf(" %d unpushed", count))
}
cli.Print(" %s %s %s%s\n",
cli.RepoStyle.Render(entry.Name()),
cli.DimStyle.Render(branch),
statusLabel,
unpushedLabel)
}
return nil
}
// createWorktree adds a git worktree at worktreePath for the given branch.
func createWorktree(ctx context.Context, repoPath, worktreePath, branch string) error {
// Check if branch exists on remote first
cmd := exec.CommandContext(ctx, "git", "worktree", "add", "-b", branch, worktreePath)
cmd.Dir = repoPath
output, err := cmd.CombinedOutput()
if err != nil {
errStr := strings.TrimSpace(string(output))
// If branch already exists, try without -b
if strings.Contains(errStr, "already exists") {
cmd = exec.CommandContext(ctx, "git", "worktree", "add", worktreePath, branch)
cmd.Dir = repoPath
output, err = cmd.CombinedOutput()
if err != nil {
return coreerr.E("createWorktree", strings.TrimSpace(string(output)), nil)
}
return nil
}
return coreerr.E("createWorktree", errStr, nil)
}
return nil
}
// removeWorktree removes a git worktree.
func removeWorktree(repoPath, worktreePath string) {
cmd := exec.Command("git", "worktree", "remove", worktreePath)
cmd.Dir = repoPath
_ = cmd.Run()
// Prune stale worktrees
cmd = exec.Command("git", "worktree", "prune")
cmd.Dir = repoPath
_ = cmd.Run()
}
// checkWorkspaceSafety checks all worktrees in a workspace for uncommitted/unpushed changes.
func checkWorkspaceSafety(wsPath string) (dirty bool, reasons []string) {
entries, err := coreio.Local.List(wsPath)
if err != nil {
return false, nil
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
worktreePath := filepath.Join(wsPath, entry.Name())
// Check for uncommitted changes
status := gitOutput(worktreePath, "status", "--porcelain")
if strings.TrimSpace(status) != "" {
dirty = true
reasons = append(reasons, fmt.Sprintf("%s: has uncommitted changes", entry.Name()))
}
// Check for unpushed commits
unpushed := gitOutput(worktreePath, "log", "--oneline", "@{u}..HEAD")
if strings.TrimSpace(unpushed) != "" {
dirty = true
count := len(strings.Split(strings.TrimSpace(unpushed), "\n"))
reasons = append(reasons, fmt.Sprintf("%s: %d unpushed commits", entry.Name(), count))
}
}
return dirty, reasons
}
// gitOutput runs a git command and returns stdout.
func gitOutput(dir string, args ...string) string {
cmd := exec.Command("git", args...)
cmd.Dir = dir
out, _ := cmd.Output()
return string(out)
}
// registryRepoNames returns repo names from the workspace registry.
func registryRepoNames(root string) ([]string, error) {
// Try to find repos.yaml
regPath, err := repos.FindRegistry(coreio.Local)
if err != nil {
return nil, err
}
reg, err := repos.LoadRegistry(coreio.Local, regPath)
if err != nil {
return nil, err
}
var names []string
for _, repo := range reg.List() {
// Only include cloneable repos
if repo.Clone != nil && !*repo.Clone {
continue
}
// Skip meta repos
if repo.Type == "meta" {
continue
}
names = append(names, repo.Name)
}
return names, nil
}
// epicBranchName returns the branch name for an EPIC.
func epicBranchName(epicID int) string {
return "epic/" + strconv.Itoa(epicID)
}

View file

@ -1,109 +0,0 @@
package workspace
import (
"os"
"os/exec"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestRepo(t *testing.T, dir, name string) string {
t.Helper()
repoPath := filepath.Join(dir, name)
require.NoError(t, os.MkdirAll(repoPath, 0755))
cmds := [][]string{
{"git", "init"},
{"git", "config", "user.email", "test@test.com"},
{"git", "config", "user.name", "Test"},
{"git", "commit", "--allow-empty", "-m", "initial"},
}
for _, c := range cmds {
cmd := exec.Command(c[0], c[1:]...)
cmd.Dir = repoPath
out, err := cmd.CombinedOutput()
require.NoError(t, err, "cmd %v failed: %s", c, string(out))
}
return repoPath
}
func TestTaskWorkspacePath(t *testing.T) {
path := taskWorkspacePath("/home/user/Code/host-uk", 101, 343)
assert.Equal(t, "/home/user/Code/host-uk/.core/workspace/p101/i343", path)
}
func TestCreateWorktree_Good(t *testing.T) {
tmp := t.TempDir()
repoPath := setupTestRepo(t, tmp, "test-repo")
worktreePath := filepath.Join(tmp, "workspace", "test-repo")
err := createWorktree(t.Context(), repoPath, worktreePath, "issue/123")
require.NoError(t, err)
// Verify worktree exists
assert.DirExists(t, worktreePath)
assert.FileExists(t, filepath.Join(worktreePath, ".git"))
// Verify branch
branch := gitOutput(worktreePath, "rev-parse", "--abbrev-ref", "HEAD")
assert.Equal(t, "issue/123", trimNL(branch))
}
func TestCreateWorktree_BranchExists(t *testing.T) {
tmp := t.TempDir()
repoPath := setupTestRepo(t, tmp, "test-repo")
// Create branch first
cmd := exec.Command("git", "branch", "issue/456")
cmd.Dir = repoPath
require.NoError(t, cmd.Run())
worktreePath := filepath.Join(tmp, "workspace", "test-repo")
err := createWorktree(t.Context(), repoPath, worktreePath, "issue/456")
require.NoError(t, err)
assert.DirExists(t, worktreePath)
}
func TestCheckWorkspaceSafety_Clean(t *testing.T) {
tmp := t.TempDir()
wsPath := filepath.Join(tmp, "workspace")
require.NoError(t, os.MkdirAll(wsPath, 0755))
repoPath := setupTestRepo(t, tmp, "origin-repo")
worktreePath := filepath.Join(wsPath, "origin-repo")
require.NoError(t, createWorktree(t.Context(), repoPath, worktreePath, "test-branch"))
dirty, reasons := checkWorkspaceSafety(wsPath)
assert.False(t, dirty)
assert.Empty(t, reasons)
}
func TestCheckWorkspaceSafety_Dirty(t *testing.T) {
tmp := t.TempDir()
wsPath := filepath.Join(tmp, "workspace")
require.NoError(t, os.MkdirAll(wsPath, 0755))
repoPath := setupTestRepo(t, tmp, "origin-repo")
worktreePath := filepath.Join(wsPath, "origin-repo")
require.NoError(t, createWorktree(t.Context(), repoPath, worktreePath, "test-branch"))
// Create uncommitted file
require.NoError(t, os.WriteFile(filepath.Join(worktreePath, "dirty.txt"), []byte("dirty"), 0644))
dirty, reasons := checkWorkspaceSafety(wsPath)
assert.True(t, dirty)
assert.Contains(t, reasons[0], "uncommitted changes")
}
func TestEpicBranchName(t *testing.T) {
assert.Equal(t, "epic/101", epicBranchName(101))
assert.Equal(t, "epic/42", epicBranchName(42))
}
func trimNL(s string) string {
return s[:len(s)-1]
}

View file

@ -1,90 +0,0 @@
package workspace
import (
"strings"
"forge.lthn.ai/core/cli/pkg/cli"
)
// AddWorkspaceCommands registers workspace management commands.
func AddWorkspaceCommands(root *cli.Command) {
wsCmd := &cli.Command{
Use: "workspace",
Short: "Manage workspace configuration",
RunE: runWorkspaceInfo,
}
wsCmd.AddCommand(&cli.Command{
Use: "active [package]",
Short: "Show or set the active package",
RunE: runWorkspaceActive,
})
addTaskCommands(wsCmd)
addPrepCommands(wsCmd)
root.AddCommand(wsCmd)
}
func runWorkspaceInfo(cmd *cli.Command, args []string) error {
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace")
}
config, err := LoadConfig(root)
if err != nil {
return err
}
if config == nil {
return cli.Err("workspace config not found")
}
cli.Print("Active: %s\n", cli.ValueStyle.Render(config.Active))
cli.Print("Packages: %s\n", cli.DimStyle.Render(config.PackagesDir))
if len(config.DefaultOnly) > 0 {
cli.Print("Types: %s\n", cli.DimStyle.Render(strings.Join(config.DefaultOnly, ", ")))
}
return nil
}
func runWorkspaceActive(cmd *cli.Command, args []string) error {
root, err := FindWorkspaceRoot()
if err != nil {
return cli.Err("not in a workspace")
}
config, err := LoadConfig(root)
if err != nil {
return err
}
if config == nil {
config = DefaultConfig()
}
// If no args, show active
if len(args) == 0 {
if config.Active == "" {
cli.Println("No active package set")
return nil
}
cli.Text(config.Active)
return nil
}
// Set active
target := args[0]
if target == config.Active {
cli.Print("Active package is already %s\n", cli.ValueStyle.Render(target))
return nil
}
config.Active = target
if err := SaveConfig(root, config); err != nil {
return err
}
cli.Print("Active package set to %s\n", cli.SuccessStyle.Render(target))
return nil
}

View file

@ -1,103 +0,0 @@
package workspace
import (
"os"
"path/filepath"
coreio "forge.lthn.ai/core/go-io"
coreerr "forge.lthn.ai/core/go-log"
"gopkg.in/yaml.v3"
)
// WorkspaceConfig holds workspace-level configuration from .core/workspace.yaml.
type WorkspaceConfig struct {
Version int `yaml:"version"`
Active string `yaml:"active"` // Active package name
DefaultOnly []string `yaml:"default_only"` // Default types for setup
PackagesDir string `yaml:"packages_dir"` // Where packages are cloned
}
// DefaultConfig returns a config with default values.
func DefaultConfig() *WorkspaceConfig {
return &WorkspaceConfig{
Version: 1,
PackagesDir: "./packages",
}
}
// LoadConfig tries to load workspace.yaml from the given directory's .core subfolder.
// Returns nil if no config file exists (caller should check for nil).
func LoadConfig(dir string) (*WorkspaceConfig, error) {
path := filepath.Join(dir, ".core", "workspace.yaml")
data, err := coreio.Local.Read(path)
if err != nil {
// If using Local.Read, it returns error on not found.
// We can check if file exists first or handle specific error if exposed.
// Simplest is to check existence first or assume IsNotExist.
// Since we don't have easy IsNotExist check on coreio error returned yet (uses wrapped error),
// let's check IsFile first.
if !coreio.Local.IsFile(path) {
// Try parent directory
parent := filepath.Dir(dir)
if parent != dir {
return LoadConfig(parent)
}
// No workspace.yaml found anywhere - return nil to indicate no config
return nil, nil
}
return nil, coreerr.E("LoadConfig", "failed to read workspace config", err)
}
config := DefaultConfig()
if err := yaml.Unmarshal([]byte(data), config); err != nil {
return nil, coreerr.E("LoadConfig", "failed to parse workspace config", err)
}
if config.Version != 1 {
return nil, coreerr.E("LoadConfig", "unsupported workspace config version", nil)
}
return config, nil
}
// SaveConfig saves the configuration to the given directory's .core/workspace.yaml.
func SaveConfig(dir string, config *WorkspaceConfig) error {
coreDir := filepath.Join(dir, ".core")
if err := coreio.Local.EnsureDir(coreDir); err != nil {
return coreerr.E("SaveConfig", "failed to create .core directory", err)
}
path := filepath.Join(coreDir, "workspace.yaml")
data, err := yaml.Marshal(config)
if err != nil {
return coreerr.E("SaveConfig", "failed to marshal workspace config", err)
}
if err := coreio.Local.Write(path, string(data)); err != nil {
return coreerr.E("SaveConfig", "failed to write workspace config", err)
}
return nil
}
// FindWorkspaceRoot searches for the root directory containing .core/workspace.yaml.
func FindWorkspaceRoot() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
if coreio.Local.IsFile(filepath.Join(dir, ".core", "workspace.yaml")) {
return dir, nil
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
return "", coreerr.E("FindWorkspaceRoot", "not in a workspace", nil)
}

View file

@ -1,395 +0,0 @@
package jobrunner
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Journal: NewJournal error path ---
func TestNewJournal_Bad_EmptyBaseDir(t *testing.T) {
_, err := NewJournal("")
require.Error(t, err)
assert.Contains(t, err.Error(), "base directory is required")
}
func TestNewJournal_Good(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
assert.NotNil(t, j)
}
// --- Journal: sanitizePathComponent additional cases ---
func TestSanitizePathComponent_Good_ValidNames(t *testing.T) {
tests := []struct {
input string
want string
}{
{"host-uk", "host-uk"},
{"core", "core"},
{"my_repo", "my_repo"},
{"repo.v2", "repo.v2"},
{"A123", "A123"},
}
for _, tc := range tests {
got, err := sanitizePathComponent(tc.input)
require.NoError(t, err, "input: %q", tc.input)
assert.Equal(t, tc.want, got)
}
}
func TestSanitizePathComponent_Bad_Invalid(t *testing.T) {
tests := []struct {
name string
input string
}{
{"empty", ""},
{"spaces", " "},
{"dotdot", ".."},
{"dot", "."},
{"slash", "foo/bar"},
{"backslash", `foo\bar`},
{"special", "org$bad"},
{"leading-dot", ".hidden"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := sanitizePathComponent(tc.input)
assert.Error(t, err, "input: %q", tc.input)
})
}
}
// --- Journal: Append with readonly directory ---
func TestJournal_Append_Bad_ReadonlyDir(t *testing.T) {
if os.Getuid() == 0 {
t.Skip("chmod does not restrict root")
}
// Create a dir that we then make readonly (only works as non-root).
dir := t.TempDir()
readonlyDir := filepath.Join(dir, "readonly")
require.NoError(t, os.MkdirAll(readonlyDir, 0o755))
require.NoError(t, os.Chmod(readonlyDir, 0o444))
t.Cleanup(func() { _ = os.Chmod(readonlyDir, 0o755) })
j, err := NewJournal(readonlyDir)
require.NoError(t, err)
signal := &PipelineSignal{
RepoOwner: "test-owner",
RepoName: "test-repo",
}
result := &ActionResult{
Action: "test",
Timestamp: time.Now(),
}
err = j.Append(signal, result)
// Should fail because MkdirAll cannot create subdirectories in readonly dir.
assert.Error(t, err)
}
// --- Poller: error-returning source ---
type errorSource struct {
name string
}
func (e *errorSource) Name() string { return e.name }
func (e *errorSource) Poll(_ context.Context) ([]*PipelineSignal, error) {
return nil, fmt.Errorf("poll error")
}
func (e *errorSource) Report(_ context.Context, _ *ActionResult) error { return nil }
func TestPoller_RunOnce_Good_SourceError(t *testing.T) {
src := &errorSource{name: "broken-source"}
handler := &mockHandler{name: "test"}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
})
err := p.RunOnce(context.Background())
require.NoError(t, err) // Poll errors are logged, not returned
handler.mu.Lock()
defer handler.mu.Unlock()
assert.Empty(t, handler.executed, "handler should not be called when poll fails")
}
// --- Poller: error-returning handler ---
type errorHandler struct {
name string
}
func (e *errorHandler) Name() string { return e.name }
func (e *errorHandler) Match(_ *PipelineSignal) bool { return true }
func (e *errorHandler) Execute(_ context.Context, _ *PipelineSignal) (*ActionResult, error) {
return nil, fmt.Errorf("handler error")
}
func TestPoller_RunOnce_Good_HandlerError(t *testing.T) {
sig := &PipelineSignal{
EpicNumber: 1,
ChildNumber: 1,
PRNumber: 1,
RepoOwner: "test",
RepoName: "repo",
}
src := &mockSource{
name: "test-source",
signals: []*PipelineSignal{sig},
}
handler := &errorHandler{name: "broken-handler"}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
})
err := p.RunOnce(context.Background())
require.NoError(t, err) // Handler errors are logged, not returned
// Source should not have received a report (handler errored out).
src.mu.Lock()
defer src.mu.Unlock()
assert.Empty(t, src.reports)
}
// --- Poller: with Journal integration ---
func TestPoller_RunOnce_Good_WithJournal(t *testing.T) {
dir := t.TempDir()
journal, err := NewJournal(dir)
require.NoError(t, err)
sig := &PipelineSignal{
EpicNumber: 10,
ChildNumber: 3,
PRNumber: 55,
RepoOwner: "host-uk",
RepoName: "core",
PRState: "OPEN",
CheckStatus: "SUCCESS",
Mergeable: "MERGEABLE",
}
src := &mockSource{
name: "test-source",
signals: []*PipelineSignal{sig},
}
handler := &mockHandler{
name: "test-handler",
matchFn: func(s *PipelineSignal) bool {
return true
},
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
Journal: journal,
})
err = p.RunOnce(context.Background())
require.NoError(t, err)
handler.mu.Lock()
require.Len(t, handler.executed, 1)
handler.mu.Unlock()
// Verify the journal file was written.
date := time.Now().UTC().Format("2006-01-02")
journalPath := filepath.Join(dir, "host-uk", "core", date+".jsonl")
_, statErr := os.Stat(journalPath)
assert.NoError(t, statErr, "journal file should exist at %s", journalPath)
}
// --- Poller: error-returning Report ---
type reportErrorSource struct {
name string
signals []*PipelineSignal
mu sync.Mutex
}
func (r *reportErrorSource) Name() string { return r.name }
func (r *reportErrorSource) Poll(_ context.Context) ([]*PipelineSignal, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.signals, nil
}
func (r *reportErrorSource) Report(_ context.Context, _ *ActionResult) error {
return fmt.Errorf("report error")
}
func TestPoller_RunOnce_Good_ReportError(t *testing.T) {
sig := &PipelineSignal{
EpicNumber: 1,
ChildNumber: 1,
PRNumber: 1,
RepoOwner: "test",
RepoName: "repo",
}
src := &reportErrorSource{
name: "report-fail-source",
signals: []*PipelineSignal{sig},
}
handler := &mockHandler{
name: "test-handler",
matchFn: func(s *PipelineSignal) bool { return true },
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
})
err := p.RunOnce(context.Background())
require.NoError(t, err) // Report errors are logged, not returned
handler.mu.Lock()
defer handler.mu.Unlock()
assert.Len(t, handler.executed, 1, "handler should still execute even though report fails")
}
// --- Poller: multiple sources and handlers ---
func TestPoller_RunOnce_Good_MultipleSources(t *testing.T) {
sig1 := &PipelineSignal{
EpicNumber: 1, ChildNumber: 1, PRNumber: 1,
RepoOwner: "org1", RepoName: "repo1",
}
sig2 := &PipelineSignal{
EpicNumber: 2, ChildNumber: 2, PRNumber: 2,
RepoOwner: "org2", RepoName: "repo2",
}
src1 := &mockSource{name: "source-1", signals: []*PipelineSignal{sig1}}
src2 := &mockSource{name: "source-2", signals: []*PipelineSignal{sig2}}
handler := &mockHandler{
name: "catch-all",
matchFn: func(s *PipelineSignal) bool { return true },
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src1, src2},
Handlers: []JobHandler{handler},
})
err := p.RunOnce(context.Background())
require.NoError(t, err)
handler.mu.Lock()
defer handler.mu.Unlock()
assert.Len(t, handler.executed, 2)
}
// --- Poller: Run with immediate cancellation ---
func TestPoller_Run_Good_ImmediateCancel(t *testing.T) {
src := &mockSource{name: "source", signals: nil}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
PollInterval: 1 * time.Hour, // long interval
})
ctx, cancel := context.WithCancel(context.Background())
// Cancel after the first RunOnce completes.
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
err := p.Run(ctx)
assert.ErrorIs(t, err, context.Canceled)
assert.Equal(t, 1, p.Cycle()) // One cycle from the initial RunOnce
}
// --- Journal: Append with journal error logging ---
func TestPoller_RunOnce_Good_JournalAppendError(t *testing.T) {
if os.Getuid() == 0 {
t.Skip("chmod does not restrict root")
}
// Use a directory that will cause journal writes to fail.
dir := t.TempDir()
journal, err := NewJournal(dir)
require.NoError(t, err)
// Make the journal directory read-only to trigger append errors.
require.NoError(t, os.Chmod(dir, 0o444))
t.Cleanup(func() { _ = os.Chmod(dir, 0o755) })
sig := &PipelineSignal{
EpicNumber: 1,
ChildNumber: 1,
PRNumber: 1,
RepoOwner: "test",
RepoName: "repo",
}
src := &mockSource{
name: "test-source",
signals: []*PipelineSignal{sig},
}
handler := &mockHandler{
name: "test-handler",
matchFn: func(s *PipelineSignal) bool { return true },
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
Journal: journal,
})
err = p.RunOnce(context.Background())
// Journal errors are logged, not returned.
require.NoError(t, err)
handler.mu.Lock()
defer handler.mu.Unlock()
assert.Len(t, handler.executed, 1, "handler should still execute even when journal fails")
}
// --- Poller: Cycle counter increments ---
func TestPoller_Cycle_Good_Increments(t *testing.T) {
src := &mockSource{name: "source", signals: nil}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
})
assert.Equal(t, 0, p.Cycle())
_ = p.RunOnce(context.Background())
assert.Equal(t, 1, p.Cycle())
_ = p.RunOnce(context.Background())
assert.Equal(t, 2, p.Cycle())
}

View file

@ -1,114 +0,0 @@
package forgejo
import (
"regexp"
"strconv"
forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// epicChildRe matches checklist items: - [ ] #42 or - [x] #42
var epicChildRe = regexp.MustCompile(`- \[([ x])\] #(\d+)`)
// parseEpicChildren extracts child issue numbers from an epic body's checklist.
func parseEpicChildren(body string) (unchecked []int, checked []int) {
matches := epicChildRe.FindAllStringSubmatch(body, -1)
for _, m := range matches {
num, err := strconv.Atoi(m[2])
if err != nil {
continue
}
if m[1] == "x" {
checked = append(checked, num)
} else {
unchecked = append(unchecked, num)
}
}
return unchecked, checked
}
// linkedPRRe matches "#N" references in PR bodies.
var linkedPRRe = regexp.MustCompile(`#(\d+)`)
// findLinkedPR finds the first PR whose body references the given issue number.
func findLinkedPR(prs []*forgejosdk.PullRequest, issueNumber int) *forgejosdk.PullRequest {
target := strconv.Itoa(issueNumber)
for _, pr := range prs {
matches := linkedPRRe.FindAllStringSubmatch(pr.Body, -1)
for _, m := range matches {
if m[1] == target {
return pr
}
}
}
return nil
}
// mapPRState maps Forgejo's PR state and merged flag to a canonical string.
func mapPRState(pr *forgejosdk.PullRequest) string {
if pr.HasMerged {
return "MERGED"
}
switch pr.State {
case forgejosdk.StateOpen:
return "OPEN"
case forgejosdk.StateClosed:
return "CLOSED"
default:
return "CLOSED"
}
}
// mapMergeable maps Forgejo's boolean Mergeable field to a canonical string.
func mapMergeable(pr *forgejosdk.PullRequest) string {
if pr.HasMerged {
return "UNKNOWN"
}
if pr.Mergeable {
return "MERGEABLE"
}
return "CONFLICTING"
}
// mapCombinedStatus maps a Forgejo CombinedStatus to SUCCESS/FAILURE/PENDING.
func mapCombinedStatus(cs *forgejosdk.CombinedStatus) string {
if cs == nil || cs.TotalCount == 0 {
return "PENDING"
}
switch cs.State {
case forgejosdk.StatusSuccess:
return "SUCCESS"
case forgejosdk.StatusFailure, forgejosdk.StatusError:
return "FAILURE"
default:
return "PENDING"
}
}
// buildSignal creates a PipelineSignal from Forgejo API data.
func buildSignal(
owner, repo string,
epicNumber, childNumber int,
pr *forgejosdk.PullRequest,
checkStatus string,
) *jobrunner.PipelineSignal {
sig := &jobrunner.PipelineSignal{
EpicNumber: epicNumber,
ChildNumber: childNumber,
PRNumber: int(pr.Index),
RepoOwner: owner,
RepoName: repo,
PRState: mapPRState(pr),
IsDraft: false, // SDK v2.2.0 doesn't expose Draft; treat as non-draft
Mergeable: mapMergeable(pr),
CheckStatus: checkStatus,
}
if pr.Head != nil {
sig.LastCommitSHA = pr.Head.Sha
}
return sig
}

View file

@ -1,205 +0,0 @@
package forgejo
import (
"testing"
forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
"github.com/stretchr/testify/assert"
)
func TestMapPRState_Good_Open(t *testing.T) {
pr := &forgejosdk.PullRequest{State: forgejosdk.StateOpen, HasMerged: false}
assert.Equal(t, "OPEN", mapPRState(pr))
}
func TestMapPRState_Good_Merged(t *testing.T) {
pr := &forgejosdk.PullRequest{State: forgejosdk.StateClosed, HasMerged: true}
assert.Equal(t, "MERGED", mapPRState(pr))
}
func TestMapPRState_Good_Closed(t *testing.T) {
pr := &forgejosdk.PullRequest{State: forgejosdk.StateClosed, HasMerged: false}
assert.Equal(t, "CLOSED", mapPRState(pr))
}
func TestMapPRState_Good_UnknownState(t *testing.T) {
// Any unknown state should default to CLOSED.
pr := &forgejosdk.PullRequest{State: "weird", HasMerged: false}
assert.Equal(t, "CLOSED", mapPRState(pr))
}
func TestMapMergeable_Good_Mergeable(t *testing.T) {
pr := &forgejosdk.PullRequest{Mergeable: true, HasMerged: false}
assert.Equal(t, "MERGEABLE", mapMergeable(pr))
}
func TestMapMergeable_Good_Conflicting(t *testing.T) {
pr := &forgejosdk.PullRequest{Mergeable: false, HasMerged: false}
assert.Equal(t, "CONFLICTING", mapMergeable(pr))
}
func TestMapMergeable_Good_Merged(t *testing.T) {
pr := &forgejosdk.PullRequest{HasMerged: true}
assert.Equal(t, "UNKNOWN", mapMergeable(pr))
}
func TestMapCombinedStatus_Good_Success(t *testing.T) {
cs := &forgejosdk.CombinedStatus{
State: forgejosdk.StatusSuccess,
TotalCount: 1,
}
assert.Equal(t, "SUCCESS", mapCombinedStatus(cs))
}
func TestMapCombinedStatus_Good_Failure(t *testing.T) {
cs := &forgejosdk.CombinedStatus{
State: forgejosdk.StatusFailure,
TotalCount: 1,
}
assert.Equal(t, "FAILURE", mapCombinedStatus(cs))
}
func TestMapCombinedStatus_Good_Error(t *testing.T) {
cs := &forgejosdk.CombinedStatus{
State: forgejosdk.StatusError,
TotalCount: 1,
}
assert.Equal(t, "FAILURE", mapCombinedStatus(cs))
}
func TestMapCombinedStatus_Good_Pending(t *testing.T) {
cs := &forgejosdk.CombinedStatus{
State: forgejosdk.StatusPending,
TotalCount: 1,
}
assert.Equal(t, "PENDING", mapCombinedStatus(cs))
}
func TestMapCombinedStatus_Good_Nil(t *testing.T) {
assert.Equal(t, "PENDING", mapCombinedStatus(nil))
}
func TestMapCombinedStatus_Good_ZeroCount(t *testing.T) {
cs := &forgejosdk.CombinedStatus{
State: forgejosdk.StatusSuccess,
TotalCount: 0,
}
assert.Equal(t, "PENDING", mapCombinedStatus(cs))
}
func TestParseEpicChildren_Good_Mixed(t *testing.T) {
body := "## Sprint\n- [x] #1\n- [ ] #2\n- [x] #3\n- [ ] #4\nSome text\n"
unchecked, checked := parseEpicChildren(body)
assert.Equal(t, []int{2, 4}, unchecked)
assert.Equal(t, []int{1, 3}, checked)
}
func TestParseEpicChildren_Good_NoCheckboxes(t *testing.T) {
body := "This is just a normal issue with no checkboxes."
unchecked, checked := parseEpicChildren(body)
assert.Nil(t, unchecked)
assert.Nil(t, checked)
}
func TestParseEpicChildren_Good_AllChecked(t *testing.T) {
body := "- [x] #10\n- [x] #20\n"
unchecked, checked := parseEpicChildren(body)
assert.Nil(t, unchecked)
assert.Equal(t, []int{10, 20}, checked)
}
func TestParseEpicChildren_Good_AllUnchecked(t *testing.T) {
body := "- [ ] #5\n- [ ] #6\n"
unchecked, checked := parseEpicChildren(body)
assert.Equal(t, []int{5, 6}, unchecked)
assert.Nil(t, checked)
}
func TestFindLinkedPR_Good(t *testing.T) {
prs := []*forgejosdk.PullRequest{
{Index: 10, Body: "Fixes #5"},
{Index: 11, Body: "Resolves #7"},
{Index: 12, Body: "Nothing here"},
}
pr := findLinkedPR(prs, 7)
assert.NotNil(t, pr)
assert.Equal(t, int64(11), pr.Index)
}
func TestFindLinkedPR_Good_NotFound(t *testing.T) {
prs := []*forgejosdk.PullRequest{
{Index: 10, Body: "Fixes #5"},
}
pr := findLinkedPR(prs, 99)
assert.Nil(t, pr)
}
func TestFindLinkedPR_Good_Nil(t *testing.T) {
pr := findLinkedPR(nil, 1)
assert.Nil(t, pr)
}
func TestBuildSignal_Good(t *testing.T) {
pr := &forgejosdk.PullRequest{
Index: 42,
State: forgejosdk.StateOpen,
Mergeable: true,
Head: &forgejosdk.PRBranchInfo{Sha: "deadbeef"},
}
sig := buildSignal("org", "repo", 10, 5, pr, "SUCCESS")
assert.Equal(t, 10, sig.EpicNumber)
assert.Equal(t, 5, sig.ChildNumber)
assert.Equal(t, 42, sig.PRNumber)
assert.Equal(t, "org", sig.RepoOwner)
assert.Equal(t, "repo", sig.RepoName)
assert.Equal(t, "OPEN", sig.PRState)
assert.Equal(t, "MERGEABLE", sig.Mergeable)
assert.Equal(t, "SUCCESS", sig.CheckStatus)
assert.Equal(t, "deadbeef", sig.LastCommitSHA)
assert.False(t, sig.IsDraft)
}
func TestBuildSignal_Good_NilHead(t *testing.T) {
pr := &forgejosdk.PullRequest{
Index: 1,
State: forgejosdk.StateClosed,
HasMerged: true,
}
sig := buildSignal("org", "repo", 1, 2, pr, "PENDING")
assert.Equal(t, "", sig.LastCommitSHA)
assert.Equal(t, "MERGED", sig.PRState)
}
func TestSplitRepo_Good(t *testing.T) {
tests := []struct {
input string
owner string
repo string
err bool
}{
{"host-uk/core", "host-uk", "core", false},
{"a/b", "a", "b", false},
{"org/repo-name", "org", "repo-name", false},
{"invalid", "", "", true},
{"", "", "", true},
{"/repo", "", "", true},
{"owner/", "", "", true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
owner, repo, err := splitRepo(tt.input)
if tt.err {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.owner, owner)
assert.Equal(t, tt.repo, repo)
}
})
}
}

View file

@ -1,173 +0,0 @@
package forgejo
import (
"context"
"fmt"
"strings"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
"forge.lthn.ai/core/go-log"
)
// Config configures a ForgejoSource.
type Config struct {
Repos []string // "owner/repo" format
}
// ForgejoSource polls a Forgejo instance for pipeline signals from epic issues.
type ForgejoSource struct {
repos []string
forge *forge.Client
}
// New creates a ForgejoSource using the given forge client.
func New(cfg Config, client *forge.Client) *ForgejoSource {
return &ForgejoSource{
repos: cfg.Repos,
forge: client,
}
}
// Name returns the source identifier.
func (s *ForgejoSource) Name() string {
return "forgejo"
}
// Poll fetches epics and their linked PRs from all configured repositories,
// returning a PipelineSignal for each unchecked child that has a linked PR.
func (s *ForgejoSource) Poll(ctx context.Context) ([]*jobrunner.PipelineSignal, error) {
var signals []*jobrunner.PipelineSignal
for _, repoFull := range s.repos {
owner, repo, err := splitRepo(repoFull)
if err != nil {
log.Error("invalid repo format", "repo", repoFull, "err", err)
continue
}
repoSignals, err := s.pollRepo(ctx, owner, repo)
if err != nil {
log.Error("poll repo failed", "repo", repoFull, "err", err)
continue
}
signals = append(signals, repoSignals...)
}
return signals, nil
}
// Report posts the action result as a comment on the epic issue.
func (s *ForgejoSource) Report(ctx context.Context, result *jobrunner.ActionResult) error {
if result == nil {
return nil
}
status := "succeeded"
if !result.Success {
status = "failed"
}
body := fmt.Sprintf("**jobrunner** `%s` %s for #%d (PR #%d)", result.Action, status, result.ChildNumber, result.PRNumber)
if result.Error != "" {
body += fmt.Sprintf("\n\n```\n%s\n```", result.Error)
}
return s.forge.CreateIssueComment(result.RepoOwner, result.RepoName, int64(result.EpicNumber), body)
}
// pollRepo fetches epics and PRs for a single repository.
func (s *ForgejoSource) pollRepo(_ context.Context, owner, repo string) ([]*jobrunner.PipelineSignal, error) {
// Fetch epic issues (label=epic, state=open).
issues, err := s.forge.ListIssues(owner, repo, forge.ListIssuesOpts{State: "open"})
if err != nil {
return nil, log.E("forgejo.pollRepo", "fetch issues", err)
}
// Filter to epics only.
var epics []epicInfo
for _, issue := range issues {
for _, label := range issue.Labels {
if label.Name == "epic" {
epics = append(epics, epicInfo{
Number: int(issue.Index),
Body: issue.Body,
})
break
}
}
}
if len(epics) == 0 {
return nil, nil
}
// Fetch all open PRs (and also merged/closed to catch MERGED state).
prs, err := s.forge.ListPullRequests(owner, repo, "all")
if err != nil {
return nil, log.E("forgejo.pollRepo", "fetch PRs", err)
}
var signals []*jobrunner.PipelineSignal
for _, epic := range epics {
unchecked, _ := parseEpicChildren(epic.Body)
for _, childNum := range unchecked {
pr := findLinkedPR(prs, childNum)
if pr == nil {
// No PR yet — check if the child issue is assigned (needs coding).
childIssue, err := s.forge.GetIssue(owner, repo, int64(childNum))
if err != nil {
log.Error("fetch child issue failed", "repo", owner+"/"+repo, "issue", childNum, "err", err)
continue
}
if len(childIssue.Assignees) > 0 && childIssue.Assignees[0].UserName != "" {
sig := &jobrunner.PipelineSignal{
EpicNumber: epic.Number,
ChildNumber: childNum,
RepoOwner: owner,
RepoName: repo,
NeedsCoding: true,
Assignee: childIssue.Assignees[0].UserName,
IssueTitle: childIssue.Title,
IssueBody: childIssue.Body,
}
signals = append(signals, sig)
}
continue
}
// Get combined commit status for the PR's head SHA.
checkStatus := "PENDING"
if pr.Head != nil && pr.Head.Sha != "" {
cs, err := s.forge.GetCombinedStatus(owner, repo, pr.Head.Sha)
if err != nil {
log.Error("fetch combined status failed", "repo", owner+"/"+repo, "sha", pr.Head.Sha, "err", err)
} else {
checkStatus = mapCombinedStatus(cs)
}
}
sig := buildSignal(owner, repo, epic.Number, childNum, pr, checkStatus)
signals = append(signals, sig)
}
}
return signals, nil
}
type epicInfo struct {
Number int
Body string
}
// splitRepo parses "owner/repo" into its components.
func splitRepo(full string) (string, string, error) {
parts := strings.SplitN(full, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", log.E("forgejo.splitRepo", fmt.Sprintf("expected owner/repo format, got %q", full), nil)
}
return parts[0], parts[1], nil
}

View file

@ -1,320 +0,0 @@
package forgejo
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
func TestForgejoSource_Poll_Good_InvalidRepo(t *testing.T) {
// Invalid repo format should be logged and skipped, not error.
s := New(Config{Repos: []string{"invalid-no-slash"}}, nil)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
assert.Empty(t, signals)
}
func TestForgejoSource_Poll_Good_MultipleRepos(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
// Return one epic per repo.
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #2\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 10,
"body": "Fixes #2",
"state": "open",
"mergeable": true,
"merged": false,
"head": map[string]string{"sha": "abc", "ref": "fix", "label": "fix"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
_ = json.NewEncoder(w).Encode(map[string]any{
"state": "success",
"total_count": 1,
"statuses": []any{},
})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org-a/repo-1", "org-b/repo-2"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
assert.Len(t, signals, 2)
}
func TestForgejoSource_Poll_Good_NeedsCoding(t *testing.T) {
// When a child issue has no linked PR but is assigned, NeedsCoding should be true.
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues/5"):
// Child issue with assignee.
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 5,
"title": "Implement feature",
"body": "Please implement this.",
"state": "open",
"assignees": []map[string]any{{"login": "darbs-claude", "username": "darbs-claude"}},
})
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #5\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
// No PRs linked.
_ = json.NewEncoder(w).Encode([]any{})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"test-org/test-repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
sig := signals[0]
assert.True(t, sig.NeedsCoding)
assert.Equal(t, "darbs-claude", sig.Assignee)
assert.Equal(t, "Implement feature", sig.IssueTitle)
assert.Equal(t, "Please implement this.", sig.IssueBody)
assert.Equal(t, 5, sig.ChildNumber)
}
func TestForgejoSource_Poll_Good_MergedPR(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #3\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 20,
"body": "Fixes #3",
"state": "closed",
"mergeable": false,
"merged": true,
"head": map[string]string{"sha": "merged123", "ref": "fix", "label": "fix"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
_ = json.NewEncoder(w).Encode(map[string]any{
"state": "success",
"total_count": 1,
"statuses": []any{},
})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
assert.Equal(t, "MERGED", signals[0].PRState)
assert.Equal(t, "UNKNOWN", signals[0].Mergeable)
}
func TestForgejoSource_Poll_Good_NoHeadSHA(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #3\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 20,
"body": "Fixes #3",
"state": "open",
"mergeable": true,
"merged": false,
// No head field.
},
}
_ = json.NewEncoder(w).Encode(prs)
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
// Without head SHA, check status stays PENDING.
assert.Equal(t, "PENDING", signals[0].CheckStatus)
}
func TestForgejoSource_Report_Good_Nil(t *testing.T) {
s := New(Config{}, nil)
err := s.Report(context.Background(), nil)
assert.NoError(t, err)
}
func TestForgejoSource_Report_Good_Failed(t *testing.T) {
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
capturedBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{}, client)
result := &jobrunner.ActionResult{
Action: "dispatch",
RepoOwner: "org",
RepoName: "repo",
EpicNumber: 1,
ChildNumber: 2,
PRNumber: 3,
Success: false,
Error: "transfer failed",
}
err := s.Report(context.Background(), result)
require.NoError(t, err)
assert.Contains(t, capturedBody, "failed")
assert.Contains(t, capturedBody, "transfer failed")
}
func TestForgejoSource_Poll_Good_APIErrors(t *testing.T) {
// When the issues API fails, poll should continue with other repos.
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
assert.Empty(t, signals)
}
func TestForgejoSource_Poll_Good_EmptyRepos(t *testing.T) {
s := New(Config{Repos: []string{}}, nil)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
assert.Empty(t, signals)
}
func TestForgejoSource_Poll_Good_NonEpicIssues(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
// Issues without the "epic" label.
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #2\n",
"labels": []map[string]string{{"name": "bug"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
assert.Empty(t, signals, "non-epic issues should not generate signals")
}

View file

@ -1,672 +0,0 @@
package forgejo
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// --- Signal parsing and filtering tests ---
func TestParseEpicChildren_Good_EmptyBody(t *testing.T) {
unchecked, checked := parseEpicChildren("")
assert.Nil(t, unchecked)
assert.Nil(t, checked)
}
func TestParseEpicChildren_Good_MixedContent(t *testing.T) {
// Checkboxes mixed with regular markdown content.
body := `## Epic: Refactor Auth
Some description of the epic.
### Tasks
- [x] #10 Migrate session store
- [ ] #11 Update OAuth flow
- [x] #12 Fix token refresh
- [ ] #13 Add 2FA support
### Notes
This is a note, not a checkbox.
- Regular list item
- Another item
`
unchecked, checked := parseEpicChildren(body)
assert.Equal(t, []int{11, 13}, unchecked)
assert.Equal(t, []int{10, 12}, checked)
}
func TestParseEpicChildren_Good_LargeIssueNumbers(t *testing.T) {
body := "- [ ] #9999\n- [x] #10000\n"
unchecked, checked := parseEpicChildren(body)
assert.Equal(t, []int{9999}, unchecked)
assert.Equal(t, []int{10000}, checked)
}
func TestParseEpicChildren_Good_ConsecutiveCheckboxes(t *testing.T) {
body := "- [ ] #1\n- [ ] #2\n- [ ] #3\n- [ ] #4\n- [ ] #5\n"
unchecked, checked := parseEpicChildren(body)
assert.Equal(t, []int{1, 2, 3, 4, 5}, unchecked)
assert.Nil(t, checked)
}
// --- findLinkedPR tests ---
func TestFindLinkedPR_Good_MultipleReferencesInBody(t *testing.T) {
prs := []*forgejosdk.PullRequest{
{Index: 10, Body: "Fixes #5 and relates to #7"},
{Index: 11, Body: "Closes #8"},
}
// Should find PR #10 because it references #7.
pr := findLinkedPR(prs, 7)
assert.NotNil(t, pr)
assert.Equal(t, int64(10), pr.Index)
// Should find PR #10 because it references #5.
pr = findLinkedPR(prs, 5)
assert.NotNil(t, pr)
assert.Equal(t, int64(10), pr.Index)
}
func TestFindLinkedPR_Good_EmptyBodyPR(t *testing.T) {
prs := []*forgejosdk.PullRequest{
{Index: 10, Body: ""},
{Index: 11, Body: "Fixes #7"},
}
pr := findLinkedPR(prs, 7)
assert.NotNil(t, pr)
assert.Equal(t, int64(11), pr.Index)
}
func TestFindLinkedPR_Good_FirstMatchWins(t *testing.T) {
// Both PRs reference #7, first one should win.
prs := []*forgejosdk.PullRequest{
{Index: 10, Body: "Fixes #7"},
{Index: 11, Body: "Also fixes #7"},
}
pr := findLinkedPR(prs, 7)
assert.NotNil(t, pr)
assert.Equal(t, int64(10), pr.Index)
}
func TestFindLinkedPR_Good_EmptySlice(t *testing.T) {
prs := []*forgejosdk.PullRequest{}
pr := findLinkedPR(prs, 1)
assert.Nil(t, pr)
}
// --- mapPRState edge case ---
func TestMapPRState_Good_MergedOverridesState(t *testing.T) {
// HasMerged=true should return MERGED regardless of State.
pr := &forgejosdk.PullRequest{State: forgejosdk.StateOpen, HasMerged: true}
assert.Equal(t, "MERGED", mapPRState(pr))
}
// --- mapCombinedStatus edge cases ---
func TestMapCombinedStatus_Good_WarningState(t *testing.T) {
// Unknown/warning state should default to PENDING.
cs := &forgejosdk.CombinedStatus{
State: forgejosdk.StatusWarning,
TotalCount: 1,
}
assert.Equal(t, "PENDING", mapCombinedStatus(cs))
}
// --- buildSignal edge cases ---
func TestBuildSignal_Good_ClosedPR(t *testing.T) {
pr := &forgejosdk.PullRequest{
Index: 5,
State: forgejosdk.StateClosed,
Mergeable: false,
HasMerged: false,
Head: &forgejosdk.PRBranchInfo{Sha: "abc"},
}
sig := buildSignal("org", "repo", 1, 2, pr, "FAILURE")
assert.Equal(t, "CLOSED", sig.PRState)
assert.Equal(t, "CONFLICTING", sig.Mergeable)
assert.Equal(t, "FAILURE", sig.CheckStatus)
assert.Equal(t, "abc", sig.LastCommitSHA)
}
func TestBuildSignal_Good_MergedPR(t *testing.T) {
pr := &forgejosdk.PullRequest{
Index: 99,
State: forgejosdk.StateClosed,
Mergeable: false,
HasMerged: true,
Head: &forgejosdk.PRBranchInfo{Sha: "merged123"},
}
sig := buildSignal("owner", "repo", 10, 5, pr, "SUCCESS")
assert.Equal(t, "MERGED", sig.PRState)
assert.Equal(t, "UNKNOWN", sig.Mergeable)
assert.Equal(t, 99, sig.PRNumber)
assert.Equal(t, "merged123", sig.LastCommitSHA)
}
// --- splitRepo edge cases ---
func TestSplitRepo_Bad_OnlySlash(t *testing.T) {
_, _, err := splitRepo("/")
assert.Error(t, err)
}
func TestSplitRepo_Bad_MultipleSlashes(t *testing.T) {
// Should take only the first part as owner, rest as repo.
owner, repo, err := splitRepo("a/b/c")
require.NoError(t, err)
assert.Equal(t, "a", owner)
assert.Equal(t, "b/c", repo)
}
// --- Poll with combined status failure ---
func TestForgejoSource_Poll_Good_CombinedStatusFailure(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #2\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 10,
"body": "Fixes #2",
"state": "open",
"mergeable": true,
"merged": false,
"head": map[string]string{"sha": "fail123", "ref": "feature", "label": "feature"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
status := map[string]any{
"state": "failure",
"total_count": 2,
"statuses": []map[string]any{{"status": "failure", "context": "ci"}},
}
_ = json.NewEncoder(w).Encode(status)
default:
w.WriteHeader(http.StatusNotFound)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
assert.Equal(t, "FAILURE", signals[0].CheckStatus)
assert.Equal(t, "OPEN", signals[0].PRState)
assert.Equal(t, "MERGEABLE", signals[0].Mergeable)
}
// --- Poll with combined status error ---
func TestForgejoSource_Poll_Good_CombinedStatusError(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #3\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 20,
"body": "Fixes #3",
"state": "open",
"mergeable": false,
"merged": false,
"head": map[string]string{"sha": "err123", "ref": "fix", "label": "fix"},
},
}
_ = json.NewEncoder(w).Encode(prs)
// Combined status endpoint returns 500 — should fall back to PENDING.
case strings.Contains(path, "/status"):
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusNotFound)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
// Combined status API error -> falls back to PENDING.
assert.Equal(t, "PENDING", signals[0].CheckStatus)
assert.Equal(t, "CONFLICTING", signals[0].Mergeable)
}
// --- Poll with child that has no assignee (NeedsCoding path, no assignee) ---
func TestForgejoSource_Poll_Good_ChildNoAssignee(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues/5"):
// Child issue with no assignee.
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 5,
"title": "Unassigned task",
"body": "No one is working on this.",
"state": "open",
"assignees": []map[string]any{},
})
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #5\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
_ = json.NewEncoder(w).Encode([]any{})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
// No signal should be emitted when child has no assignee and no PR.
assert.Empty(t, signals)
}
// --- Poll with child issue fetch failure ---
func TestForgejoSource_Poll_Good_ChildFetchFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues/5"):
// Child issue fetch fails.
w.WriteHeader(http.StatusInternalServerError)
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #5\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
_ = json.NewEncoder(w).Encode([]any{})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
// Child fetch error should be logged and skipped, not returned as error.
assert.Empty(t, signals)
}
// --- Poll with multiple epics ---
func TestForgejoSource_Poll_Good_MultipleEpics(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #3\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
{
"number": 2,
"body": "- [ ] #4\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 10,
"body": "Fixes #3",
"state": "open",
"mergeable": true,
"merged": false,
"head": map[string]string{"sha": "aaa", "ref": "f1", "label": "f1"},
},
{
"number": 11,
"body": "Fixes #4",
"state": "open",
"mergeable": true,
"merged": false,
"head": map[string]string{"sha": "bbb", "ref": "f2", "label": "f2"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
_ = json.NewEncoder(w).Encode(map[string]any{
"state": "success",
"total_count": 1,
"statuses": []any{},
})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 2)
assert.Equal(t, 1, signals[0].EpicNumber)
assert.Equal(t, 3, signals[0].ChildNumber)
assert.Equal(t, 10, signals[0].PRNumber)
assert.Equal(t, 2, signals[1].EpicNumber)
assert.Equal(t, 4, signals[1].ChildNumber)
assert.Equal(t, 11, signals[1].PRNumber)
}
// --- Report with nil result ---
func TestForgejoSource_Report_Good_NilResult(t *testing.T) {
s := New(Config{}, nil)
err := s.Report(context.Background(), nil)
assert.NoError(t, err)
}
// --- Report constructs correct comment body ---
func TestForgejoSource_Report_Good_SuccessFormat(t *testing.T) {
var capturedPath string
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
w.Header().Set("Content-Type", "application/json")
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
capturedBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{}, client)
result := &jobrunner.ActionResult{
Action: "tick_parent",
RepoOwner: "core",
RepoName: "go-scm",
EpicNumber: 5,
ChildNumber: 10,
PRNumber: 20,
Success: true,
}
err := s.Report(context.Background(), result)
require.NoError(t, err)
// Comment should be on the epic issue.
assert.Contains(t, capturedPath, "/issues/5/comments")
assert.Contains(t, capturedBody, "tick_parent")
assert.Contains(t, capturedBody, "succeeded")
assert.Contains(t, capturedBody, "#10")
assert.Contains(t, capturedBody, "PR #20")
}
func TestForgejoSource_Report_Good_FailureWithError(t *testing.T) {
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
capturedBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{}, client)
result := &jobrunner.ActionResult{
Action: "enable_auto_merge",
RepoOwner: "org",
RepoName: "repo",
EpicNumber: 1,
ChildNumber: 2,
PRNumber: 3,
Success: false,
Error: "merge conflict detected",
}
err := s.Report(context.Background(), result)
require.NoError(t, err)
assert.Contains(t, capturedBody, "failed")
assert.Contains(t, capturedBody, "merge conflict detected")
}
// --- Poll filters only epic-labelled issues ---
func TestForgejoSource_Poll_Good_MixedLabels(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #2\n",
"labels": []map[string]string{{"name": "epic"}, {"name": "priority-high"}},
"state": "open",
},
{
"number": 3,
"body": "- [ ] #4\n",
"labels": []map[string]string{{"name": "bug"}},
"state": "open",
},
{
"number": 5,
"body": "- [ ] #6\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 10,
"body": "Fixes #2",
"state": "open",
"mergeable": true,
"merged": false,
"head": map[string]string{"sha": "sha1", "ref": "f1", "label": "f1"},
},
{
"number": 11,
"body": "Fixes #4",
"state": "open",
"mergeable": true,
"merged": false,
"head": map[string]string{"sha": "sha2", "ref": "f2", "label": "f2"},
},
{
"number": 12,
"body": "Fixes #6",
"state": "open",
"mergeable": true,
"merged": false,
"head": map[string]string{"sha": "sha3", "ref": "f3", "label": "f3"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
_ = json.NewEncoder(w).Encode(map[string]any{
"state": "success",
"total_count": 1,
"statuses": []any{},
})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
// Only issues #1 and #5 have the "epic" label.
require.Len(t, signals, 2)
assert.Equal(t, 1, signals[0].EpicNumber)
assert.Equal(t, 2, signals[0].ChildNumber)
assert.Equal(t, 5, signals[1].EpicNumber)
assert.Equal(t, 6, signals[1].ChildNumber)
}
// --- Poll with PRs error after issues succeed ---
func TestForgejoSource_Poll_Good_PRsAPIError(t *testing.T) {
callCount := 0
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
callCount++
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1,
"body": "- [ ] #2\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client, err := forge.New(srv.URL, "test-token")
require.NoError(t, err)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
// PR API failure -> repo is skipped, no signals.
assert.Empty(t, signals)
}
// --- New creates source correctly ---
func TestForgejoSource_New_Good(t *testing.T) {
s := New(Config{Repos: []string{"a/b", "c/d"}}, nil)
assert.Equal(t, "forgejo", s.Name())
assert.Equal(t, []string{"a/b", "c/d"}, s.repos)
}

View file

@ -1,409 +0,0 @@
package forgejo
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// ---------------------------------------------------------------------------
// Supplementary Forgejo signal source tests — extends Phase 3 coverage
// ---------------------------------------------------------------------------
func TestForgejoSource_Poll_Good_MultipleEpicsMultipleChildren(t *testing.T) {
// Two epics, each with multiple unchecked children that have linked PRs.
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 10,
"body": "## Sprint\n- [ ] #11\n- [ ] #12\n- [x] #13\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
{
"number": 20,
"body": "## Sprint 2\n- [ ] #21\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 30, "body": "Fixes #11", "state": "open",
"mergeable": true, "merged": false,
"head": map[string]string{"sha": "aaa111", "ref": "fix-11", "label": "fix-11"},
},
{
"number": 31, "body": "Fixes #12", "state": "open",
"mergeable": false, "merged": false,
"head": map[string]string{"sha": "bbb222", "ref": "fix-12", "label": "fix-12"},
},
{
"number": 32, "body": "Resolves #21", "state": "open",
"mergeable": true, "merged": false,
"head": map[string]string{"sha": "ccc333", "ref": "fix-21", "label": "fix-21"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
_ = json.NewEncoder(w).Encode(map[string]any{
"state": "success", "total_count": 1, "statuses": []any{},
})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
// Epic 10 has #11 and #12 unchecked; epic 20 has #21 unchecked. Total 3 signals.
require.Len(t, signals, 3, "expected three signals from two epics")
childNumbers := map[int]bool{}
for _, sig := range signals {
childNumbers[sig.ChildNumber] = true
}
assert.True(t, childNumbers[11])
assert.True(t, childNumbers[12])
assert.True(t, childNumbers[21])
}
func TestForgejoSource_Poll_Good_CombinedStatusFetchErrorFallsToPending(t *testing.T) {
// When combined status fetch fails, check status should default to PENDING.
var statusFetched atomic.Bool
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1, "body": "- [ ] #2\n",
"labels": []map[string]string{{"name": "epic"}}, "state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 10, "body": "Fixes #2", "state": "open",
"mergeable": true, "merged": false,
"head": map[string]string{"sha": "sha123", "ref": "fix", "label": "fix"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
statusFetched.Store(true)
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
assert.True(t, statusFetched.Load(), "status endpoint should have been called")
assert.Equal(t, "PENDING", signals[0].CheckStatus, "failed status fetch should default to PENDING")
}
func TestForgejoSource_Poll_Good_MixedReposFirstFailsSecondSucceeds(t *testing.T) {
// First repo fails (issues endpoint 500), second repo succeeds.
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/repos/bad-org/bad-repo/issues"):
w.WriteHeader(http.StatusInternalServerError)
case strings.Contains(path, "/repos/good-org/good-repo/issues"):
issues := []map[string]any{
{
"number": 1, "body": "- [ ] #2\n",
"labels": []map[string]string{{"name": "epic"}}, "state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/repos/good-org/good-repo/pulls"):
prs := []map[string]any{
{
"number": 10, "body": "Fixes #2", "state": "open",
"mergeable": true, "merged": false,
"head": map[string]string{"sha": "abc", "ref": "fix", "label": "fix"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
_ = json.NewEncoder(w).Encode(map[string]any{
"state": "success", "total_count": 1, "statuses": []any{},
})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"bad-org/bad-repo", "good-org/good-repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1, "only the good repo should produce signals")
assert.Equal(t, "good-org", signals[0].RepoOwner)
assert.Equal(t, "good-repo", signals[0].RepoName)
}
func TestForgejoSource_Report_Good_CommentBodyTable(t *testing.T) {
tests := []struct {
name string
result *jobrunner.ActionResult
wantContains []string
}{
{
name: "successful action",
result: &jobrunner.ActionResult{
Action: "enable_auto_merge", RepoOwner: "org", RepoName: "repo",
EpicNumber: 10, ChildNumber: 11, PRNumber: 20, Success: true,
},
wantContains: []string{"enable_auto_merge", "succeeded", "#11", "PR #20"},
},
{
name: "failed action with error",
result: &jobrunner.ActionResult{
Action: "tick_parent", RepoOwner: "org", RepoName: "repo",
EpicNumber: 10, ChildNumber: 11, PRNumber: 20,
Success: false, Error: "rate limit exceeded",
},
wantContains: []string{"tick_parent", "failed", "#11", "PR #20", "rate limit exceeded"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
capturedBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{}, client)
err := s.Report(context.Background(), tt.result)
require.NoError(t, err)
for _, want := range tt.wantContains {
assert.Contains(t, capturedBody, want)
}
})
}
}
func TestForgejoSource_Report_Good_PostsToCorrectEpicIssue(t *testing.T) {
var capturedPath string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost {
capturedPath = r.URL.Path
}
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{}, client)
result := &jobrunner.ActionResult{
Action: "merge", RepoOwner: "test-org", RepoName: "test-repo",
EpicNumber: 42, ChildNumber: 7, PRNumber: 99, Success: true,
}
err := s.Report(context.Background(), result)
require.NoError(t, err)
expected := fmt.Sprintf("/api/v1/repos/%s/%s/issues/%d/comments", result.RepoOwner, result.RepoName, result.EpicNumber)
assert.Equal(t, expected, capturedPath, "comment should be posted on the epic issue")
}
func TestForgejoSource_Poll_Good_SignalFieldCompleteness(t *testing.T) {
// Verify that all expected signal fields are populated correctly.
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 100, "body": "## Work\n- [ ] #101\n- [x] #102\n",
"labels": []map[string]string{{"name": "epic"}}, "state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 200, "body": "Closes #101", "state": "open",
"mergeable": true, "merged": false,
"head": map[string]string{"sha": "deadbeef", "ref": "feature", "label": "feature"},
},
}
_ = json.NewEncoder(w).Encode(prs)
case strings.Contains(path, "/status"):
_ = json.NewEncoder(w).Encode(map[string]any{
"state": "success", "total_count": 2,
"statuses": []map[string]any{{"status": "success"}, {"status": "success"}},
})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"acme/widgets"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
sig := signals[0]
assert.Equal(t, 100, sig.EpicNumber)
assert.Equal(t, 101, sig.ChildNumber)
assert.Equal(t, 200, sig.PRNumber)
assert.Equal(t, "acme", sig.RepoOwner)
assert.Equal(t, "widgets", sig.RepoName)
assert.Equal(t, "OPEN", sig.PRState)
assert.Equal(t, "MERGEABLE", sig.Mergeable)
assert.Equal(t, "SUCCESS", sig.CheckStatus)
assert.Equal(t, "deadbeef", sig.LastCommitSHA)
assert.False(t, sig.NeedsCoding)
assert.Equal(t, "acme/widgets", sig.RepoFullName())
}
func TestForgejoSource_Poll_Good_AllChildrenCheckedNoSignals(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1, "body": "- [x] #2\n- [x] #3\n",
"labels": []map[string]string{{"name": "epic"}}, "state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
_ = json.NewEncoder(w).Encode([]any{})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
assert.Empty(t, signals, "all children checked means no work to do")
}
func TestForgejoSource_Poll_Good_NeedsCodingSignalFields(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
case strings.Contains(path, "/issues/7"):
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 7, "title": "Implement authentication",
"body": "Add OAuth2 support.", "state": "open",
"assignees": []map[string]any{{"login": "agent-bot", "username": "agent-bot"}},
})
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 1, "body": "- [ ] #7\n",
"labels": []map[string]string{{"name": "epic"}}, "state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
case strings.Contains(path, "/pulls"):
_ = json.NewEncoder(w).Encode([]any{})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"org/repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
sig := signals[0]
assert.True(t, sig.NeedsCoding)
assert.Equal(t, "agent-bot", sig.Assignee)
assert.Equal(t, "Implement authentication", sig.IssueTitle)
assert.Contains(t, sig.IssueBody, "OAuth2 support")
assert.Equal(t, 0, sig.PRNumber, "PRNumber should be zero for NeedsCoding signals")
}

View file

@ -1,177 +0,0 @@
package forgejo
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// withVersion wraps an HTTP handler to serve the Forgejo /api/v1/version
// endpoint that the SDK calls during NewClient initialization.
func withVersion(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/version") {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"version":"9.0.0"}`))
return
}
next.ServeHTTP(w, r)
})
}
func newTestClient(t *testing.T, url string) *forge.Client {
t.Helper()
client, err := forge.New(url, "test-token")
require.NoError(t, err)
return client
}
func TestForgejoSource_Name(t *testing.T) {
s := New(Config{}, nil)
assert.Equal(t, "forgejo", s.Name())
}
func TestForgejoSource_Poll_Good(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch {
// List issues — return one epic
case strings.Contains(path, "/issues"):
issues := []map[string]any{
{
"number": 10,
"body": "## Tasks\n- [ ] #11\n- [x] #12\n",
"labels": []map[string]string{{"name": "epic"}},
"state": "open",
},
}
_ = json.NewEncoder(w).Encode(issues)
// List PRs — return one open PR linked to #11
case strings.Contains(path, "/pulls"):
prs := []map[string]any{
{
"number": 20,
"body": "Fixes #11",
"state": "open",
"mergeable": true,
"merged": false,
"head": map[string]string{"sha": "abc123", "ref": "feature", "label": "feature"},
},
}
_ = json.NewEncoder(w).Encode(prs)
// Combined status
case strings.Contains(path, "/status"):
status := map[string]any{
"state": "success",
"total_count": 1,
"statuses": []map[string]any{{"status": "success", "context": "ci"}},
}
_ = json.NewEncoder(w).Encode(status)
default:
w.WriteHeader(http.StatusNotFound)
}
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"test-org/test-repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
require.Len(t, signals, 1)
sig := signals[0]
assert.Equal(t, 10, sig.EpicNumber)
assert.Equal(t, 11, sig.ChildNumber)
assert.Equal(t, 20, sig.PRNumber)
assert.Equal(t, "OPEN", sig.PRState)
assert.Equal(t, "MERGEABLE", sig.Mergeable)
assert.Equal(t, "SUCCESS", sig.CheckStatus)
assert.Equal(t, "test-org", sig.RepoOwner)
assert.Equal(t, "test-repo", sig.RepoName)
assert.Equal(t, "abc123", sig.LastCommitSHA)
}
func TestForgejoSource_Poll_NoEpics(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode([]any{})
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{Repos: []string{"test-org/test-repo"}}, client)
signals, err := s.Poll(context.Background())
require.NoError(t, err)
assert.Empty(t, signals)
}
func TestForgejoSource_Report_Good(t *testing.T) {
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
capturedBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
})))
defer srv.Close()
client := newTestClient(t, srv.URL)
s := New(Config{}, client)
result := &jobrunner.ActionResult{
Action: "enable_auto_merge",
RepoOwner: "test-org",
RepoName: "test-repo",
EpicNumber: 10,
ChildNumber: 11,
PRNumber: 20,
Success: true,
}
err := s.Report(context.Background(), result)
require.NoError(t, err)
assert.Contains(t, capturedBody, "enable_auto_merge")
assert.Contains(t, capturedBody, "succeeded")
}
func TestParseEpicChildren(t *testing.T) {
body := "## Tasks\n- [x] #1\n- [ ] #7\n- [ ] #8\n- [x] #3\n"
unchecked, checked := parseEpicChildren(body)
assert.Equal(t, []int{7, 8}, unchecked)
assert.Equal(t, []int{1, 3}, checked)
}
func TestFindLinkedPR(t *testing.T) {
assert.Nil(t, findLinkedPR(nil, 7))
}
func TestSplitRepo(t *testing.T) {
owner, repo, err := splitRepo("host-uk/core")
require.NoError(t, err)
assert.Equal(t, "host-uk", owner)
assert.Equal(t, "core", repo)
_, _, err = splitRepo("invalid")
assert.Error(t, err)
_, _, err = splitRepo("")
assert.Error(t, err)
}

View file

@ -1,88 +0,0 @@
package handlers
import (
"context"
"fmt"
"time"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
const (
ColorAgentComplete = "#0e8a16" // Green
)
// CompletionHandler manages issue state when an agent finishes work.
type CompletionHandler struct {
forge *forge.Client
}
// NewCompletionHandler creates a handler for agent completion events.
func NewCompletionHandler(client *forge.Client) *CompletionHandler {
return &CompletionHandler{
forge: client,
}
}
// Name returns the handler identifier.
func (h *CompletionHandler) Name() string {
return "completion"
}
// Match returns true if the signal indicates an agent has finished a task.
func (h *CompletionHandler) Match(signal *jobrunner.PipelineSignal) bool {
return signal.Type == "agent_completion"
}
// Execute updates the issue labels based on the completion status.
func (h *CompletionHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) {
start := time.Now()
// Remove in-progress label.
if inProgressLabel, err := h.forge.GetLabelByName(signal.RepoOwner, signal.RepoName, LabelInProgress); err == nil {
_ = h.forge.RemoveIssueLabel(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), inProgressLabel.ID)
}
if signal.Success {
completeLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentComplete, ColorAgentComplete)
if err != nil {
return nil, coreerr.E("completion.Execute", "ensure label "+LabelAgentComplete, err)
}
if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{completeLabel.ID}); err != nil {
return nil, coreerr.E("completion.Execute", "add completed label", err)
}
if signal.Message != "" {
_ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), signal.Message)
}
} else {
failedLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentFailed, ColorAgentFailed)
if err != nil {
return nil, coreerr.E("completion.Execute", "ensure label "+LabelAgentFailed, err)
}
if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{failedLabel.ID}); err != nil {
return nil, coreerr.E("completion.Execute", "add failed label", err)
}
msg := "Agent reported failure."
if signal.Error != "" {
msg += fmt.Sprintf("\n\nError: %s", signal.Error)
}
_ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), msg)
}
return &jobrunner.ActionResult{
Action: "completion",
RepoOwner: signal.RepoOwner,
RepoName: signal.RepoName,
EpicNumber: signal.EpicNumber,
ChildNumber: signal.ChildNumber,
Success: true,
Timestamp: time.Now(),
Duration: time.Since(start),
}, nil
}

View file

@ -1,291 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
func TestCompletion_Name_Good(t *testing.T) {
h := NewCompletionHandler(nil)
assert.Equal(t, "completion", h.Name())
}
func TestCompletion_Match_Good_AgentCompletion(t *testing.T) {
h := NewCompletionHandler(nil)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
}
assert.True(t, h.Match(sig))
}
func TestCompletion_Match_Bad_WrongType(t *testing.T) {
h := NewCompletionHandler(nil)
sig := &jobrunner.PipelineSignal{
Type: "pr_update",
}
assert.False(t, h.Match(sig))
}
func TestCompletion_Match_Bad_EmptyType(t *testing.T) {
h := NewCompletionHandler(nil)
sig := &jobrunner.PipelineSignal{}
assert.False(t, h.Match(sig))
}
func TestCompletion_Execute_Good_Success(t *testing.T) {
var labelRemoved bool
var labelAdded bool
var commentPosted bool
var commentBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
// GetLabelByName (in-progress) — GET labels to find in-progress.
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/test-org/test-repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "in-progress", "color": "#1d76db"},
})
// RemoveIssueLabel (in-progress).
case r.Method == http.MethodDelete && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/labels/1":
labelRemoved = true
w.WriteHeader(http.StatusNoContent)
// EnsureLabel (agent-completed) — POST to create.
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 2, "name": "agent-completed", "color": "#0e8a16"})
// AddIssueLabels.
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/labels":
labelAdded = true
_ = json.NewEncoder(w).Encode([]map[string]any{{"id": 2, "name": "agent-completed"}})
// CreateIssueComment.
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/comments":
commentPosted = true
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
commentBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
RepoOwner: "test-org",
RepoName: "test-repo",
ChildNumber: 5,
EpicNumber: 3,
Success: true,
Message: "Task completed successfully",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "completion", result.Action)
assert.Equal(t, "test-org", result.RepoOwner)
assert.Equal(t, "test-repo", result.RepoName)
assert.Equal(t, 3, result.EpicNumber)
assert.Equal(t, 5, result.ChildNumber)
assert.True(t, labelRemoved, "in-progress label should be removed")
assert.True(t, labelAdded, "agent-completed label should be added")
assert.True(t, commentPosted, "comment should be posted")
assert.Contains(t, commentBody, "Task completed successfully")
}
func TestCompletion_Execute_Good_Failure(t *testing.T) {
var labelAdded bool
var commentBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/test-org/test-repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 3, "name": "agent-failed", "color": "#c0392b"})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/labels":
labelAdded = true
_ = json.NewEncoder(w).Encode([]map[string]any{{"id": 3, "name": "agent-failed"}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/comments":
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
commentBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
RepoOwner: "test-org",
RepoName: "test-repo",
ChildNumber: 5,
EpicNumber: 3,
Success: false,
Error: "tests failed",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success) // The handler itself succeeded
assert.Equal(t, "completion", result.Action)
assert.True(t, labelAdded, "agent-failed label should be added")
assert.Contains(t, commentBody, "Agent reported failure")
assert.Contains(t, commentBody, "tests failed")
}
func TestCompletion_Execute_Good_FailureNoError(t *testing.T) {
var commentBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 3, "name": "agent-failed", "color": "#c0392b"})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/issues/1/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/issues/1/comments":
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
commentBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 1,
Success: false,
Error: "", // No error message.
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Contains(t, commentBody, "Agent reported failure")
assert.NotContains(t, commentBody, "Error:") // No error detail.
}
func TestCompletion_Execute_Good_SuccessNoMessage(t *testing.T) {
var commentPosted bool
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 2, "name": "agent-completed", "color": "#0e8a16"})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/issues/1/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/issues/1/comments":
commentPosted = true
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 1,
Success: true,
Message: "", // No message.
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.False(t, commentPosted, "no comment should be posted when message is empty")
}
func TestCompletion_Execute_Bad_EnsureLabelFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
// Return empty so EnsureLabel tries to create.
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
// Label creation fails.
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 1,
Success: true,
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "ensure label")
}

View file

@ -1,704 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
agentci "forge.lthn.ai/core/agent/pkg/orchestrator"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// --- Dispatch: Execute with invalid repo name ---
func TestDispatch_Execute_Bad_InvalidRepoNameSpecialChars(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
RepoOwner: "valid-org",
RepoName: "repo$bad!",
ChildNumber: 1,
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid repo name")
}
// --- Dispatch: Execute when EnsureLabel fails ---
func TestDispatch_Execute_Bad_EnsureLabelCreationFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 1,
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "ensure label")
}
// dispatchMockServer creates a standard mock server for dispatch tests.
// It handles all the Forgejo API calls needed for a full dispatch flow.
func dispatchMockServer(t *testing.T) *httptest.Server {
t.Helper()
return httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
// GetLabelByName / list labels
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "in-progress", "color": "#1d76db"},
{"id": 2, "name": "agent-ready", "color": "#00ff00"},
})
// CreateLabel (shouldn't normally be needed since we return it above)
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress", "color": "#1d76db"})
// GetIssue (returns issue with no label to trigger the full dispatch flow)
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
w.WriteHeader(http.StatusNotFound) // Issue not found => full dispatch flow
// AssignIssue
case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
_ = json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5})
// AddIssueLabels
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{{"id": 1, "name": "in-progress"}})
// RemoveIssueLabel
case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/labels/"):
w.WriteHeader(http.StatusNoContent)
// CreateIssueComment
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/comments"):
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": "dispatched"})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
}
// --- Dispatch: Execute when GetIssue returns 404 (full dispatch path) ---
func TestDispatch_Execute_Good_GetIssueNotFound(t *testing.T) {
srv := dispatchMockServer(t)
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 5,
EpicNumber: 3,
IssueTitle: "Test issue",
IssueBody: "Test body",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.Equal(t, "dispatch", result.Action)
}
// --- Completion: Execute when AddIssueLabels fails for success case ---
func TestCompletion_Execute_Bad_AddCompleteLabelFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/repo/labels"):
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 2, "name": "agent-completed", "color": "#0e8a16"})
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"):
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 5,
Success: true,
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "add completed label")
}
// --- Completion: Execute when AddIssueLabels fails for failure case ---
func TestCompletion_Execute_Bad_AddFailLabelFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/repo/labels"):
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 3, "name": "agent-failed", "color": "#c0392b"})
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"):
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 5,
Success: false,
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "add failed label")
}
// --- Completion: Execute with EnsureLabel failure on failure path ---
func TestCompletion_Execute_Bad_FailedPathEnsureLabelFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{})
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/labels"):
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
sig := &jobrunner.PipelineSignal{
Type: "agent_completion",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 1,
Success: false,
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "ensure label")
}
// --- EnableAutoMerge: additional edge case ---
func TestEnableAutoMerge_Match_Bad_PendingChecks(t *testing.T) {
h := NewEnableAutoMergeHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
IsDraft: false,
Mergeable: "MERGEABLE",
CheckStatus: "PENDING",
}
assert.False(t, h.Match(sig))
}
func TestEnableAutoMerge_Execute_Bad_InternalServerError(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewEnableAutoMergeHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 1,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.False(t, result.Success)
assert.Contains(t, result.Error, "merge failed")
}
// --- PublishDraft: Match with MERGED state ---
func TestPublishDraft_Match_Bad_MergedState(t *testing.T) {
h := NewPublishDraftHandler(nil)
sig := &jobrunner.PipelineSignal{
IsDraft: true,
PRState: "MERGED",
CheckStatus: "SUCCESS",
}
assert.False(t, h.Match(sig))
}
// --- SendFixCommand: Execute merge conflict message ---
func TestSendFixCommand_Execute_Good_MergeConflictMessage(t *testing.T) {
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost {
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
capturedBody = body["body"]
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewSendFixCommandHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 1,
Mergeable: "CONFLICTING",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Contains(t, capturedBody, "fix the merge conflict")
}
// --- DismissReviews: Execute with stale review that gets dismissed ---
func TestDismissReviews_Execute_Good_StaleReviewDismissed(t *testing.T) {
var dismissCalled bool
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/reviews") {
reviews := []map[string]any{
{
"id": 1, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true,
"body": "fix it", "commit_id": "abc123",
},
}
_ = json.NewEncoder(w).Encode(reviews)
return
}
if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/dismissals") {
dismissCalled = true
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "state": "DISMISSED"})
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewDismissReviewsHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 1,
PRState: "OPEN",
ThreadsTotal: 1,
ThreadsResolved: 0,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.True(t, dismissCalled)
}
// --- TickParent: Execute ticks and closes ---
func TestTickParent_Execute_Good_TicksCheckboxAndCloses(t *testing.T) {
epicBody := "## Tasks\n- [ ] #7\n- [ ] #8\n"
var editedBody string
var closedIssue bool
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42"):
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": epicBody,
"title": "Epic",
})
case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/42"):
var body map[string]any
_ = json.NewDecoder(r.Body).Decode(&body)
if b, ok := body["body"].(string); ok {
editedBody = b
}
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": editedBody,
"title": "Epic",
})
case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/7"):
closedIssue = true
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 7,
"state": "closed",
})
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewTickParentHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
EpicNumber: 42,
ChildNumber: 7,
PRNumber: 99,
PRState: "MERGED",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Contains(t, editedBody, "- [x] #7")
assert.True(t, closedIssue)
}
// --- Dispatch: DualRun mode ---
func TestDispatch_Execute_Good_DualRunModeDispatch(t *testing.T) {
srv := dispatchMockServer(t)
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := agentci.NewSpinner(
agentci.ClothoConfig{Strategy: "clotho-verified"},
map[string]agentci.AgentConfig{
"darbs-claude": {
Host: "localhost",
QueueDir: "/tmp/nonexistent-queue",
Active: true,
Model: "sonnet",
DualRun: true,
},
},
)
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 5,
EpicNumber: 3,
IssueTitle: "Test issue",
IssueBody: "Test body",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.Equal(t, "dispatch", result.Action)
}
// --- TickParent: ChildNumber not found in epic body ---
func TestTickParent_Execute_Good_ChildNotInBody(t *testing.T) {
epicBody := "## Tasks\n- [ ] #99\n"
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42") {
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": epicBody,
"title": "Epic",
})
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewTickParentHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
EpicNumber: 42,
ChildNumber: 50,
PRNumber: 100,
PRState: "MERGED",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
}
// --- Dispatch: AssignIssue fails (warn, continue) ---
func TestDispatch_Execute_Good_AssignIssueFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "in-progress", "color": "#1d76db"},
{"id": 2, "name": "agent-ready", "color": "#00ff00"},
})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"})
// GetIssue returns issue with NO special labels
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 5, "number": 5, "title": "Test Issue",
"labels": []map[string]any{},
})
// AssignIssue FAILS
case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"message":"assign failed"}`))
// AddIssueLabels succeeds
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{{"id": 1, "name": "in-progress"}})
case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/labels/"):
w.WriteHeader(http.StatusNoContent)
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/comments"):
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": "dispatched"})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
signal := &jobrunner.PipelineSignal{
EpicNumber: 1,
ChildNumber: 5,
PRNumber: 10,
RepoOwner: "org",
RepoName: "repo",
Assignee: "darbs-claude",
IssueTitle: "Test Issue",
IssueBody: "Test body",
}
// Should not return error because AssignIssue failure is only a warning.
result, err := h.Execute(context.Background(), signal)
// secureTransfer will fail because SSH isn't available, but we exercised the assign-error path.
_ = result
_ = err
}
// --- Dispatch: AddIssueLabels fails ---
func TestDispatch_Execute_Bad_AddIssueLabelsError(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "in-progress", "color": "#1d76db"},
})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"})
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 5, "number": 5, "title": "Test Issue",
"labels": []map[string]any{},
})
case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
_ = json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5})
// AddIssueLabels FAILS
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"):
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"message":"label add failed"}`))
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
signal := &jobrunner.PipelineSignal{
EpicNumber: 1,
ChildNumber: 5,
PRNumber: 10,
RepoOwner: "org",
RepoName: "repo",
Assignee: "darbs-claude",
IssueTitle: "Test Issue",
IssueBody: "Test body",
}
_, err := h.Execute(context.Background(), signal)
assert.Error(t, err)
assert.Contains(t, err.Error(), "add in-progress label")
}
// --- Dispatch: GetIssue returns issue with existing labels not matching ---
func TestDispatch_Execute_Good_IssueFoundNoSpecialLabels(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "in-progress", "color": "#1d76db"},
{"id": 2, "name": "agent-ready", "color": "#00ff00"},
})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"})
// GetIssue returns issue with unrelated labels
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 5, "number": 5, "title": "Test Issue",
"labels": []map[string]any{
{"id": 10, "name": "enhancement"},
},
})
case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
_ = json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5})
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{{"id": 1, "name": "in-progress"}})
case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/labels/"):
w.WriteHeader(http.StatusNoContent)
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/comments"):
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": "dispatched"})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
signal := &jobrunner.PipelineSignal{
EpicNumber: 1,
ChildNumber: 5,
PRNumber: 10,
RepoOwner: "org",
RepoName: "repo",
Assignee: "darbs-claude",
IssueTitle: "Test Issue",
IssueBody: "Test body",
}
// Execute will proceed past label check and try SSH (which fails).
result, err := h.Execute(context.Background(), signal)
// Should either succeed (if somehow SSH works) or fail at secureTransfer.
_ = result
_ = err
}

View file

@ -1,290 +0,0 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"path/filepath"
"time"
agentci "forge.lthn.ai/core/agent/pkg/orchestrator"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
coreerr "forge.lthn.ai/core/go-log"
)
const (
LabelAgentReady = "agent-ready"
LabelInProgress = "in-progress"
LabelAgentFailed = "agent-failed"
LabelAgentComplete = "agent-completed"
ColorInProgress = "#1d76db" // Blue
ColorAgentFailed = "#c0392b" // Red
)
// DispatchTicket is the JSON payload written to the agent's queue.
// The ForgeToken is transferred separately via a .env file with 0600 permissions.
type DispatchTicket struct {
ID string `json:"id"`
RepoOwner string `json:"repo_owner"`
RepoName string `json:"repo_name"`
IssueNumber int `json:"issue_number"`
IssueTitle string `json:"issue_title"`
IssueBody string `json:"issue_body"`
TargetBranch string `json:"target_branch"`
EpicNumber int `json:"epic_number"`
ForgeURL string `json:"forge_url"`
ForgeUser string `json:"forgejo_user"`
Model string `json:"model,omitempty"`
Runner string `json:"runner,omitempty"`
VerifyModel string `json:"verify_model,omitempty"`
DualRun bool `json:"dual_run"`
CreatedAt string `json:"created_at"`
}
// DispatchHandler dispatches coding work to remote agent machines via SSH.
type DispatchHandler struct {
forge *forge.Client
forgeURL string
token string
spinner *agentci.Spinner
}
// NewDispatchHandler creates a handler that dispatches tickets to agent machines.
func NewDispatchHandler(client *forge.Client, forgeURL, token string, spinner *agentci.Spinner) *DispatchHandler {
return &DispatchHandler{
forge: client,
forgeURL: forgeURL,
token: token,
spinner: spinner,
}
}
// Name returns the handler identifier.
func (h *DispatchHandler) Name() string {
return "dispatch"
}
// Match returns true for signals where a child issue needs coding (no PR yet)
// and the assignee is a known agent (by config key or Forgejo username).
func (h *DispatchHandler) Match(signal *jobrunner.PipelineSignal) bool {
if !signal.NeedsCoding {
return false
}
_, _, ok := h.spinner.FindByForgejoUser(signal.Assignee)
return ok
}
// Execute creates a ticket JSON and transfers it securely to the agent's queue directory.
func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) {
start := time.Now()
agentName, agent, ok := h.spinner.FindByForgejoUser(signal.Assignee)
if !ok {
return nil, coreerr.E("handlers.Dispatch.Execute", "unknown agent: "+signal.Assignee, nil)
}
// Sanitize inputs to prevent path traversal.
safeOwner, err := agentci.SanitizePath(signal.RepoOwner)
if err != nil {
return nil, coreerr.E("handlers.Dispatch.Execute", "invalid repo owner", err)
}
safeRepo, err := agentci.SanitizePath(signal.RepoName)
if err != nil {
return nil, coreerr.E("handlers.Dispatch.Execute", "invalid repo name", err)
}
// Ensure in-progress label exists on repo.
inProgressLabel, err := h.forge.EnsureLabel(safeOwner, safeRepo, LabelInProgress, ColorInProgress)
if err != nil {
return nil, coreerr.E("handlers.Dispatch.Execute", "ensure label "+LabelInProgress, err)
}
// Check if already in progress to prevent double-dispatch.
issue, err := h.forge.GetIssue(safeOwner, safeRepo, int64(signal.ChildNumber))
if err == nil {
for _, l := range issue.Labels {
if l.Name == LabelInProgress || l.Name == LabelAgentComplete {
coreerr.Info("issue already processed, skipping", "issue", signal.ChildNumber)
return &jobrunner.ActionResult{
Action: "dispatch",
Success: true,
Timestamp: time.Now(),
Duration: time.Since(start),
}, nil
}
}
}
// Assign agent and add in-progress label.
if err := h.forge.AssignIssue(safeOwner, safeRepo, int64(signal.ChildNumber), []string{signal.Assignee}); err != nil {
coreerr.Warn("failed to assign agent, continuing", "err", err)
}
if err := h.forge.AddIssueLabels(safeOwner, safeRepo, int64(signal.ChildNumber), []int64{inProgressLabel.ID}); err != nil {
return nil, coreerr.E("handlers.Dispatch.Execute", "add in-progress label", err)
}
// Remove agent-ready label if present.
if readyLabel, err := h.forge.GetLabelByName(safeOwner, safeRepo, LabelAgentReady); err == nil {
_ = h.forge.RemoveIssueLabel(safeOwner, safeRepo, int64(signal.ChildNumber), readyLabel.ID)
}
// Clotho planning — determine execution mode.
runMode := h.spinner.DeterminePlan(signal, agentName)
verifyModel := ""
if runMode == agentci.ModeDual {
verifyModel = h.spinner.GetVerifierModel(agentName)
}
// Build ticket.
targetBranch := "new" // TODO: resolve from epic or repo default
ticketID := fmt.Sprintf("%s-%s-%d-%d", safeOwner, safeRepo, signal.ChildNumber, time.Now().Unix())
ticket := DispatchTicket{
ID: ticketID,
RepoOwner: safeOwner,
RepoName: safeRepo,
IssueNumber: signal.ChildNumber,
IssueTitle: signal.IssueTitle,
IssueBody: signal.IssueBody,
TargetBranch: targetBranch,
EpicNumber: signal.EpicNumber,
ForgeURL: h.forgeURL,
ForgeUser: signal.Assignee,
Model: agent.Model,
Runner: agent.Runner,
VerifyModel: verifyModel,
DualRun: runMode == agentci.ModeDual,
CreatedAt: time.Now().UTC().Format(time.RFC3339),
}
ticketJSON, err := json.MarshalIndent(ticket, "", " ")
if err != nil {
h.failDispatch(signal, "Failed to marshal ticket JSON")
return nil, coreerr.E("handlers.Dispatch.Execute", "marshal ticket", err)
}
// Check if ticket already exists on agent (dedup).
ticketName := fmt.Sprintf("ticket-%s-%s-%d.json", safeOwner, safeRepo, signal.ChildNumber)
if h.ticketExists(ctx, agent, ticketName) {
coreerr.Info("ticket already queued, skipping", "ticket", ticketName, "agent", signal.Assignee)
return &jobrunner.ActionResult{
Action: "dispatch",
RepoOwner: safeOwner,
RepoName: safeRepo,
EpicNumber: signal.EpicNumber,
ChildNumber: signal.ChildNumber,
Success: true,
Timestamp: time.Now(),
Duration: time.Since(start),
}, nil
}
// Transfer ticket JSON.
remoteTicketPath := filepath.Join(agent.QueueDir, ticketName)
if err := h.secureTransfer(ctx, agent, remoteTicketPath, ticketJSON, 0644); err != nil {
h.failDispatch(signal, fmt.Sprintf("Ticket transfer failed: %v", err))
return &jobrunner.ActionResult{
Action: "dispatch",
RepoOwner: safeOwner,
RepoName: safeRepo,
EpicNumber: signal.EpicNumber,
ChildNumber: signal.ChildNumber,
Success: false,
Error: fmt.Sprintf("transfer ticket: %v", err),
Timestamp: time.Now(),
Duration: time.Since(start),
}, nil
}
// Transfer token via separate .env file with 0600 permissions.
envContent := fmt.Sprintf("FORGE_TOKEN=%s\n", h.token)
remoteEnvPath := filepath.Join(agent.QueueDir, fmt.Sprintf(".env.%s", ticketID))
if err := h.secureTransfer(ctx, agent, remoteEnvPath, []byte(envContent), 0600); err != nil {
// Clean up the ticket if env transfer fails.
_ = h.runRemote(ctx, agent, fmt.Sprintf("rm -f %s", agentci.EscapeShellArg(remoteTicketPath)))
h.failDispatch(signal, fmt.Sprintf("Token transfer failed: %v", err))
return &jobrunner.ActionResult{
Action: "dispatch",
RepoOwner: safeOwner,
RepoName: safeRepo,
EpicNumber: signal.EpicNumber,
ChildNumber: signal.ChildNumber,
Success: false,
Error: fmt.Sprintf("transfer token: %v", err),
Timestamp: time.Now(),
Duration: time.Since(start),
}, nil
}
// Comment on issue.
modeStr := "Standard"
if runMode == agentci.ModeDual {
modeStr = "Clotho Verified (Dual Run)"
}
comment := fmt.Sprintf("Dispatched to **%s** agent queue.\nMode: **%s**", signal.Assignee, modeStr)
_ = h.forge.CreateIssueComment(safeOwner, safeRepo, int64(signal.ChildNumber), comment)
return &jobrunner.ActionResult{
Action: "dispatch",
RepoOwner: safeOwner,
RepoName: safeRepo,
EpicNumber: signal.EpicNumber,
ChildNumber: signal.ChildNumber,
Success: true,
Timestamp: time.Now(),
Duration: time.Since(start),
}, nil
}
// failDispatch handles cleanup when dispatch fails (adds failed label, removes in-progress).
func (h *DispatchHandler) failDispatch(signal *jobrunner.PipelineSignal, reason string) {
if failedLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentFailed, ColorAgentFailed); err == nil {
_ = h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{failedLabel.ID})
}
if inProgressLabel, err := h.forge.GetLabelByName(signal.RepoOwner, signal.RepoName, LabelInProgress); err == nil {
_ = h.forge.RemoveIssueLabel(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), inProgressLabel.ID)
}
_ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), fmt.Sprintf("Agent dispatch failed: %s", reason))
}
// secureTransfer writes data to a remote path via SSH stdin, preventing command injection.
func (h *DispatchHandler) secureTransfer(ctx context.Context, agent agentci.AgentConfig, remotePath string, data []byte, mode int) error {
safeRemotePath := agentci.EscapeShellArg(remotePath)
remoteCmd := fmt.Sprintf("cat > %s && chmod %o %s", safeRemotePath, mode, safeRemotePath)
cmd := agentci.SecureSSHCommandContext(ctx, agent.Host, remoteCmd)
cmd.Stdin = bytes.NewReader(data)
output, err := cmd.CombinedOutput()
if err != nil {
return coreerr.E("dispatch.transfer", fmt.Sprintf("ssh to %s failed: %s", agent.Host, string(output)), err)
}
return nil
}
// runRemote executes a command on the agent via SSH.
func (h *DispatchHandler) runRemote(ctx context.Context, agent agentci.AgentConfig, cmdStr string) error {
cmd := agentci.SecureSSHCommandContext(ctx, agent.Host, cmdStr)
return cmd.Run()
}
// ticketExists checks if a ticket file already exists in queue, active, or done.
func (h *DispatchHandler) ticketExists(ctx context.Context, agent agentci.AgentConfig, ticketName string) bool {
safeTicket, err := agentci.SanitizePath(ticketName)
if err != nil {
return false
}
qDir := agent.QueueDir
checkCmd := fmt.Sprintf(
"test -f %s/%s || test -f %s/../active/%s || test -f %s/../done/%s",
qDir, safeTicket, qDir, safeTicket, qDir, safeTicket,
)
cmd := agentci.SecureSSHCommandContext(ctx, agent.Host, checkCmd)
return cmd.Run() == nil
}

View file

@ -1,327 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
agentci "forge.lthn.ai/core/agent/pkg/orchestrator"
"forge.lthn.ai/core/agent/pkg/jobrunner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestSpinner creates a Spinner with the given agents for testing.
func newTestSpinner(agents map[string]agentci.AgentConfig) *agentci.Spinner {
return agentci.NewSpinner(agentci.ClothoConfig{Strategy: "direct"}, agents)
}
// --- Match tests ---
func TestDispatch_Match_Good_NeedsCoding(t *testing.T) {
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true},
})
h := NewDispatchHandler(nil, "", "", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
}
assert.True(t, h.Match(sig))
}
func TestDispatch_Match_Good_MultipleAgents(t *testing.T) {
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true},
"local-codex": {Host: "localhost", QueueDir: "~/ai-work/queue", Active: true},
})
h := NewDispatchHandler(nil, "", "", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "local-codex",
}
assert.True(t, h.Match(sig))
}
func TestDispatch_Match_Bad_HasPR(t *testing.T) {
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true},
})
h := NewDispatchHandler(nil, "", "", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: false,
PRNumber: 7,
Assignee: "darbs-claude",
}
assert.False(t, h.Match(sig))
}
func TestDispatch_Match_Bad_UnknownAgent(t *testing.T) {
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true},
})
h := NewDispatchHandler(nil, "", "", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "unknown-user",
}
assert.False(t, h.Match(sig))
}
func TestDispatch_Match_Bad_NotAssigned(t *testing.T) {
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true},
})
h := NewDispatchHandler(nil, "", "", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "",
}
assert.False(t, h.Match(sig))
}
func TestDispatch_Match_Bad_EmptyAgentMap(t *testing.T) {
spinner := newTestSpinner(map[string]agentci.AgentConfig{})
h := NewDispatchHandler(nil, "", "", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
}
assert.False(t, h.Match(sig))
}
// --- Name test ---
func TestDispatch_Name_Good(t *testing.T) {
spinner := newTestSpinner(nil)
h := NewDispatchHandler(nil, "", "", spinner)
assert.Equal(t, "dispatch", h.Name())
}
// --- Execute tests ---
func TestDispatch_Execute_Bad_UnknownAgent(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "nonexistent-agent",
RepoOwner: "host-uk",
RepoName: "core",
ChildNumber: 1,
}
_, err := h.Execute(context.Background(), sig)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown agent")
}
func TestDispatch_TicketJSON_Good(t *testing.T) {
ticket := DispatchTicket{
ID: "host-uk-core-5-1234567890",
RepoOwner: "host-uk",
RepoName: "core",
IssueNumber: 5,
IssueTitle: "Fix the thing",
IssueBody: "Please fix this bug",
TargetBranch: "new",
EpicNumber: 3,
ForgeURL: "https://forge.lthn.ai",
ForgeUser: "darbs-claude",
Model: "sonnet",
Runner: "claude",
DualRun: false,
CreatedAt: "2026-02-09T12:00:00Z",
}
data, err := json.MarshalIndent(ticket, "", " ")
require.NoError(t, err)
var decoded map[string]any
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, "host-uk-core-5-1234567890", decoded["id"])
assert.Equal(t, "host-uk", decoded["repo_owner"])
assert.Equal(t, "core", decoded["repo_name"])
assert.Equal(t, float64(5), decoded["issue_number"])
assert.Equal(t, "Fix the thing", decoded["issue_title"])
assert.Equal(t, "Please fix this bug", decoded["issue_body"])
assert.Equal(t, "new", decoded["target_branch"])
assert.Equal(t, float64(3), decoded["epic_number"])
assert.Equal(t, "https://forge.lthn.ai", decoded["forge_url"])
assert.Equal(t, "darbs-claude", decoded["forgejo_user"])
assert.Equal(t, "sonnet", decoded["model"])
assert.Equal(t, "claude", decoded["runner"])
// Token should NOT be present in the ticket.
_, hasToken := decoded["forge_token"]
assert.False(t, hasToken, "forge_token must not be in ticket JSON")
}
func TestDispatch_TicketJSON_Good_DualRun(t *testing.T) {
ticket := DispatchTicket{
ID: "test-dual",
RepoOwner: "host-uk",
RepoName: "core",
IssueNumber: 1,
ForgeURL: "https://forge.lthn.ai",
Model: "gemini-2.0-flash",
VerifyModel: "gemini-1.5-pro",
DualRun: true,
}
data, err := json.Marshal(ticket)
require.NoError(t, err)
var roundtrip DispatchTicket
err = json.Unmarshal(data, &roundtrip)
require.NoError(t, err)
assert.True(t, roundtrip.DualRun)
assert.Equal(t, "gemini-1.5-pro", roundtrip.VerifyModel)
}
func TestDispatch_TicketJSON_Good_OmitsEmptyModelRunner(t *testing.T) {
ticket := DispatchTicket{
ID: "test-1",
RepoOwner: "host-uk",
RepoName: "core",
IssueNumber: 1,
TargetBranch: "new",
ForgeURL: "https://forge.lthn.ai",
}
data, err := json.MarshalIndent(ticket, "", " ")
require.NoError(t, err)
var decoded map[string]any
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
_, hasModel := decoded["model"]
_, hasRunner := decoded["runner"]
assert.False(t, hasModel, "model should be omitted when empty")
assert.False(t, hasRunner, "runner should be omitted when empty")
}
func TestDispatch_TicketJSON_Good_ModelRunnerVariants(t *testing.T) {
tests := []struct {
name string
model string
runner string
}{
{"claude-sonnet", "sonnet", "claude"},
{"claude-opus", "opus", "claude"},
{"codex-default", "", "codex"},
{"gemini-default", "", "gemini"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ticket := DispatchTicket{
ID: "test-" + tt.name,
RepoOwner: "host-uk",
RepoName: "core",
IssueNumber: 1,
TargetBranch: "new",
ForgeURL: "https://forge.lthn.ai",
Model: tt.model,
Runner: tt.runner,
}
data, err := json.Marshal(ticket)
require.NoError(t, err)
var roundtrip DispatchTicket
err = json.Unmarshal(data, &roundtrip)
require.NoError(t, err)
assert.Equal(t, tt.model, roundtrip.Model)
assert.Equal(t, tt.runner, roundtrip.Runner)
})
}
}
func TestDispatch_Execute_Good_PostsComment(t *testing.T) {
var commentPosted bool
var commentBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/host-uk/core/labels":
json.NewEncoder(w).Encode([]any{})
return
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/labels":
json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress", "color": "#1d76db"})
return
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5":
json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5, "labels": []any{}, "title": "Test"})
return
case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5":
json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5})
return
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5/labels":
json.NewEncoder(w).Encode([]any{map[string]any{"id": 1, "name": "in-progress"}})
return
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5/comments":
commentPosted = true
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
commentBody = body["body"]
json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]})
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{})
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
RepoOwner: "host-uk",
RepoName: "core",
ChildNumber: 5,
EpicNumber: 3,
IssueTitle: "Test issue",
IssueBody: "Test body",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.Equal(t, "dispatch", result.Action)
assert.Equal(t, "host-uk", result.RepoOwner)
assert.Equal(t, "core", result.RepoName)
assert.Equal(t, 3, result.EpicNumber)
assert.Equal(t, 5, result.ChildNumber)
if result.Success {
assert.True(t, commentPosted)
assert.Contains(t, commentBody, "darbs-claude")
}
}

View file

@ -1,58 +0,0 @@
package handlers
import (
"context"
"fmt"
"time"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// EnableAutoMergeHandler merges a PR that is ready using squash strategy.
type EnableAutoMergeHandler struct {
forge *forge.Client
}
// NewEnableAutoMergeHandler creates a handler that merges ready PRs.
func NewEnableAutoMergeHandler(f *forge.Client) *EnableAutoMergeHandler {
return &EnableAutoMergeHandler{forge: f}
}
// Name returns the handler identifier.
func (h *EnableAutoMergeHandler) Name() string {
return "enable_auto_merge"
}
// Match returns true when the PR is open, not a draft, mergeable, checks
// are passing, and there are no unresolved review threads.
func (h *EnableAutoMergeHandler) Match(signal *jobrunner.PipelineSignal) bool {
return signal.PRState == "OPEN" &&
!signal.IsDraft &&
signal.Mergeable == "MERGEABLE" &&
signal.CheckStatus == "SUCCESS" &&
!signal.HasUnresolvedThreads()
}
// Execute merges the pull request with squash strategy.
func (h *EnableAutoMergeHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) {
start := time.Now()
err := h.forge.MergePullRequest(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber), "squash")
result := &jobrunner.ActionResult{
Action: "enable_auto_merge",
RepoOwner: signal.RepoOwner,
RepoName: signal.RepoName,
PRNumber: signal.PRNumber,
Success: err == nil,
Timestamp: time.Now(),
Duration: time.Since(start),
}
if err != nil {
result.Error = fmt.Sprintf("merge failed: %v", err)
}
return result, nil
}

View file

@ -1,105 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
func TestEnableAutoMerge_Match_Good(t *testing.T) {
h := NewEnableAutoMergeHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
IsDraft: false,
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
ThreadsTotal: 0,
ThreadsResolved: 0,
}
assert.True(t, h.Match(sig))
}
func TestEnableAutoMerge_Match_Bad_Draft(t *testing.T) {
h := NewEnableAutoMergeHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
IsDraft: true,
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
ThreadsTotal: 0,
ThreadsResolved: 0,
}
assert.False(t, h.Match(sig))
}
func TestEnableAutoMerge_Match_Bad_UnresolvedThreads(t *testing.T) {
h := NewEnableAutoMergeHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
IsDraft: false,
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
ThreadsTotal: 5,
ThreadsResolved: 3,
}
assert.False(t, h.Match(sig))
}
func TestEnableAutoMerge_Execute_Good(t *testing.T) {
var capturedPath string
var capturedMethod string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedMethod = r.Method
capturedPath = r.URL.Path
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewEnableAutoMergeHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "host-uk",
RepoName: "core-php",
PRNumber: 55,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "enable_auto_merge", result.Action)
assert.Equal(t, http.MethodPost, capturedMethod)
assert.Equal(t, "/api/v1/repos/host-uk/core-php/pulls/55/merge", capturedPath)
}
func TestEnableAutoMerge_Execute_Bad_MergeFailed(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusConflict)
_ = json.NewEncoder(w).Encode(map[string]string{"message": "merge conflict"})
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewEnableAutoMergeHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "host-uk",
RepoName: "core-php",
PRNumber: 55,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.False(t, result.Success)
assert.Contains(t, result.Error, "merge failed")
}

View file

@ -1,583 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
agentci "forge.lthn.ai/core/agent/pkg/orchestrator"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// --- Name tests for all handlers ---
func TestEnableAutoMerge_Name_Good(t *testing.T) {
h := NewEnableAutoMergeHandler(nil)
assert.Equal(t, "enable_auto_merge", h.Name())
}
func TestPublishDraft_Name_Good(t *testing.T) {
h := NewPublishDraftHandler(nil)
assert.Equal(t, "publish_draft", h.Name())
}
func TestDismissReviews_Name_Good(t *testing.T) {
h := NewDismissReviewsHandler(nil)
assert.Equal(t, "dismiss_reviews", h.Name())
}
func TestSendFixCommand_Name_Good(t *testing.T) {
h := NewSendFixCommandHandler(nil)
assert.Equal(t, "send_fix_command", h.Name())
}
func TestTickParent_Name_Good(t *testing.T) {
h := NewTickParentHandler(nil)
assert.Equal(t, "tick_parent", h.Name())
}
// --- Additional Match tests ---
func TestEnableAutoMerge_Match_Bad_Closed(t *testing.T) {
h := NewEnableAutoMergeHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "CLOSED",
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
}
assert.False(t, h.Match(sig))
}
func TestEnableAutoMerge_Match_Bad_ChecksFailing(t *testing.T) {
h := NewEnableAutoMergeHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
Mergeable: "MERGEABLE",
CheckStatus: "FAILURE",
}
assert.False(t, h.Match(sig))
}
func TestEnableAutoMerge_Match_Bad_Conflicting(t *testing.T) {
h := NewEnableAutoMergeHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
Mergeable: "CONFLICTING",
CheckStatus: "SUCCESS",
}
assert.False(t, h.Match(sig))
}
func TestPublishDraft_Match_Bad_Closed(t *testing.T) {
h := NewPublishDraftHandler(nil)
sig := &jobrunner.PipelineSignal{
IsDraft: true,
PRState: "CLOSED",
CheckStatus: "SUCCESS",
}
assert.False(t, h.Match(sig))
}
func TestDismissReviews_Match_Bad_Closed(t *testing.T) {
h := NewDismissReviewsHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "CLOSED",
ThreadsTotal: 3,
ThreadsResolved: 1,
}
assert.False(t, h.Match(sig))
}
func TestDismissReviews_Match_Bad_NoThreads(t *testing.T) {
h := NewDismissReviewsHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
ThreadsTotal: 0,
ThreadsResolved: 0,
}
assert.False(t, h.Match(sig))
}
func TestSendFixCommand_Match_Bad_Closed(t *testing.T) {
h := NewSendFixCommandHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "CLOSED",
Mergeable: "CONFLICTING",
}
assert.False(t, h.Match(sig))
}
func TestSendFixCommand_Match_Bad_NoIssues(t *testing.T) {
h := NewSendFixCommandHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
}
assert.False(t, h.Match(sig))
}
func TestSendFixCommand_Match_Good_ThreadsFailure(t *testing.T) {
h := NewSendFixCommandHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
Mergeable: "MERGEABLE",
CheckStatus: "FAILURE",
ThreadsTotal: 2,
ThreadsResolved: 0,
}
assert.True(t, h.Match(sig))
}
func TestTickParent_Match_Bad_Closed(t *testing.T) {
h := NewTickParentHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "CLOSED",
}
assert.False(t, h.Match(sig))
}
// --- Additional Execute tests ---
func TestPublishDraft_Execute_Bad_ServerError(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewPublishDraftHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 1,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.False(t, result.Success)
assert.Contains(t, result.Error, "publish draft failed")
}
func TestSendFixCommand_Execute_Good_Reviews(t *testing.T) {
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost {
b, _ := io.ReadAll(r.Body)
capturedBody = string(b)
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":1}`))
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewSendFixCommandHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 5,
PRState: "OPEN",
Mergeable: "MERGEABLE",
CheckStatus: "FAILURE",
ThreadsTotal: 2,
ThreadsResolved: 0,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Contains(t, capturedBody, "fix the code reviews")
}
func TestSendFixCommand_Execute_Bad_CommentFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewSendFixCommandHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 1,
Mergeable: "CONFLICTING",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.False(t, result.Success)
assert.Contains(t, result.Error, "post comment failed")
}
func TestTickParent_Execute_Good_AlreadyTicked(t *testing.T) {
epicBody := "## Tasks\n- [x] #7\n- [ ] #8\n"
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42") {
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": epicBody,
"title": "Epic",
})
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewTickParentHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
EpicNumber: 42,
ChildNumber: 7,
PRNumber: 99,
PRState: "MERGED",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "tick_parent", result.Action)
}
func TestTickParent_Execute_Bad_FetchEpicFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewTickParentHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
EpicNumber: 999,
ChildNumber: 1,
PRState: "MERGED",
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "fetch epic")
}
func TestTickParent_Execute_Bad_EditEpicFails(t *testing.T) {
epicBody := "## Tasks\n- [ ] #7\n"
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42"):
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": epicBody,
"title": "Epic",
})
case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/42"):
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewTickParentHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
EpicNumber: 42,
ChildNumber: 7,
PRNumber: 99,
PRState: "MERGED",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.False(t, result.Success)
assert.Contains(t, result.Error, "edit epic failed")
}
func TestTickParent_Execute_Bad_CloseChildFails(t *testing.T) {
epicBody := "## Tasks\n- [ ] #7\n"
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42"):
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": epicBody,
"title": "Epic",
})
case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/42"):
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": strings.Replace(epicBody, "- [ ] #7", "- [x] #7", 1),
"title": "Epic",
})
case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/7"):
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusOK)
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewTickParentHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
EpicNumber: 42,
ChildNumber: 7,
PRNumber: 99,
PRState: "MERGED",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.False(t, result.Success)
assert.Contains(t, result.Error, "close child issue failed")
}
func TestDismissReviews_Execute_Bad_ListFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewDismissReviewsHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 1,
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "list reviews")
}
func TestDismissReviews_Execute_Good_NothingToDismiss(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet {
// All reviews are either approved or already dismissed.
reviews := []map[string]any{
{
"id": 1, "state": "APPROVED", "dismissed": false, "stale": false,
"body": "lgtm", "commit_id": "abc123",
},
{
"id": 2, "state": "REQUEST_CHANGES", "dismissed": true, "stale": true,
"body": "already dismissed", "commit_id": "abc123",
},
{
"id": 3, "state": "REQUEST_CHANGES", "dismissed": false, "stale": false,
"body": "not stale", "commit_id": "abc123",
},
}
_ = json.NewEncoder(w).Encode(reviews)
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewDismissReviewsHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 1,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success, "nothing to dismiss should be success")
}
func TestDismissReviews_Execute_Bad_DismissFails(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet {
reviews := []map[string]any{
{
"id": 1, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true,
"body": "fix it", "commit_id": "abc123",
},
}
_ = json.NewEncoder(w).Encode(reviews)
return
}
// Dismiss fails.
w.WriteHeader(http.StatusForbidden)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewDismissReviewsHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "org",
RepoName: "repo",
PRNumber: 1,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.False(t, result.Success)
assert.Contains(t, result.Error, "failed to dismiss")
}
// --- Dispatch Execute edge cases ---
func TestDispatch_Execute_Good_AlreadyInProgress(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "in-progress", "color": "#1d76db"},
})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"})
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
// Issue already has in-progress label.
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 5,
"number": 5,
"labels": []map[string]any{{"name": "in-progress", "id": 1}},
"title": "Test",
})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 5,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success, "already in-progress should be a no-op success")
}
func TestDispatch_Execute_Good_AlreadyCompleted(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 2, "name": "agent-completed", "color": "#0e8a16"},
})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"})
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 5,
"number": 5,
"labels": []map[string]any{{"name": "agent-completed", "id": 2}},
"title": "Done",
})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
RepoOwner: "org",
RepoName: "repo",
ChildNumber: 5,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
}
func TestDispatch_Execute_Bad_InvalidRepoOwner(t *testing.T) {
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
spinner := newTestSpinner(map[string]agentci.AgentConfig{
"darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true},
})
h := NewDispatchHandler(client, srv.URL, "test-token", spinner)
sig := &jobrunner.PipelineSignal{
NeedsCoding: true,
Assignee: "darbs-claude",
RepoOwner: "org$bad",
RepoName: "repo",
ChildNumber: 1,
}
_, err := h.Execute(context.Background(), sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid repo owner")
}

View file

@ -1,824 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// --- Integration: full signal -> handler -> result flow ---
// These tests exercise the complete pipeline: a signal is created,
// matched by a handler, executed against a mock Forgejo server,
// and the result is verified.
// mockForgejoServer creates a comprehensive mock Forgejo API server
// for integration testing. It supports issues, PRs, labels, comments,
// and tracks all API calls made.
type apiCall struct {
Method string
Path string
Body string
}
type forgejoMock struct {
epicBody string
calls []apiCall
srv *httptest.Server
closedChild bool
editedBody string
comments []string
}
func newForgejoMock(t *testing.T, epicBody string) *forgejoMock {
t.Helper()
m := &forgejoMock{
epicBody: epicBody,
}
m.srv = httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
m.calls = append(m.calls, apiCall{
Method: r.Method,
Path: r.URL.Path,
Body: string(bodyBytes),
})
w.Header().Set("Content-Type", "application/json")
path := r.URL.Path
switch {
// GET epic issue.
case r.Method == http.MethodGet && strings.Contains(path, "/issues/") && !strings.Contains(path, "/comments"):
issueNum := path[strings.LastIndex(path, "/")+1:]
_ = json.NewEncoder(w).Encode(map[string]any{
"number": json.Number(issueNum),
"body": m.epicBody,
"title": "Epic: Phase 3",
"state": "open",
"labels": []map[string]any{{"name": "epic", "id": 1}},
})
// PATCH epic issue (edit body or close child).
case r.Method == http.MethodPatch && strings.Contains(path, "/issues/"):
var body map[string]any
_ = json.Unmarshal(bodyBytes, &body)
if bodyStr, ok := body["body"].(string); ok {
m.editedBody = bodyStr
}
if state, ok := body["state"].(string); ok && state == "closed" {
m.closedChild = true
}
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 1,
"body": m.editedBody,
"state": "open",
})
// POST comment.
case r.Method == http.MethodPost && strings.Contains(path, "/comments"):
var body map[string]string
_ = json.Unmarshal(bodyBytes, &body)
m.comments = append(m.comments, body["body"])
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]})
// GET labels.
case r.Method == http.MethodGet && strings.Contains(path, "/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "epic", "color": "#ff0000"},
{"id": 2, "name": "in-progress", "color": "#1d76db"},
})
// POST labels.
case r.Method == http.MethodPost && strings.HasSuffix(path, "/labels"):
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 10, "name": "new-label"})
// POST issue labels.
case r.Method == http.MethodPost && strings.Contains(path, "/issues/") && strings.Contains(path, "/labels"):
_ = json.NewEncoder(w).Encode([]map[string]any{})
// DELETE issue label.
case r.Method == http.MethodDelete && strings.Contains(path, "/labels/"):
w.WriteHeader(http.StatusNoContent)
// POST merge PR.
case r.Method == http.MethodPost && strings.Contains(path, "/merge"):
w.WriteHeader(http.StatusOK)
// PATCH PR (publish draft).
case r.Method == http.MethodPatch && strings.Contains(path, "/pulls/"):
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{}`))
// GET reviews.
case r.Method == http.MethodGet && strings.Contains(path, "/reviews"):
_ = json.NewEncoder(w).Encode([]map[string]any{})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
return m
}
func (m *forgejoMock) close() {
m.srv.Close()
}
func (m *forgejoMock) client(t *testing.T) *forge.Client {
t.Helper()
c, err := forge.New(m.srv.URL, "test-token")
require.NoError(t, err)
return c
}
// --- TickParent integration: signal -> execute -> verify epic updated ---
func TestIntegration_TickParent_Good_FullFlow(t *testing.T) {
epicBody := "## Tasks\n- [x] #1\n- [ ] #7\n- [ ] #8\n- [x] #3\n"
mock := newForgejoMock(t, epicBody)
defer mock.close()
h := NewTickParentHandler(mock.client(t))
// Create signal representing a merged PR for child #7.
signal := &jobrunner.PipelineSignal{
EpicNumber: 42,
ChildNumber: 7,
PRNumber: 99,
RepoOwner: "host-uk",
RepoName: "core-php",
PRState: "MERGED",
CheckStatus: "SUCCESS",
Mergeable: "UNKNOWN",
}
// Verify the handler matches.
assert.True(t, h.Match(signal))
// Execute.
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
// Verify result.
assert.True(t, result.Success)
assert.Equal(t, "tick_parent", result.Action)
assert.Equal(t, "host-uk", result.RepoOwner)
assert.Equal(t, "core-php", result.RepoName)
assert.Equal(t, 99, result.PRNumber)
// Verify the epic body was updated: #7 should now be checked.
assert.Contains(t, mock.editedBody, "- [x] #7")
// #8 should still be unchecked.
assert.Contains(t, mock.editedBody, "- [ ] #8")
// #1 and #3 should remain checked.
assert.Contains(t, mock.editedBody, "- [x] #1")
assert.Contains(t, mock.editedBody, "- [x] #3")
// Verify the child issue was closed.
assert.True(t, mock.closedChild)
}
// --- TickParent integration: epic progress tracking ---
func TestIntegration_TickParent_Good_TrackEpicProgress(t *testing.T) {
// Start with 4 tasks, 1 checked.
epicBody := "## Tasks\n- [x] #1\n- [ ] #2\n- [ ] #3\n- [ ] #4\n"
mock := newForgejoMock(t, epicBody)
defer mock.close()
h := NewTickParentHandler(mock.client(t))
// Tick child #2.
signal := &jobrunner.PipelineSignal{
EpicNumber: 10,
ChildNumber: 2,
PRNumber: 20,
RepoOwner: "org",
RepoName: "repo",
PRState: "MERGED",
}
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
assert.True(t, result.Success)
// Verify #2 is now checked.
assert.Contains(t, mock.editedBody, "- [x] #2")
// #3 and #4 should still be unchecked.
assert.Contains(t, mock.editedBody, "- [ ] #3")
assert.Contains(t, mock.editedBody, "- [ ] #4")
// Count progress: 2 out of 4 now checked.
checked := strings.Count(mock.editedBody, "- [x]")
unchecked := strings.Count(mock.editedBody, "- [ ]")
assert.Equal(t, 2, checked)
assert.Equal(t, 2, unchecked)
}
// --- EnableAutoMerge integration: full flow ---
func TestIntegration_EnableAutoMerge_Good_FullFlow(t *testing.T) {
var mergeMethod string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/merge") {
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]any
_ = json.Unmarshal(bodyBytes, &body)
if do, ok := body["Do"].(string); ok {
mergeMethod = do
}
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewEnableAutoMergeHandler(client)
signal := &jobrunner.PipelineSignal{
EpicNumber: 1,
ChildNumber: 5,
PRNumber: 42,
RepoOwner: "host-uk",
RepoName: "core-tenant",
PRState: "OPEN",
IsDraft: false,
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
}
// Verify match.
assert.True(t, h.Match(signal))
// Execute.
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "enable_auto_merge", result.Action)
assert.Equal(t, "host-uk", result.RepoOwner)
assert.Equal(t, "core-tenant", result.RepoName)
assert.Equal(t, 42, result.PRNumber)
assert.Equal(t, "squash", mergeMethod)
}
// --- PublishDraft integration: full flow ---
func TestIntegration_PublishDraft_Good_FullFlow(t *testing.T) {
var patchedDraft bool
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/pulls/") {
bodyBytes, _ := io.ReadAll(r.Body)
if strings.Contains(string(bodyBytes), `"draft":false`) {
patchedDraft = true
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{}`))
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewPublishDraftHandler(client)
signal := &jobrunner.PipelineSignal{
EpicNumber: 3,
ChildNumber: 8,
PRNumber: 15,
RepoOwner: "org",
RepoName: "repo",
PRState: "OPEN",
IsDraft: true,
CheckStatus: "SUCCESS",
Mergeable: "MERGEABLE",
}
// Verify match.
assert.True(t, h.Match(signal))
// Execute.
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "publish_draft", result.Action)
assert.True(t, patchedDraft)
}
// --- SendFixCommand integration: conflict message ---
func TestIntegration_SendFixCommand_Good_ConflictFlow(t *testing.T) {
var commentBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments") {
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]string
_ = json.Unmarshal(bodyBytes, &body)
commentBody = body["body"]
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewSendFixCommandHandler(client)
signal := &jobrunner.PipelineSignal{
EpicNumber: 1,
ChildNumber: 3,
PRNumber: 10,
RepoOwner: "org",
RepoName: "repo",
PRState: "OPEN",
Mergeable: "CONFLICTING",
CheckStatus: "SUCCESS",
}
assert.True(t, h.Match(signal))
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "send_fix_command", result.Action)
assert.Contains(t, commentBody, "fix the merge conflict")
}
// --- SendFixCommand integration: code review message ---
func TestIntegration_SendFixCommand_Good_ReviewFlow(t *testing.T) {
var commentBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments") {
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]string
_ = json.Unmarshal(bodyBytes, &body)
commentBody = body["body"]
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
return
}
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewSendFixCommandHandler(client)
signal := &jobrunner.PipelineSignal{
EpicNumber: 1,
ChildNumber: 3,
PRNumber: 10,
RepoOwner: "org",
RepoName: "repo",
PRState: "OPEN",
Mergeable: "MERGEABLE",
CheckStatus: "FAILURE",
ThreadsTotal: 3,
ThreadsResolved: 1,
}
assert.True(t, h.Match(signal))
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Contains(t, commentBody, "fix the code reviews")
}
// --- Completion integration: success flow ---
func TestIntegration_Completion_Good_SuccessFlow(t *testing.T) {
var labelAdded bool
var labelRemoved bool
var commentBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
// GetLabelByName — GET repo labels.
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/core/go-scm/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "in-progress", "color": "#1d76db"},
})
// RemoveIssueLabel.
case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/labels/"):
labelRemoved = true
w.WriteHeader(http.StatusNoContent)
// EnsureLabel — POST to create repo label.
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/core/go-scm/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 2, "name": "agent-completed", "color": "#0e8a16"})
// AddIssueLabels — POST to issue labels.
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/core/go-scm/issues/12/labels":
labelAdded = true
_ = json.NewEncoder(w).Encode([]map[string]any{{"id": 2, "name": "agent-completed"}})
// CreateIssueComment.
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments"):
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]string
_ = json.Unmarshal(bodyBytes, &body)
commentBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
signal := &jobrunner.PipelineSignal{
Type: "agent_completion",
EpicNumber: 5,
ChildNumber: 12,
RepoOwner: "core",
RepoName: "go-scm",
Success: true,
Message: "PR created and tests passing",
}
assert.True(t, h.Match(signal))
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "completion", result.Action)
assert.Equal(t, "core", result.RepoOwner)
assert.Equal(t, "go-scm", result.RepoName)
assert.Equal(t, 5, result.EpicNumber)
assert.Equal(t, 12, result.ChildNumber)
assert.True(t, labelRemoved, "in-progress label should be removed")
assert.True(t, labelAdded, "agent-completed label should be added")
assert.Contains(t, commentBody, "PR created and tests passing")
}
// --- Full pipeline integration: signal -> match -> execute -> journal ---
func TestIntegration_FullPipeline_Good_TickParentWithJournal(t *testing.T) {
epicBody := "## Tasks\n- [ ] #7\n- [ ] #8\n"
mock := newForgejoMock(t, epicBody)
defer mock.close()
dir := t.TempDir()
journal, err := jobrunner.NewJournal(dir)
require.NoError(t, err)
client := mock.client(t)
h := NewTickParentHandler(client)
signal := &jobrunner.PipelineSignal{
EpicNumber: 10,
ChildNumber: 7,
PRNumber: 55,
RepoOwner: "host-uk",
RepoName: "core-tenant",
PRState: "MERGED",
CheckStatus: "SUCCESS",
Mergeable: "UNKNOWN",
}
// Verify match.
assert.True(t, h.Match(signal))
// Execute.
start := time.Now()
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
assert.True(t, result.Success)
// Write to journal (simulating what the poller does).
result.EpicNumber = signal.EpicNumber
result.ChildNumber = signal.ChildNumber
result.Cycle = 1
result.Duration = time.Since(start)
err = journal.Append(signal, result)
require.NoError(t, err)
// Verify the journal file exists and contains the entry.
date := time.Now().UTC().Format("2006-01-02")
journalPath := filepath.Join(dir, "host-uk", "core-tenant", date+".jsonl")
_, statErr := os.Stat(journalPath)
require.NoError(t, statErr)
f, err := os.Open(journalPath)
require.NoError(t, err)
defer func() { _ = f.Close() }()
var entry jobrunner.JournalEntry
err = json.NewDecoder(f).Decode(&entry)
require.NoError(t, err)
assert.Equal(t, "tick_parent", entry.Action)
assert.Equal(t, "host-uk/core-tenant", entry.Repo)
assert.Equal(t, 10, entry.Epic)
assert.Equal(t, 7, entry.Child)
assert.Equal(t, 55, entry.PR)
assert.Equal(t, 1, entry.Cycle)
assert.True(t, entry.Result.Success)
assert.Equal(t, "MERGED", entry.Signals.PRState)
// Verify the epic was properly updated.
assert.Contains(t, mock.editedBody, "- [x] #7")
assert.Contains(t, mock.editedBody, "- [ ] #8")
assert.True(t, mock.closedChild)
}
// --- Handler matching priority: first match wins ---
func TestIntegration_HandlerPriority_Good_FirstMatchWins(t *testing.T) {
// Test that when multiple handlers could match, the first one wins.
// This exercises the poller's findHandler logic.
// Signal with OPEN, not draft, MERGEABLE, SUCCESS, no threads:
// This matches enable_auto_merge.
signal := &jobrunner.PipelineSignal{
PRState: "OPEN",
IsDraft: false,
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
ThreadsTotal: 0,
ThreadsResolved: 0,
}
autoMerge := NewEnableAutoMergeHandler(nil)
publishDraft := NewPublishDraftHandler(nil)
fixCommand := NewSendFixCommandHandler(nil)
// enable_auto_merge should match.
assert.True(t, autoMerge.Match(signal))
// publish_draft should NOT match (not a draft).
assert.False(t, publishDraft.Match(signal))
// send_fix_command should NOT match (mergeable and passing).
assert.False(t, fixCommand.Match(signal))
}
// --- Handler matching: draft PR path ---
func TestIntegration_HandlerPriority_Good_DraftPRPath(t *testing.T) {
signal := &jobrunner.PipelineSignal{
PRState: "OPEN",
IsDraft: true,
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
ThreadsTotal: 0,
ThreadsResolved: 0,
}
autoMerge := NewEnableAutoMergeHandler(nil)
publishDraft := NewPublishDraftHandler(nil)
fixCommand := NewSendFixCommandHandler(nil)
// enable_auto_merge should NOT match (is draft).
assert.False(t, autoMerge.Match(signal))
// publish_draft should match (draft + open + success).
assert.True(t, publishDraft.Match(signal))
// send_fix_command should NOT match.
assert.False(t, fixCommand.Match(signal))
}
// --- Handler matching: merged PR only matches tick_parent ---
func TestIntegration_HandlerPriority_Good_MergedPRPath(t *testing.T) {
signal := &jobrunner.PipelineSignal{
PRState: "MERGED",
IsDraft: false,
Mergeable: "UNKNOWN",
CheckStatus: "SUCCESS",
ThreadsTotal: 0,
ThreadsResolved: 0,
}
autoMerge := NewEnableAutoMergeHandler(nil)
publishDraft := NewPublishDraftHandler(nil)
fixCommand := NewSendFixCommandHandler(nil)
tickParent := NewTickParentHandler(nil)
assert.False(t, autoMerge.Match(signal))
assert.False(t, publishDraft.Match(signal))
assert.False(t, fixCommand.Match(signal))
assert.True(t, tickParent.Match(signal))
}
// --- Handler matching: conflicting PR matches send_fix_command ---
func TestIntegration_HandlerPriority_Good_ConflictingPRPath(t *testing.T) {
signal := &jobrunner.PipelineSignal{
PRState: "OPEN",
IsDraft: false,
Mergeable: "CONFLICTING",
CheckStatus: "SUCCESS",
ThreadsTotal: 0,
ThreadsResolved: 0,
}
autoMerge := NewEnableAutoMergeHandler(nil)
fixCommand := NewSendFixCommandHandler(nil)
// enable_auto_merge should NOT match (conflicting).
assert.False(t, autoMerge.Match(signal))
// send_fix_command should match (conflicting).
assert.True(t, fixCommand.Match(signal))
}
// --- Completion integration: failure flow ---
func TestIntegration_Completion_Good_FailureFlow(t *testing.T) {
var commentBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
// GetLabelByName — GET repo labels.
case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/core/go-scm/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{
{"id": 1, "name": "in-progress", "color": "#1d76db"},
})
// RemoveIssueLabel.
case r.Method == http.MethodDelete:
w.WriteHeader(http.StatusNoContent)
// EnsureLabel — POST to create repo label.
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/core/go-scm/labels":
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 3, "name": "agent-failed", "color": "#c0392b"})
// AddIssueLabels — POST to issue labels.
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/core/go-scm/issues/12/labels":
_ = json.NewEncoder(w).Encode([]map[string]any{{"id": 3, "name": "agent-failed"}})
// CreateIssueComment.
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments"):
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]string
_ = json.Unmarshal(bodyBytes, &body)
commentBody = body["body"]
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewCompletionHandler(client)
signal := &jobrunner.PipelineSignal{
Type: "agent_completion",
EpicNumber: 5,
ChildNumber: 12,
RepoOwner: "core",
RepoName: "go-scm",
Success: false,
Error: "tests failed: 3 assertions",
}
result, err := h.Execute(context.Background(), signal)
require.NoError(t, err)
assert.True(t, result.Success) // The handler itself succeeded.
assert.Contains(t, commentBody, "Agent reported failure")
assert.Contains(t, commentBody, "tests failed: 3 assertions")
}
// --- Multiple handlers execute in sequence for different signals ---
func TestIntegration_MultipleHandlers_Good_DifferentSignals(t *testing.T) {
var commentBodies []string
var mergedPRs []int64
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/merge"):
// Extract PR number from path.
parts := strings.Split(r.URL.Path, "/")
for i, p := range parts {
if p == "pulls" && i+1 < len(parts) {
var prNum int64
_ = json.Unmarshal([]byte(parts[i+1]), &prNum)
mergedPRs = append(mergedPRs, prNum)
}
}
w.WriteHeader(http.StatusOK)
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments"):
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]string
_ = json.Unmarshal(bodyBytes, &body)
commentBodies = append(commentBodies, body["body"])
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/"):
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": "## Tasks\n- [ ] #7\n- [ ] #8\n",
"title": "Epic",
})
case r.Method == http.MethodPatch:
_ = json.NewEncoder(w).Encode(map[string]any{"number": 1, "body": "", "state": "open"})
default:
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]any{})
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
autoMergeHandler := NewEnableAutoMergeHandler(client)
fixCommandHandler := NewSendFixCommandHandler(client)
// Signal 1: should trigger auto merge.
sig1 := &jobrunner.PipelineSignal{
PRState: "OPEN", IsDraft: false, Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS", PRNumber: 10,
RepoOwner: "org", RepoName: "repo",
}
// Signal 2: should trigger fix command.
sig2 := &jobrunner.PipelineSignal{
PRState: "OPEN", Mergeable: "CONFLICTING",
CheckStatus: "SUCCESS", PRNumber: 20,
RepoOwner: "org", RepoName: "repo",
}
assert.True(t, autoMergeHandler.Match(sig1))
assert.False(t, autoMergeHandler.Match(sig2))
assert.False(t, fixCommandHandler.Match(sig1))
assert.True(t, fixCommandHandler.Match(sig2))
// Execute both.
result1, err := autoMergeHandler.Execute(context.Background(), sig1)
require.NoError(t, err)
assert.True(t, result1.Success)
result2, err := fixCommandHandler.Execute(context.Background(), sig2)
require.NoError(t, err)
assert.True(t, result2.Success)
// Verify correct comment was posted for the conflicting PR.
require.Len(t, commentBodies, 1)
assert.Contains(t, commentBodies[0], "fix the merge conflict")
}

View file

@ -1,55 +0,0 @@
package handlers
import (
"context"
"fmt"
"time"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// PublishDraftHandler marks a draft PR as ready for review once its checks pass.
type PublishDraftHandler struct {
forge *forge.Client
}
// NewPublishDraftHandler creates a handler that publishes draft PRs.
func NewPublishDraftHandler(f *forge.Client) *PublishDraftHandler {
return &PublishDraftHandler{forge: f}
}
// Name returns the handler identifier.
func (h *PublishDraftHandler) Name() string {
return "publish_draft"
}
// Match returns true when the PR is a draft, open, and all checks have passed.
func (h *PublishDraftHandler) Match(signal *jobrunner.PipelineSignal) bool {
return signal.IsDraft &&
signal.PRState == "OPEN" &&
signal.CheckStatus == "SUCCESS"
}
// Execute marks the PR as no longer a draft.
func (h *PublishDraftHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) {
start := time.Now()
err := h.forge.SetPRDraft(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber), false)
result := &jobrunner.ActionResult{
Action: "publish_draft",
RepoOwner: signal.RepoOwner,
RepoName: signal.RepoName,
PRNumber: signal.PRNumber,
Success: err == nil,
Timestamp: time.Now(),
Duration: time.Since(start),
}
if err != nil {
result.Error = fmt.Sprintf("publish draft failed: %v", err)
}
return result, nil
}

View file

@ -1,84 +0,0 @@
package handlers
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
func TestPublishDraft_Match_Good(t *testing.T) {
h := NewPublishDraftHandler(nil)
sig := &jobrunner.PipelineSignal{
IsDraft: true,
PRState: "OPEN",
CheckStatus: "SUCCESS",
}
assert.True(t, h.Match(sig))
}
func TestPublishDraft_Match_Bad_NotDraft(t *testing.T) {
h := NewPublishDraftHandler(nil)
sig := &jobrunner.PipelineSignal{
IsDraft: false,
PRState: "OPEN",
CheckStatus: "SUCCESS",
}
assert.False(t, h.Match(sig))
}
func TestPublishDraft_Match_Bad_ChecksFailing(t *testing.T) {
h := NewPublishDraftHandler(nil)
sig := &jobrunner.PipelineSignal{
IsDraft: true,
PRState: "OPEN",
CheckStatus: "FAILURE",
}
assert.False(t, h.Match(sig))
}
func TestPublishDraft_Execute_Good(t *testing.T) {
var capturedMethod string
var capturedPath string
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedMethod = r.Method
capturedPath = r.URL.Path
b, _ := io.ReadAll(r.Body)
capturedBody = string(b)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{}`))
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewPublishDraftHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "host-uk",
RepoName: "core-php",
PRNumber: 42,
IsDraft: true,
PRState: "OPEN",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.Equal(t, http.MethodPatch, capturedMethod)
assert.Equal(t, "/api/v1/repos/host-uk/core-php/pulls/42", capturedPath)
assert.Contains(t, capturedBody, `"draft":false`)
assert.True(t, result.Success)
assert.Equal(t, "publish_draft", result.Action)
assert.Equal(t, "host-uk", result.RepoOwner)
assert.Equal(t, "core-php", result.RepoName)
assert.Equal(t, 42, result.PRNumber)
}

View file

@ -1,80 +0,0 @@
package handlers
import (
"context"
"fmt"
"time"
forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// DismissReviewsHandler dismisses stale "request changes" reviews on a PR.
// This replaces the GitHub-only ResolveThreadsHandler because Forgejo does
// not have a thread resolution API.
type DismissReviewsHandler struct {
forge *forge.Client
}
// NewDismissReviewsHandler creates a handler that dismisses stale reviews.
func NewDismissReviewsHandler(f *forge.Client) *DismissReviewsHandler {
return &DismissReviewsHandler{forge: f}
}
// Name returns the handler identifier.
func (h *DismissReviewsHandler) Name() string {
return "dismiss_reviews"
}
// Match returns true when the PR is open and has unresolved review threads.
func (h *DismissReviewsHandler) Match(signal *jobrunner.PipelineSignal) bool {
return signal.PRState == "OPEN" && signal.HasUnresolvedThreads()
}
// Execute dismisses stale "request changes" reviews on the PR.
func (h *DismissReviewsHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) {
start := time.Now()
reviews, err := h.forge.ListPRReviews(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber))
if err != nil {
return nil, coreerr.E("dismiss_reviews.Execute", "list reviews", err)
}
var dismissErrors []string
dismissed := 0
for _, review := range reviews {
if review.State != forgejosdk.ReviewStateRequestChanges || review.Dismissed || !review.Stale {
continue
}
if err := h.forge.DismissReview(
signal.RepoOwner, signal.RepoName,
int64(signal.PRNumber), review.ID,
"Automatically dismissed: review is stale after new commits",
); err != nil {
dismissErrors = append(dismissErrors, err.Error())
} else {
dismissed++
}
}
result := &jobrunner.ActionResult{
Action: "dismiss_reviews",
RepoOwner: signal.RepoOwner,
RepoName: signal.RepoName,
PRNumber: signal.PRNumber,
Success: len(dismissErrors) == 0,
Timestamp: time.Now(),
Duration: time.Since(start),
}
if len(dismissErrors) > 0 {
result.Error = fmt.Sprintf("failed to dismiss %d review(s): %s",
len(dismissErrors), dismissErrors[0])
}
return result, nil
}

View file

@ -1,91 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
func TestDismissReviews_Match_Good(t *testing.T) {
h := NewDismissReviewsHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
ThreadsTotal: 4,
ThreadsResolved: 2,
}
assert.True(t, h.Match(sig))
}
func TestDismissReviews_Match_Bad_AllResolved(t *testing.T) {
h := NewDismissReviewsHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
ThreadsTotal: 3,
ThreadsResolved: 3,
}
assert.False(t, h.Match(sig))
}
func TestDismissReviews_Execute_Good(t *testing.T) {
callCount := 0
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
// ListPullReviews (GET)
if r.Method == http.MethodGet {
reviews := []map[string]any{
{
"id": 1, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true,
"body": "fix this", "commit_id": "abc123",
},
{
"id": 2, "state": "APPROVED", "dismissed": false, "stale": false,
"body": "looks good", "commit_id": "abc123",
},
{
"id": 3, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true,
"body": "needs work", "commit_id": "abc123",
},
}
_ = json.NewEncoder(w).Encode(reviews)
return
}
// DismissPullReview (POST to dismissals endpoint)
w.WriteHeader(http.StatusOK)
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewDismissReviewsHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "host-uk",
RepoName: "core-admin",
PRNumber: 33,
PRState: "OPEN",
ThreadsTotal: 3,
ThreadsResolved: 1,
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "dismiss_reviews", result.Action)
assert.Equal(t, "host-uk", result.RepoOwner)
assert.Equal(t, "core-admin", result.RepoName)
assert.Equal(t, 33, result.PRNumber)
// 1 list + 2 dismiss (reviews #1 and #3 are stale REQUEST_CHANGES)
assert.Equal(t, 3, callCount)
}

View file

@ -1,74 +0,0 @@
package handlers
import (
"context"
"fmt"
"time"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// SendFixCommandHandler posts a comment on a PR asking for conflict or
// review fixes.
type SendFixCommandHandler struct {
forge *forge.Client
}
// NewSendFixCommandHandler creates a handler that posts fix commands.
func NewSendFixCommandHandler(f *forge.Client) *SendFixCommandHandler {
return &SendFixCommandHandler{forge: f}
}
// Name returns the handler identifier.
func (h *SendFixCommandHandler) Name() string {
return "send_fix_command"
}
// Match returns true when the PR is open and either has merge conflicts or
// has unresolved threads with failing checks.
func (h *SendFixCommandHandler) Match(signal *jobrunner.PipelineSignal) bool {
if signal.PRState != "OPEN" {
return false
}
if signal.Mergeable == "CONFLICTING" {
return true
}
if signal.HasUnresolvedThreads() && signal.CheckStatus == "FAILURE" {
return true
}
return false
}
// Execute posts a comment on the PR asking for a fix.
func (h *SendFixCommandHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) {
start := time.Now()
var message string
if signal.Mergeable == "CONFLICTING" {
message = "Can you fix the merge conflict?"
} else {
message = "Can you fix the code reviews?"
}
err := h.forge.CreateIssueComment(
signal.RepoOwner, signal.RepoName,
int64(signal.PRNumber), message,
)
result := &jobrunner.ActionResult{
Action: "send_fix_command",
RepoOwner: signal.RepoOwner,
RepoName: signal.RepoName,
PRNumber: signal.PRNumber,
Success: err == nil,
Timestamp: time.Now(),
Duration: time.Since(start),
}
if err != nil {
result.Error = fmt.Sprintf("post comment failed: %v", err)
}
return result, nil
}

View file

@ -1,87 +0,0 @@
package handlers
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
func TestSendFixCommand_Match_Good_Conflicting(t *testing.T) {
h := NewSendFixCommandHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
Mergeable: "CONFLICTING",
}
assert.True(t, h.Match(sig))
}
func TestSendFixCommand_Match_Good_UnresolvedThreads(t *testing.T) {
h := NewSendFixCommandHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
Mergeable: "MERGEABLE",
CheckStatus: "FAILURE",
ThreadsTotal: 3,
ThreadsResolved: 1,
}
assert.True(t, h.Match(sig))
}
func TestSendFixCommand_Match_Bad_Clean(t *testing.T) {
h := NewSendFixCommandHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
ThreadsTotal: 2,
ThreadsResolved: 2,
}
assert.False(t, h.Match(sig))
}
func TestSendFixCommand_Execute_Good_Conflict(t *testing.T) {
var capturedMethod string
var capturedPath string
var capturedBody string
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedMethod = r.Method
capturedPath = r.URL.Path
b, _ := io.ReadAll(r.Body)
capturedBody = string(b)
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"id":1}`))
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewSendFixCommandHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "host-uk",
RepoName: "core-tenant",
PRNumber: 17,
PRState: "OPEN",
Mergeable: "CONFLICTING",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.Equal(t, http.MethodPost, capturedMethod)
assert.Equal(t, "/api/v1/repos/host-uk/core-tenant/issues/17/comments", capturedPath)
assert.Contains(t, capturedBody, "fix the merge conflict")
assert.True(t, result.Success)
assert.Equal(t, "send_fix_command", result.Action)
assert.Equal(t, "host-uk", result.RepoOwner)
assert.Equal(t, "core-tenant", result.RepoName)
assert.Equal(t, 17, result.PRNumber)
}

View file

@ -1,35 +0,0 @@
package handlers
import (
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/go-scm/forge"
)
// forgejoVersionResponse is the JSON response for /api/v1/version.
const forgejoVersionResponse = `{"version":"9.0.0"}`
// withVersion wraps an HTTP handler to also serve the Forgejo version endpoint
// that the SDK calls during NewClient initialization.
func withVersion(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/version") {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(forgejoVersionResponse))
return
}
next.ServeHTTP(w, r)
})
}
// newTestForgeClient creates a forge.Client pointing at the given test server URL.
func newTestForgeClient(t *testing.T, url string) *forge.Client {
t.Helper()
client, err := forge.New(url, "test-token")
require.NoError(t, err)
return client
}

View file

@ -1,101 +0,0 @@
package handlers
import (
"context"
"fmt"
"strings"
"time"
forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2"
coreerr "forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-scm/forge"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
// TickParentHandler ticks a child checkbox in the parent epic issue body
// after the child's PR has been merged.
type TickParentHandler struct {
forge *forge.Client
}
// NewTickParentHandler creates a handler that ticks parent epic checkboxes.
func NewTickParentHandler(f *forge.Client) *TickParentHandler {
return &TickParentHandler{forge: f}
}
// Name returns the handler identifier.
func (h *TickParentHandler) Name() string {
return "tick_parent"
}
// Match returns true when the child PR has been merged.
func (h *TickParentHandler) Match(signal *jobrunner.PipelineSignal) bool {
return signal.PRState == "MERGED"
}
// Execute fetches the epic body, replaces the unchecked checkbox for the
// child issue with a checked one, updates the epic, and closes the child issue.
func (h *TickParentHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) {
start := time.Now()
// Fetch the epic issue body.
epic, err := h.forge.GetIssue(signal.RepoOwner, signal.RepoName, int64(signal.EpicNumber))
if err != nil {
return nil, coreerr.E("tick_parent.Execute", "fetch epic", err)
}
oldBody := epic.Body
unchecked := fmt.Sprintf("- [ ] #%d", signal.ChildNumber)
checked := fmt.Sprintf("- [x] #%d", signal.ChildNumber)
if !strings.Contains(oldBody, unchecked) {
// Already ticked or not found -- nothing to do.
return &jobrunner.ActionResult{
Action: "tick_parent",
RepoOwner: signal.RepoOwner,
RepoName: signal.RepoName,
PRNumber: signal.PRNumber,
Success: true,
Timestamp: time.Now(),
Duration: time.Since(start),
}, nil
}
newBody := strings.Replace(oldBody, unchecked, checked, 1)
// Update the epic body.
_, err = h.forge.EditIssue(signal.RepoOwner, signal.RepoName, int64(signal.EpicNumber), forgejosdk.EditIssueOption{
Body: &newBody,
})
if err != nil {
return &jobrunner.ActionResult{
Action: "tick_parent",
RepoOwner: signal.RepoOwner,
RepoName: signal.RepoName,
PRNumber: signal.PRNumber,
Error: fmt.Sprintf("edit epic failed: %v", err),
Timestamp: time.Now(),
Duration: time.Since(start),
}, nil
}
// Close the child issue.
err = h.forge.CloseIssue(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber))
result := &jobrunner.ActionResult{
Action: "tick_parent",
RepoOwner: signal.RepoOwner,
RepoName: signal.RepoName,
PRNumber: signal.PRNumber,
Success: err == nil,
Timestamp: time.Now(),
Duration: time.Since(start),
}
if err != nil {
result.Error = fmt.Sprintf("close child issue failed: %v", err)
}
return result, nil
}

View file

@ -1,98 +0,0 @@
package handlers
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"forge.lthn.ai/core/agent/pkg/jobrunner"
)
func TestTickParent_Match_Good(t *testing.T) {
h := NewTickParentHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "MERGED",
}
assert.True(t, h.Match(sig))
}
func TestTickParent_Match_Bad_Open(t *testing.T) {
h := NewTickParentHandler(nil)
sig := &jobrunner.PipelineSignal{
PRState: "OPEN",
}
assert.False(t, h.Match(sig))
}
func TestTickParent_Execute_Good(t *testing.T) {
epicBody := "## Tasks\n- [x] #1\n- [ ] #7\n- [ ] #8\n"
var editBody string
var closeCalled bool
srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
method := r.Method
w.Header().Set("Content-Type", "application/json")
switch {
// GET issue (fetch epic)
case method == http.MethodGet && strings.Contains(path, "/issues/42"):
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": epicBody,
"title": "Epic",
})
// PATCH issue (edit epic body)
case method == http.MethodPatch && strings.Contains(path, "/issues/42"):
b, _ := io.ReadAll(r.Body)
editBody = string(b)
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 42,
"body": editBody,
"title": "Epic",
})
// PATCH issue (close child — state: closed)
case method == http.MethodPatch && strings.Contains(path, "/issues/7"):
closeCalled = true
_ = json.NewEncoder(w).Encode(map[string]any{
"number": 7,
"state": "closed",
})
default:
w.WriteHeader(http.StatusNotFound)
}
})))
defer srv.Close()
client := newTestForgeClient(t, srv.URL)
h := NewTickParentHandler(client)
sig := &jobrunner.PipelineSignal{
RepoOwner: "host-uk",
RepoName: "core-php",
EpicNumber: 42,
ChildNumber: 7,
PRNumber: 99,
PRState: "MERGED",
}
result, err := h.Execute(context.Background(), sig)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "tick_parent", result.Action)
// Verify the edit body contains the checked checkbox.
assert.Contains(t, editBody, "- [x] #7")
assert.True(t, closeCalled, "expected child issue to be closed")
}

View file

@ -1,203 +0,0 @@
package jobrunner
import (
"bufio"
"encoding/json"
"iter"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
coreio "forge.lthn.ai/core/go-io"
coreerr "forge.lthn.ai/core/go-log"
)
// validPathComponent matches safe repo owner/name characters (alphanumeric, hyphen, underscore, dot).
var validPathComponent = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`)
// JournalEntry is a single line in the JSONL audit log.
type JournalEntry struct {
Timestamp string `json:"ts"`
Epic int `json:"epic"`
Child int `json:"child"`
PR int `json:"pr"`
Repo string `json:"repo"`
Action string `json:"action"`
Signals SignalSnapshot `json:"signals"`
Result ResultSnapshot `json:"result"`
Cycle int `json:"cycle"`
}
// SignalSnapshot captures the structural state of a PR at the time of action.
type SignalSnapshot struct {
PRState string `json:"pr_state"`
IsDraft bool `json:"is_draft"`
CheckStatus string `json:"check_status"`
Mergeable string `json:"mergeable"`
ThreadsTotal int `json:"threads_total"`
ThreadsResolved int `json:"threads_resolved"`
}
// ResultSnapshot captures the outcome of an action.
type ResultSnapshot struct {
Success bool `json:"success"`
Error string `json:"error,omitempty"`
DurationMs int64 `json:"duration_ms"`
}
// Journal writes ActionResult entries to date-partitioned JSONL files.
type Journal struct {
baseDir string
mu sync.Mutex
}
// NewJournal creates a new Journal rooted at baseDir.
func NewJournal(baseDir string) (*Journal, error) {
if baseDir == "" {
return nil, coreerr.E("journal.NewJournal", "base directory is required", nil)
}
return &Journal{baseDir: baseDir}, nil
}
// sanitizePathComponent validates a single path component (owner or repo name)
// to prevent path traversal attacks. It rejects "..", empty strings, paths
// containing separators, and any value outside the safe character set.
func sanitizePathComponent(name string) (string, error) {
// Reject empty or whitespace-only values.
if name == "" || strings.TrimSpace(name) == "" {
return "", coreerr.E("journal.sanitizePathComponent", "invalid path component: "+name, nil)
}
// Reject inputs containing path separators (directory traversal attempt).
if strings.ContainsAny(name, `/\`) {
return "", coreerr.E("journal.sanitizePathComponent", "path component contains directory separator: "+name, nil)
}
// Use filepath.Clean to normalize (e.g., collapse redundant dots).
clean := filepath.Clean(name)
// Reject traversal components.
if clean == "." || clean == ".." {
return "", coreerr.E("journal.sanitizePathComponent", "invalid path component: "+name, nil)
}
// Validate against the safe character set.
if !validPathComponent.MatchString(clean) {
return "", coreerr.E("journal.sanitizePathComponent", "path component contains invalid characters: "+name, nil)
}
return clean, nil
}
// ReadEntries returns an iterator over JournalEntry lines in a date-partitioned file.
func (j *Journal) ReadEntries(path string) iter.Seq2[JournalEntry, error] {
return func(yield func(JournalEntry, error) bool) {
f, err := os.Open(path)
if err != nil {
yield(JournalEntry{}, err)
return
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var entry JournalEntry
if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil {
if !yield(JournalEntry{}, err) {
return
}
continue
}
if !yield(entry, nil) {
return
}
}
if err := scanner.Err(); err != nil {
yield(JournalEntry{}, err)
}
}
}
// Append writes a journal entry for the given signal and result.
func (j *Journal) Append(signal *PipelineSignal, result *ActionResult) error {
if signal == nil {
return coreerr.E("journal.Append", "signal is required", nil)
}
if result == nil {
return coreerr.E("journal.Append", "result is required", nil)
}
entry := JournalEntry{
Timestamp: result.Timestamp.UTC().Format("2006-01-02T15:04:05Z"),
Epic: signal.EpicNumber,
Child: signal.ChildNumber,
PR: signal.PRNumber,
Repo: signal.RepoFullName(),
Action: result.Action,
Signals: SignalSnapshot{
PRState: signal.PRState,
IsDraft: signal.IsDraft,
CheckStatus: signal.CheckStatus,
Mergeable: signal.Mergeable,
ThreadsTotal: signal.ThreadsTotal,
ThreadsResolved: signal.ThreadsResolved,
},
Result: ResultSnapshot{
Success: result.Success,
Error: result.Error,
DurationMs: result.Duration.Milliseconds(),
},
Cycle: result.Cycle,
}
data, err := json.Marshal(entry)
if err != nil {
return coreerr.E("journal.Append", "marshal entry", err)
}
data = append(data, '\n')
// Sanitize path components to prevent path traversal (CVE: issue #46).
owner, err := sanitizePathComponent(signal.RepoOwner)
if err != nil {
return coreerr.E("journal.Append", "invalid repo owner", err)
}
repo, err := sanitizePathComponent(signal.RepoName)
if err != nil {
return coreerr.E("journal.Append", "invalid repo name", err)
}
date := result.Timestamp.UTC().Format("2006-01-02")
dir := filepath.Join(j.baseDir, owner, repo)
// Resolve to absolute path and verify it stays within baseDir.
absBase, err := filepath.Abs(j.baseDir)
if err != nil {
return coreerr.E("journal.Append", "resolve base directory", err)
}
absDir, err := filepath.Abs(dir)
if err != nil {
return coreerr.E("journal.Append", "resolve journal directory", err)
}
if !strings.HasPrefix(absDir, absBase+string(filepath.Separator)) {
return coreerr.E("journal.Append", "path escapes base directory: "+absDir, nil)
}
j.mu.Lock()
defer j.mu.Unlock()
if err := coreio.Local.EnsureDir(dir); err != nil {
return coreerr.E("journal.Append", "create directory", err)
}
path := filepath.Join(dir, date+".jsonl")
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return coreerr.E("journal.Append", "open file", err)
}
defer func() { _ = f.Close() }()
_, err = f.Write(data)
return err
}

View file

@ -1,540 +0,0 @@
package jobrunner
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// readJournalEntries reads all JSONL entries from a given file path.
func readJournalEntries(t *testing.T, path string) []JournalEntry {
t.Helper()
f, err := os.Open(path)
require.NoError(t, err)
defer func() { _ = f.Close() }()
var entries []JournalEntry
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var entry JournalEntry
err := json.Unmarshal(scanner.Bytes(), &entry)
require.NoError(t, err)
entries = append(entries, entry)
}
require.NoError(t, scanner.Err())
return entries
}
// readAllJournalFiles reads all .jsonl files recursively under a base directory.
func readAllJournalFiles(t *testing.T, baseDir string) []JournalEntry {
t.Helper()
var all []JournalEntry
err := filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if filepath.Ext(path) == ".jsonl" {
entries := readJournalEntries(t, path)
all = append(all, entries...)
}
return nil
})
require.NoError(t, err)
return all
}
// --- Journal replay: write multiple entries, read back, verify round-trip ---
func TestJournal_Replay_Good_WriteAndReadBack(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
baseTime := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC)
// Write 5 entries with different actions, times, and repos.
entries := []struct {
signal *PipelineSignal
result *ActionResult
}{
{
signal: &PipelineSignal{
EpicNumber: 1, ChildNumber: 2, PRNumber: 10,
RepoOwner: "org-a", RepoName: "repo-1",
PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE",
},
result: &ActionResult{
Action: "enable_auto_merge",
RepoOwner: "org-a", RepoName: "repo-1",
Success: true, Timestamp: baseTime, Duration: 100 * time.Millisecond, Cycle: 1,
},
},
{
signal: &PipelineSignal{
EpicNumber: 1, ChildNumber: 3, PRNumber: 11,
RepoOwner: "org-a", RepoName: "repo-1",
PRState: "OPEN", CheckStatus: "FAILURE", Mergeable: "CONFLICTING",
},
result: &ActionResult{
Action: "send_fix_command",
RepoOwner: "org-a", RepoName: "repo-1",
Success: true, Timestamp: baseTime.Add(5 * time.Minute), Duration: 50 * time.Millisecond, Cycle: 1,
},
},
{
signal: &PipelineSignal{
EpicNumber: 5, ChildNumber: 10, PRNumber: 20,
RepoOwner: "org-b", RepoName: "repo-2",
PRState: "MERGED", CheckStatus: "SUCCESS", Mergeable: "UNKNOWN",
},
result: &ActionResult{
Action: "tick_parent",
RepoOwner: "org-b", RepoName: "repo-2",
Success: true, Timestamp: baseTime.Add(10 * time.Minute), Duration: 200 * time.Millisecond, Cycle: 2,
},
},
{
signal: &PipelineSignal{
EpicNumber: 5, ChildNumber: 11, PRNumber: 21,
RepoOwner: "org-b", RepoName: "repo-2",
PRState: "OPEN", CheckStatus: "PENDING", Mergeable: "MERGEABLE",
IsDraft: true,
},
result: &ActionResult{
Action: "publish_draft",
RepoOwner: "org-b", RepoName: "repo-2",
Success: false, Error: "API error", Timestamp: baseTime.Add(15 * time.Minute),
Duration: 300 * time.Millisecond, Cycle: 2,
},
},
{
signal: &PipelineSignal{
EpicNumber: 1, ChildNumber: 4, PRNumber: 12,
RepoOwner: "org-a", RepoName: "repo-1",
PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE",
ThreadsTotal: 3, ThreadsResolved: 1,
},
result: &ActionResult{
Action: "dismiss_reviews",
RepoOwner: "org-a", RepoName: "repo-1",
Success: true, Timestamp: baseTime.Add(20 * time.Minute), Duration: 150 * time.Millisecond, Cycle: 3,
},
},
}
for _, e := range entries {
err := j.Append(e.signal, e.result)
require.NoError(t, err)
}
// Read back all entries.
all := readAllJournalFiles(t, dir)
require.Len(t, all, 5)
// Build a map by action for flexible lookup (filepath.Walk order is by path, not insertion).
byAction := make(map[string][]JournalEntry)
for _, e := range all {
byAction[e.Action] = append(byAction[e.Action], e)
}
// Verify enable_auto_merge entry (org-a/repo-1).
require.Len(t, byAction["enable_auto_merge"], 1)
eam := byAction["enable_auto_merge"][0]
assert.Equal(t, "org-a/repo-1", eam.Repo)
assert.Equal(t, 1, eam.Epic)
assert.Equal(t, 2, eam.Child)
assert.Equal(t, 10, eam.PR)
assert.Equal(t, 1, eam.Cycle)
assert.True(t, eam.Result.Success)
assert.Equal(t, int64(100), eam.Result.DurationMs)
// Verify publish_draft (failed entry has error).
require.Len(t, byAction["publish_draft"], 1)
pd := byAction["publish_draft"][0]
assert.Equal(t, "publish_draft", pd.Action)
assert.False(t, pd.Result.Success)
assert.Equal(t, "API error", pd.Result.Error)
// Verify signal snapshot preserves state.
assert.True(t, pd.Signals.IsDraft)
assert.Equal(t, "PENDING", pd.Signals.CheckStatus)
// Verify dismiss_reviews has thread counts preserved.
require.Len(t, byAction["dismiss_reviews"], 1)
dr := byAction["dismiss_reviews"][0]
assert.Equal(t, 3, dr.Signals.ThreadsTotal)
assert.Equal(t, 1, dr.Signals.ThreadsResolved)
}
// --- Journal replay: filter by action ---
func TestJournal_Replay_Good_FilterByAction(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
ts := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC)
actions := []string{"enable_auto_merge", "tick_parent", "send_fix_command", "tick_parent", "publish_draft"}
for i, action := range actions {
signal := &PipelineSignal{
EpicNumber: 1, ChildNumber: i + 1, PRNumber: 10 + i,
RepoOwner: "org", RepoName: "repo",
PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE",
}
result := &ActionResult{
Action: action,
RepoOwner: "org", RepoName: "repo",
Success: true,
Timestamp: ts.Add(time.Duration(i) * time.Minute),
Duration: 100 * time.Millisecond,
Cycle: i + 1,
}
require.NoError(t, j.Append(signal, result))
}
all := readAllJournalFiles(t, dir)
require.Len(t, all, 5)
// Filter by action=tick_parent.
var tickParentEntries []JournalEntry
for _, e := range all {
if e.Action == "tick_parent" {
tickParentEntries = append(tickParentEntries, e)
}
}
assert.Len(t, tickParentEntries, 2)
assert.Equal(t, 2, tickParentEntries[0].Child)
assert.Equal(t, 4, tickParentEntries[1].Child)
}
// --- Journal replay: filter by repo ---
func TestJournal_Replay_Good_FilterByRepo(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
ts := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC)
repos := []struct {
owner string
name string
}{
{"host-uk", "core-php"},
{"host-uk", "core-tenant"},
{"host-uk", "core-php"},
{"lethean", "go-scm"},
{"host-uk", "core-tenant"},
}
for i, r := range repos {
signal := &PipelineSignal{
EpicNumber: 1, ChildNumber: i + 1, PRNumber: 10 + i,
RepoOwner: r.owner, RepoName: r.name,
PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE",
}
result := &ActionResult{
Action: "tick_parent",
RepoOwner: r.owner, RepoName: r.name,
Success: true,
Timestamp: ts.Add(time.Duration(i) * time.Minute),
Duration: 50 * time.Millisecond,
Cycle: i + 1,
}
require.NoError(t, j.Append(signal, result))
}
// Read entries for host-uk/core-php.
phpPath := filepath.Join(dir, "host-uk", "core-php", "2026-02-10.jsonl")
phpEntries := readJournalEntries(t, phpPath)
assert.Len(t, phpEntries, 2)
for _, e := range phpEntries {
assert.Equal(t, "host-uk/core-php", e.Repo)
}
// Read entries for host-uk/core-tenant.
tenantPath := filepath.Join(dir, "host-uk", "core-tenant", "2026-02-10.jsonl")
tenantEntries := readJournalEntries(t, tenantPath)
assert.Len(t, tenantEntries, 2)
for _, e := range tenantEntries {
assert.Equal(t, "host-uk/core-tenant", e.Repo)
}
// Read entries for lethean/go-scm.
scmPath := filepath.Join(dir, "lethean", "go-scm", "2026-02-10.jsonl")
scmEntries := readJournalEntries(t, scmPath)
assert.Len(t, scmEntries, 1)
assert.Equal(t, "lethean/go-scm", scmEntries[0].Repo)
}
// --- Journal replay: filter by time range (date partitioning) ---
func TestJournal_Replay_Good_FilterByTimeRange(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
// Write entries across three different days.
dates := []time.Time{
time.Date(2026, 2, 8, 9, 0, 0, 0, time.UTC),
time.Date(2026, 2, 9, 10, 0, 0, 0, time.UTC),
time.Date(2026, 2, 9, 14, 0, 0, 0, time.UTC),
time.Date(2026, 2, 10, 8, 0, 0, 0, time.UTC),
time.Date(2026, 2, 10, 16, 0, 0, 0, time.UTC),
}
for i, ts := range dates {
signal := &PipelineSignal{
EpicNumber: 1, ChildNumber: i + 1, PRNumber: 10 + i,
RepoOwner: "org", RepoName: "repo",
PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE",
}
result := &ActionResult{
Action: "merge",
RepoOwner: "org", RepoName: "repo",
Success: true,
Timestamp: ts,
Duration: 100 * time.Millisecond,
Cycle: i + 1,
}
require.NoError(t, j.Append(signal, result))
}
// Verify each date file has the correct number of entries.
day8Path := filepath.Join(dir, "org", "repo", "2026-02-08.jsonl")
day8Entries := readJournalEntries(t, day8Path)
assert.Len(t, day8Entries, 1)
assert.Equal(t, "2026-02-08T09:00:00Z", day8Entries[0].Timestamp)
day9Path := filepath.Join(dir, "org", "repo", "2026-02-09.jsonl")
day9Entries := readJournalEntries(t, day9Path)
assert.Len(t, day9Entries, 2)
assert.Equal(t, "2026-02-09T10:00:00Z", day9Entries[0].Timestamp)
assert.Equal(t, "2026-02-09T14:00:00Z", day9Entries[1].Timestamp)
day10Path := filepath.Join(dir, "org", "repo", "2026-02-10.jsonl")
day10Entries := readJournalEntries(t, day10Path)
assert.Len(t, day10Entries, 2)
// Simulate a time range query: get entries for Feb 9 only.
// In a real system, you'd list files matching the date range.
// Here we verify the date partitioning is correct.
rangeStart := time.Date(2026, 2, 9, 0, 0, 0, 0, time.UTC)
rangeEnd := time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC) // exclusive
var filtered []JournalEntry
all := readAllJournalFiles(t, dir)
for _, e := range all {
ts, err := time.Parse("2006-01-02T15:04:05Z", e.Timestamp)
require.NoError(t, err)
if !ts.Before(rangeStart) && ts.Before(rangeEnd) {
filtered = append(filtered, e)
}
}
assert.Len(t, filtered, 2)
assert.Equal(t, 2, filtered[0].Child)
assert.Equal(t, 3, filtered[1].Child)
}
// --- Journal replay: combined filter (action + repo + time) ---
func TestJournal_Replay_Good_CombinedFilter(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
ts1 := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC)
ts2 := time.Date(2026, 2, 10, 11, 0, 0, 0, time.UTC)
ts3 := time.Date(2026, 2, 11, 9, 0, 0, 0, time.UTC)
testData := []struct {
owner string
name string
action string
ts time.Time
}{
{"org", "repo-a", "tick_parent", ts1},
{"org", "repo-a", "enable_auto_merge", ts1},
{"org", "repo-b", "tick_parent", ts2},
{"org", "repo-a", "tick_parent", ts3},
{"org", "repo-b", "send_fix_command", ts3},
}
for i, td := range testData {
signal := &PipelineSignal{
EpicNumber: 1, ChildNumber: i + 1, PRNumber: 100 + i,
RepoOwner: td.owner, RepoName: td.name,
PRState: "MERGED", CheckStatus: "SUCCESS", Mergeable: "UNKNOWN",
}
result := &ActionResult{
Action: td.action,
RepoOwner: td.owner, RepoName: td.name,
Success: true,
Timestamp: td.ts,
Duration: 50 * time.Millisecond,
Cycle: i + 1,
}
require.NoError(t, j.Append(signal, result))
}
// Filter: action=tick_parent AND repo=org/repo-a.
repoAPath := filepath.Join(dir, "org", "repo-a")
var repoAEntries []JournalEntry
err = filepath.Walk(repoAPath, func(path string, info os.FileInfo, walkErr error) error {
if walkErr != nil {
return walkErr
}
if filepath.Ext(path) == ".jsonl" {
entries := readJournalEntries(t, path)
repoAEntries = append(repoAEntries, entries...)
}
return nil
})
require.NoError(t, err)
var tickParentRepoA []JournalEntry
for _, e := range repoAEntries {
if e.Action == "tick_parent" && e.Repo == "org/repo-a" {
tickParentRepoA = append(tickParentRepoA, e)
}
}
assert.Len(t, tickParentRepoA, 2)
assert.Equal(t, 1, tickParentRepoA[0].Child)
assert.Equal(t, 4, tickParentRepoA[1].Child)
}
// --- Journal replay: empty journal returns no entries ---
func TestJournal_Replay_Good_EmptyJournal(t *testing.T) {
dir := t.TempDir()
all := readAllJournalFiles(t, dir)
assert.Empty(t, all)
}
// --- Journal replay: single entry round-trip preserves all fields ---
func TestJournal_Replay_Good_FullFieldRoundTrip(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
ts := time.Date(2026, 2, 15, 14, 30, 45, 0, time.UTC)
signal := &PipelineSignal{
EpicNumber: 42,
ChildNumber: 7,
PRNumber: 99,
RepoOwner: "host-uk",
RepoName: "core-admin",
PRState: "OPEN",
IsDraft: true,
Mergeable: "CONFLICTING",
CheckStatus: "FAILURE",
ThreadsTotal: 5,
ThreadsResolved: 2,
}
result := &ActionResult{
Action: "send_fix_command",
RepoOwner: "host-uk",
RepoName: "core-admin",
Success: false,
Error: "comment API returned 503",
Timestamp: ts,
Duration: 1500 * time.Millisecond,
Cycle: 7,
}
require.NoError(t, j.Append(signal, result))
path := filepath.Join(dir, "host-uk", "core-admin", "2026-02-15.jsonl")
entries := readJournalEntries(t, path)
require.Len(t, entries, 1)
e := entries[0]
assert.Equal(t, "2026-02-15T14:30:45Z", e.Timestamp)
assert.Equal(t, 42, e.Epic)
assert.Equal(t, 7, e.Child)
assert.Equal(t, 99, e.PR)
assert.Equal(t, "host-uk/core-admin", e.Repo)
assert.Equal(t, "send_fix_command", e.Action)
assert.Equal(t, 7, e.Cycle)
// Signal snapshot.
assert.Equal(t, "OPEN", e.Signals.PRState)
assert.True(t, e.Signals.IsDraft)
assert.Equal(t, "CONFLICTING", e.Signals.Mergeable)
assert.Equal(t, "FAILURE", e.Signals.CheckStatus)
assert.Equal(t, 5, e.Signals.ThreadsTotal)
assert.Equal(t, 2, e.Signals.ThreadsResolved)
// Result snapshot.
assert.False(t, e.Result.Success)
assert.Equal(t, "comment API returned 503", e.Result.Error)
assert.Equal(t, int64(1500), e.Result.DurationMs)
}
// --- Journal replay: concurrent writes produce valid JSONL ---
func TestJournal_Replay_Good_ConcurrentWrites(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
ts := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC)
// Write 20 entries concurrently.
done := make(chan struct{}, 20)
for i := range 20 {
go func(idx int) {
signal := &PipelineSignal{
EpicNumber: 1, ChildNumber: idx, PRNumber: idx,
RepoOwner: "org", RepoName: "repo",
PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE",
}
result := &ActionResult{
Action: "test",
RepoOwner: "org", RepoName: "repo",
Success: true,
Timestamp: ts,
Duration: 10 * time.Millisecond,
Cycle: idx,
}
_ = j.Append(signal, result)
done <- struct{}{}
}(i)
}
for range 20 {
<-done
}
// All entries should be parseable and present.
path := filepath.Join(dir, "org", "repo", "2026-02-10.jsonl")
entries := readJournalEntries(t, path)
assert.Len(t, entries, 20)
// Each entry should have valid JSON (no corruption from concurrent writes).
for _, e := range entries {
assert.NotEmpty(t, e.Action)
assert.Equal(t, "org/repo", e.Repo)
}
}

View file

@ -1,263 +0,0 @@
package jobrunner
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestJournal_Append_Good(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
ts := time.Date(2026, 2, 5, 14, 30, 0, 0, time.UTC)
signal := &PipelineSignal{
EpicNumber: 10,
ChildNumber: 3,
PRNumber: 55,
RepoOwner: "host-uk",
RepoName: "core-tenant",
PRState: "OPEN",
IsDraft: false,
Mergeable: "MERGEABLE",
CheckStatus: "SUCCESS",
ThreadsTotal: 2,
ThreadsResolved: 1,
LastCommitSHA: "abc123",
LastCommitAt: ts,
LastReviewAt: ts,
}
result := &ActionResult{
Action: "merge",
RepoOwner: "host-uk",
RepoName: "core-tenant",
EpicNumber: 10,
ChildNumber: 3,
PRNumber: 55,
Success: true,
Timestamp: ts,
Duration: 1200 * time.Millisecond,
Cycle: 1,
}
err = j.Append(signal, result)
require.NoError(t, err)
// Read the file back.
expectedPath := filepath.Join(dir, "host-uk", "core-tenant", "2026-02-05.jsonl")
f, err := os.Open(expectedPath)
require.NoError(t, err)
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
require.True(t, scanner.Scan(), "expected at least one line in JSONL file")
var entry JournalEntry
err = json.Unmarshal(scanner.Bytes(), &entry)
require.NoError(t, err)
assert.Equal(t, "2026-02-05T14:30:00Z", entry.Timestamp)
assert.Equal(t, 10, entry.Epic)
assert.Equal(t, 3, entry.Child)
assert.Equal(t, 55, entry.PR)
assert.Equal(t, "host-uk/core-tenant", entry.Repo)
assert.Equal(t, "merge", entry.Action)
assert.Equal(t, 1, entry.Cycle)
// Verify signal snapshot.
assert.Equal(t, "OPEN", entry.Signals.PRState)
assert.Equal(t, false, entry.Signals.IsDraft)
assert.Equal(t, "SUCCESS", entry.Signals.CheckStatus)
assert.Equal(t, "MERGEABLE", entry.Signals.Mergeable)
assert.Equal(t, 2, entry.Signals.ThreadsTotal)
assert.Equal(t, 1, entry.Signals.ThreadsResolved)
// Verify result snapshot.
assert.Equal(t, true, entry.Result.Success)
assert.Equal(t, "", entry.Result.Error)
assert.Equal(t, int64(1200), entry.Result.DurationMs)
// Append a second entry and verify two lines exist.
result2 := &ActionResult{
Action: "comment",
RepoOwner: "host-uk",
RepoName: "core-tenant",
Success: false,
Error: "rate limited",
Timestamp: ts,
Duration: 50 * time.Millisecond,
Cycle: 2,
}
err = j.Append(signal, result2)
require.NoError(t, err)
data, err := os.ReadFile(expectedPath)
require.NoError(t, err)
lines := 0
sc := bufio.NewScanner(strings.NewReader(string(data)))
for sc.Scan() {
lines++
}
assert.Equal(t, 2, lines, "expected two JSONL lines after two appends")
}
func TestJournal_Append_Bad_PathTraversal(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
ts := time.Now()
tests := []struct {
name string
repoOwner string
repoName string
wantErr string
}{
{
name: "dotdot owner",
repoOwner: "..",
repoName: "core",
wantErr: "invalid repo owner",
},
{
name: "dotdot repo",
repoOwner: "host-uk",
repoName: "../../etc/cron.d",
wantErr: "invalid repo name",
},
{
name: "slash in owner",
repoOwner: "../etc",
repoName: "core",
wantErr: "invalid repo owner",
},
{
name: "absolute path in repo",
repoOwner: "host-uk",
repoName: "/etc/passwd",
wantErr: "invalid repo name",
},
{
name: "empty owner",
repoOwner: "",
repoName: "core",
wantErr: "invalid repo owner",
},
{
name: "empty repo",
repoOwner: "host-uk",
repoName: "",
wantErr: "invalid repo name",
},
{
name: "dot only owner",
repoOwner: ".",
repoName: "core",
wantErr: "invalid repo owner",
},
{
name: "spaces only owner",
repoOwner: " ",
repoName: "core",
wantErr: "invalid repo owner",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
signal := &PipelineSignal{
RepoOwner: tc.repoOwner,
RepoName: tc.repoName,
}
result := &ActionResult{
Action: "merge",
Timestamp: ts,
}
err := j.Append(signal, result)
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErr)
})
}
}
func TestJournal_Append_Good_ValidNames(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
ts := time.Date(2026, 2, 5, 14, 30, 0, 0, time.UTC)
// Verify valid names with dots, hyphens, underscores all work.
validNames := []struct {
owner string
repo string
}{
{"host-uk", "core"},
{"my_org", "my_repo"},
{"org.name", "repo.v2"},
{"a", "b"},
{"Org-123", "Repo_456.go"},
}
for _, vn := range validNames {
signal := &PipelineSignal{
RepoOwner: vn.owner,
RepoName: vn.repo,
}
result := &ActionResult{
Action: "test",
Timestamp: ts,
}
err := j.Append(signal, result)
assert.NoError(t, err, "expected valid name pair %s/%s to succeed", vn.owner, vn.repo)
}
}
func TestJournal_Append_Bad_NilSignal(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
result := &ActionResult{
Action: "merge",
Timestamp: time.Now(),
}
err = j.Append(nil, result)
require.Error(t, err)
assert.Contains(t, err.Error(), "signal is required")
}
func TestJournal_Append_Bad_NilResult(t *testing.T) {
dir := t.TempDir()
j, err := NewJournal(dir)
require.NoError(t, err)
signal := &PipelineSignal{
RepoOwner: "host-uk",
RepoName: "core-php",
}
err = j.Append(signal, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "result is required")
}

View file

@ -1,224 +0,0 @@
package jobrunner
import (
"context"
"iter"
"sync"
"time"
"forge.lthn.ai/core/go-log"
)
// PollerConfig configures a Poller.
type PollerConfig struct {
Sources []JobSource
Handlers []JobHandler
Journal *Journal
PollInterval time.Duration
DryRun bool
}
// Poller discovers signals from sources and dispatches them to handlers.
type Poller struct {
mu sync.RWMutex
sources []JobSource
handlers []JobHandler
journal *Journal
interval time.Duration
dryRun bool
cycle int
}
// NewPoller creates a Poller from the given config.
func NewPoller(cfg PollerConfig) *Poller {
interval := cfg.PollInterval
if interval <= 0 {
interval = 60 * time.Second
}
return &Poller{
sources: cfg.Sources,
handlers: cfg.Handlers,
journal: cfg.Journal,
interval: interval,
dryRun: cfg.DryRun,
}
}
// Cycle returns the number of completed poll-dispatch cycles.
func (p *Poller) Cycle() int {
p.mu.RLock()
defer p.mu.RUnlock()
return p.cycle
}
// Sources returns an iterator over the poller's sources.
func (p *Poller) Sources() iter.Seq[JobSource] {
return func(yield func(JobSource) bool) {
p.mu.RLock()
sources := make([]JobSource, len(p.sources))
copy(sources, p.sources)
p.mu.RUnlock()
for _, s := range sources {
if !yield(s) {
return
}
}
}
}
// Handlers returns an iterator over the poller's handlers.
func (p *Poller) Handlers() iter.Seq[JobHandler] {
return func(yield func(JobHandler) bool) {
p.mu.RLock()
handlers := make([]JobHandler, len(p.handlers))
copy(handlers, p.handlers)
p.mu.RUnlock()
for _, h := range handlers {
if !yield(h) {
return
}
}
}
}
// DryRun returns whether dry-run mode is enabled.
func (p *Poller) DryRun() bool {
p.mu.RLock()
defer p.mu.RUnlock()
return p.dryRun
}
// SetDryRun enables or disables dry-run mode.
func (p *Poller) SetDryRun(v bool) {
p.mu.Lock()
p.dryRun = v
p.mu.Unlock()
}
// AddSource appends a source to the poller.
func (p *Poller) AddSource(s JobSource) {
p.mu.Lock()
p.sources = append(p.sources, s)
p.mu.Unlock()
}
// AddHandler appends a handler to the poller.
func (p *Poller) AddHandler(h JobHandler) {
p.mu.Lock()
p.handlers = append(p.handlers, h)
p.mu.Unlock()
}
// Run starts a blocking poll-dispatch loop. It runs one cycle immediately,
// then repeats on each tick of the configured interval until the context
// is cancelled.
func (p *Poller) Run(ctx context.Context) error {
if err := p.RunOnce(ctx); err != nil {
return err
}
ticker := time.NewTicker(p.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if err := p.RunOnce(ctx); err != nil {
return err
}
}
}
}
// RunOnce performs a single poll-dispatch cycle: iterate sources, poll each,
// find the first matching handler for each signal, and execute it.
func (p *Poller) RunOnce(ctx context.Context) error {
p.mu.Lock()
p.cycle++
cycle := p.cycle
dryRun := p.dryRun
p.mu.Unlock()
log.Info("poller cycle starting", "cycle", cycle)
for src := range p.Sources() {
signals, err := src.Poll(ctx)
if err != nil {
log.Error("poll failed", "source", src.Name(), "err", err)
continue
}
log.Info("polled source", "source", src.Name(), "signals", len(signals))
for _, sig := range signals {
handler := p.findHandler(p.Handlers(), sig)
if handler == nil {
log.Debug("no matching handler", "epic", sig.EpicNumber, "child", sig.ChildNumber)
continue
}
if dryRun {
log.Info("dry-run: would execute",
"handler", handler.Name(),
"epic", sig.EpicNumber,
"child", sig.ChildNumber,
"pr", sig.PRNumber,
)
continue
}
start := time.Now()
result, err := handler.Execute(ctx, sig)
elapsed := time.Since(start)
if err != nil {
log.Error("handler execution failed",
"handler", handler.Name(),
"epic", sig.EpicNumber,
"child", sig.ChildNumber,
"err", err,
)
continue
}
result.Cycle = cycle
result.EpicNumber = sig.EpicNumber
result.ChildNumber = sig.ChildNumber
result.Duration = elapsed
if p.journal != nil {
if jErr := p.journal.Append(sig, result); jErr != nil {
log.Error("journal append failed", "err", jErr)
}
}
if rErr := src.Report(ctx, result); rErr != nil {
log.Error("source report failed", "source", src.Name(), "err", rErr)
}
log.Info("handler executed",
"handler", handler.Name(),
"action", result.Action,
"success", result.Success,
"duration", elapsed,
)
}
}
return nil
}
// findHandler returns the first handler that matches the signal, or nil.
func (p *Poller) findHandler(handlers iter.Seq[JobHandler], sig *PipelineSignal) JobHandler {
for h := range handlers {
if h.Match(sig) {
return h
}
}
return nil
}

View file

@ -1,307 +0,0 @@
package jobrunner
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Mock source ---
type mockSource struct {
name string
signals []*PipelineSignal
reports []*ActionResult
mu sync.Mutex
}
func (m *mockSource) Name() string { return m.name }
func (m *mockSource) Poll(_ context.Context) ([]*PipelineSignal, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.signals, nil
}
func (m *mockSource) Report(_ context.Context, result *ActionResult) error {
m.mu.Lock()
defer m.mu.Unlock()
m.reports = append(m.reports, result)
return nil
}
// --- Mock handler ---
type mockHandler struct {
name string
matchFn func(*PipelineSignal) bool
executed []*PipelineSignal
mu sync.Mutex
}
func (m *mockHandler) Name() string { return m.name }
func (m *mockHandler) Match(sig *PipelineSignal) bool {
if m.matchFn != nil {
return m.matchFn(sig)
}
return true
}
func (m *mockHandler) Execute(_ context.Context, sig *PipelineSignal) (*ActionResult, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.executed = append(m.executed, sig)
return &ActionResult{
Action: m.name,
RepoOwner: sig.RepoOwner,
RepoName: sig.RepoName,
PRNumber: sig.PRNumber,
Success: true,
Timestamp: time.Now(),
}, nil
}
func TestPoller_RunOnce_Good(t *testing.T) {
sig := &PipelineSignal{
EpicNumber: 1,
ChildNumber: 2,
PRNumber: 10,
RepoOwner: "host-uk",
RepoName: "core-php",
PRState: "OPEN",
CheckStatus: "SUCCESS",
Mergeable: "MERGEABLE",
}
src := &mockSource{
name: "test-source",
signals: []*PipelineSignal{sig},
}
handler := &mockHandler{
name: "test-handler",
matchFn: func(s *PipelineSignal) bool {
return s.PRNumber == 10
},
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
})
err := p.RunOnce(context.Background())
require.NoError(t, err)
// Handler should have been called with our signal.
handler.mu.Lock()
defer handler.mu.Unlock()
require.Len(t, handler.executed, 1)
assert.Equal(t, 10, handler.executed[0].PRNumber)
// Source should have received a report.
src.mu.Lock()
defer src.mu.Unlock()
require.Len(t, src.reports, 1)
assert.Equal(t, "test-handler", src.reports[0].Action)
assert.True(t, src.reports[0].Success)
assert.Equal(t, 1, src.reports[0].Cycle)
assert.Equal(t, 1, src.reports[0].EpicNumber)
assert.Equal(t, 2, src.reports[0].ChildNumber)
// Cycle counter should have incremented.
assert.Equal(t, 1, p.Cycle())
}
func TestPoller_RunOnce_Good_NoSignals(t *testing.T) {
src := &mockSource{
name: "empty-source",
signals: nil,
}
handler := &mockHandler{
name: "unused-handler",
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
})
err := p.RunOnce(context.Background())
require.NoError(t, err)
// Handler should not have been called.
handler.mu.Lock()
defer handler.mu.Unlock()
assert.Empty(t, handler.executed)
// Source should not have received reports.
src.mu.Lock()
defer src.mu.Unlock()
assert.Empty(t, src.reports)
assert.Equal(t, 1, p.Cycle())
}
func TestPoller_RunOnce_Good_NoMatchingHandler(t *testing.T) {
sig := &PipelineSignal{
EpicNumber: 5,
ChildNumber: 8,
PRNumber: 42,
RepoOwner: "host-uk",
RepoName: "core-tenant",
PRState: "OPEN",
}
src := &mockSource{
name: "test-source",
signals: []*PipelineSignal{sig},
}
handler := &mockHandler{
name: "picky-handler",
matchFn: func(s *PipelineSignal) bool {
return false // never matches
},
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
})
err := p.RunOnce(context.Background())
require.NoError(t, err)
// Handler should not have been called.
handler.mu.Lock()
defer handler.mu.Unlock()
assert.Empty(t, handler.executed)
// Source should not have received reports (no action taken).
src.mu.Lock()
defer src.mu.Unlock()
assert.Empty(t, src.reports)
}
func TestPoller_RunOnce_Good_DryRun(t *testing.T) {
sig := &PipelineSignal{
EpicNumber: 1,
ChildNumber: 3,
PRNumber: 20,
RepoOwner: "host-uk",
RepoName: "core-admin",
PRState: "OPEN",
CheckStatus: "SUCCESS",
Mergeable: "MERGEABLE",
}
src := &mockSource{
name: "test-source",
signals: []*PipelineSignal{sig},
}
handler := &mockHandler{
name: "merge-handler",
matchFn: func(s *PipelineSignal) bool {
return true
},
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
Handlers: []JobHandler{handler},
DryRun: true,
})
assert.True(t, p.DryRun())
err := p.RunOnce(context.Background())
require.NoError(t, err)
// Handler should NOT have been called in dry-run mode.
handler.mu.Lock()
defer handler.mu.Unlock()
assert.Empty(t, handler.executed)
// Source should not have received reports.
src.mu.Lock()
defer src.mu.Unlock()
assert.Empty(t, src.reports)
}
func TestPoller_SetDryRun_Good(t *testing.T) {
p := NewPoller(PollerConfig{})
assert.False(t, p.DryRun())
p.SetDryRun(true)
assert.True(t, p.DryRun())
p.SetDryRun(false)
assert.False(t, p.DryRun())
}
func TestPoller_AddSourceAndHandler_Good(t *testing.T) {
p := NewPoller(PollerConfig{})
sig := &PipelineSignal{
EpicNumber: 1,
ChildNumber: 1,
PRNumber: 5,
RepoOwner: "host-uk",
RepoName: "core-php",
PRState: "OPEN",
}
src := &mockSource{
name: "added-source",
signals: []*PipelineSignal{sig},
}
handler := &mockHandler{
name: "added-handler",
matchFn: func(s *PipelineSignal) bool { return true },
}
p.AddSource(src)
p.AddHandler(handler)
err := p.RunOnce(context.Background())
require.NoError(t, err)
handler.mu.Lock()
defer handler.mu.Unlock()
require.Len(t, handler.executed, 1)
assert.Equal(t, 5, handler.executed[0].PRNumber)
}
func TestPoller_Run_Good(t *testing.T) {
src := &mockSource{
name: "tick-source",
signals: nil,
}
p := NewPoller(PollerConfig{
Sources: []JobSource{src},
PollInterval: 50 * time.Millisecond,
})
ctx, cancel := context.WithTimeout(context.Background(), 180*time.Millisecond)
defer cancel()
err := p.Run(ctx)
assert.ErrorIs(t, err, context.DeadlineExceeded)
// Should have completed at least 2 cycles (one immediate + at least one tick).
assert.GreaterOrEqual(t, p.Cycle(), 2)
}
func TestPoller_DefaultInterval_Good(t *testing.T) {
p := NewPoller(PollerConfig{})
assert.Equal(t, 60*time.Second, p.interval)
}

View file

@ -1,72 +0,0 @@
package jobrunner
import (
"context"
"time"
)
// PipelineSignal is the structural snapshot of a child issue/PR.
// Carries structural state plus issue title/body for dispatch prompts.
type PipelineSignal struct {
EpicNumber int
ChildNumber int
PRNumber int
RepoOwner string
RepoName string
PRState string // OPEN, MERGED, CLOSED
IsDraft bool
Mergeable string // MERGEABLE, CONFLICTING, UNKNOWN
CheckStatus string // SUCCESS, FAILURE, PENDING
ThreadsTotal int
ThreadsResolved int
LastCommitSHA string
LastCommitAt time.Time
LastReviewAt time.Time
NeedsCoding bool // true if child has no PR (work not started)
Assignee string // issue assignee username (for dispatch)
IssueTitle string // child issue title (for dispatch prompt)
IssueBody string // child issue body (for dispatch prompt)
Type string // signal type (e.g., "agent_completion")
Success bool // agent completion success flag
Error string // agent error message
Message string // agent completion message
}
// RepoFullName returns "owner/repo".
func (s *PipelineSignal) RepoFullName() string {
return s.RepoOwner + "/" + s.RepoName
}
// HasUnresolvedThreads returns true if there are unresolved review threads.
func (s *PipelineSignal) HasUnresolvedThreads() bool {
return s.ThreadsTotal > s.ThreadsResolved
}
// ActionResult carries the outcome of a handler execution.
type ActionResult struct {
Action string `json:"action"`
RepoOwner string `json:"repo_owner"`
RepoName string `json:"repo_name"`
EpicNumber int `json:"epic"`
ChildNumber int `json:"child"`
PRNumber int `json:"pr"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
Timestamp time.Time `json:"ts"`
Duration time.Duration `json:"duration_ms"`
Cycle int `json:"cycle"`
}
// JobSource discovers actionable work from an external system.
type JobSource interface {
Name() string
Poll(ctx context.Context) ([]*PipelineSignal, error)
Report(ctx context.Context, result *ActionResult) error
}
// JobHandler processes a single pipeline signal.
type JobHandler interface {
Name() string
Match(signal *PipelineSignal) bool
Execute(ctx context.Context, signal *PipelineSignal) (*ActionResult, error)
}

View file

@ -1,98 +0,0 @@
package jobrunner
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPipelineSignal_RepoFullName_Good(t *testing.T) {
sig := &PipelineSignal{
RepoOwner: "host-uk",
RepoName: "core-php",
}
assert.Equal(t, "host-uk/core-php", sig.RepoFullName())
}
func TestPipelineSignal_HasUnresolvedThreads_Good(t *testing.T) {
sig := &PipelineSignal{
ThreadsTotal: 5,
ThreadsResolved: 3,
}
assert.True(t, sig.HasUnresolvedThreads())
}
func TestPipelineSignal_HasUnresolvedThreads_Bad_AllResolved(t *testing.T) {
sig := &PipelineSignal{
ThreadsTotal: 4,
ThreadsResolved: 4,
}
assert.False(t, sig.HasUnresolvedThreads())
// Also verify zero threads is not unresolved.
sigZero := &PipelineSignal{
ThreadsTotal: 0,
ThreadsResolved: 0,
}
assert.False(t, sigZero.HasUnresolvedThreads())
}
func TestActionResult_JSON_Good(t *testing.T) {
ts := time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC)
result := &ActionResult{
Action: "merge",
RepoOwner: "host-uk",
RepoName: "core-tenant",
EpicNumber: 42,
ChildNumber: 7,
PRNumber: 99,
Success: true,
Timestamp: ts,
Duration: 1500 * time.Millisecond,
Cycle: 3,
}
data, err := json.Marshal(result)
require.NoError(t, err)
var decoded map[string]any
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, "merge", decoded["action"])
assert.Equal(t, "host-uk", decoded["repo_owner"])
assert.Equal(t, "core-tenant", decoded["repo_name"])
assert.Equal(t, float64(42), decoded["epic"])
assert.Equal(t, float64(7), decoded["child"])
assert.Equal(t, float64(99), decoded["pr"])
assert.Equal(t, true, decoded["success"])
assert.Equal(t, float64(3), decoded["cycle"])
// Error field should be omitted when empty.
_, hasError := decoded["error"]
assert.False(t, hasError, "error field should be omitted when empty")
// Verify round-trip with error field present.
resultWithErr := &ActionResult{
Action: "merge",
RepoOwner: "host-uk",
RepoName: "core-tenant",
Success: false,
Error: "checks failing",
Timestamp: ts,
Duration: 200 * time.Millisecond,
Cycle: 1,
}
data2, err := json.Marshal(resultWithErr)
require.NoError(t, err)
var decoded2 map[string]any
err = json.Unmarshal(data2, &decoded2)
require.NoError(t, err)
assert.Equal(t, "checks failing", decoded2["error"])
assert.Equal(t, false, decoded2["success"])
}

View file

@ -1,335 +0,0 @@
package lifecycle
import (
"iter"
"sync"
"time"
)
// AllowanceStatus indicates the current state of an agent's quota.
type AllowanceStatus string
const (
// AllowanceOK indicates the agent has remaining quota.
AllowanceOK AllowanceStatus = "ok"
// AllowanceWarning indicates the agent is at 80%+ usage.
AllowanceWarning AllowanceStatus = "warning"
// AllowanceExceeded indicates the agent has exceeded its quota.
AllowanceExceeded AllowanceStatus = "exceeded"
)
// AgentAllowance defines the quota limits for a single agent.
type AgentAllowance struct {
// AgentID is the unique identifier for the agent.
AgentID string `json:"agent_id" yaml:"agent_id"`
// DailyTokenLimit is the maximum tokens (in+out) per 24h. 0 means unlimited.
DailyTokenLimit int64 `json:"daily_token_limit" yaml:"daily_token_limit"`
// DailyJobLimit is the maximum jobs per 24h. 0 means unlimited.
DailyJobLimit int `json:"daily_job_limit" yaml:"daily_job_limit"`
// ConcurrentJobs is the maximum simultaneous jobs. 0 means unlimited.
ConcurrentJobs int `json:"concurrent_jobs" yaml:"concurrent_jobs"`
// MaxJobDuration is the maximum job duration before kill. 0 means unlimited.
MaxJobDuration time.Duration `json:"max_job_duration" yaml:"max_job_duration"`
// ModelAllowlist restricts which models this agent can use. Empty means all.
ModelAllowlist []string `json:"model_allowlist,omitempty" yaml:"model_allowlist"`
}
// ModelQuota defines global per-model limits across all agents.
type ModelQuota struct {
// Model is the model identifier (e.g. "claude-sonnet-4-5-20250929").
Model string `json:"model" yaml:"model"`
// DailyTokenBudget is the total tokens across all agents per 24h.
DailyTokenBudget int64 `json:"daily_token_budget" yaml:"daily_token_budget"`
// HourlyRateLimit is the max requests per hour.
// Reserved: stored but not yet enforced in AllowanceService.Check.
// Enforcement requires AllowanceStore.GetHourlyUsage (sliding window).
HourlyRateLimit int `json:"hourly_rate_limit" yaml:"hourly_rate_limit"`
// CostCeiling stops all usage if cumulative cost exceeds this (in cents).
// Reserved: stored but not yet enforced in AllowanceService.Check.
CostCeiling int64 `json:"cost_ceiling" yaml:"cost_ceiling"`
}
// RepoLimit defines per-repository rate limits.
type RepoLimit struct {
// Repo is the repository identifier (e.g. "owner/repo").
Repo string `json:"repo" yaml:"repo"`
// MaxDailyPRs is the maximum PRs per day. 0 means unlimited.
MaxDailyPRs int `json:"max_daily_prs" yaml:"max_daily_prs"`
// MaxDailyIssues is the maximum issues per day. 0 means unlimited.
MaxDailyIssues int `json:"max_daily_issues" yaml:"max_daily_issues"`
// CooldownAfterFailure is the wait time after a failure before retrying.
CooldownAfterFailure time.Duration `json:"cooldown_after_failure" yaml:"cooldown_after_failure"`
}
// UsageRecord tracks an agent's current usage within a quota period.
type UsageRecord struct {
// AgentID is the agent this record belongs to.
AgentID string `json:"agent_id"`
// TokensUsed is the total tokens consumed in the current period.
TokensUsed int64 `json:"tokens_used"`
// JobsStarted is the total jobs started in the current period.
JobsStarted int `json:"jobs_started"`
// ActiveJobs is the number of currently running jobs.
ActiveJobs int `json:"active_jobs"`
// PeriodStart is when the current quota period began.
PeriodStart time.Time `json:"period_start"`
}
// QuotaCheckResult is the outcome of a pre-dispatch allowance check.
type QuotaCheckResult struct {
// Allowed indicates whether the agent may proceed.
Allowed bool `json:"allowed"`
// Status is the current allowance state.
Status AllowanceStatus `json:"status"`
// Remaining is the number of tokens remaining in the period.
RemainingTokens int64 `json:"remaining_tokens"`
// RemainingJobs is the number of jobs remaining in the period.
RemainingJobs int `json:"remaining_jobs"`
// Reason explains why the check failed (if !Allowed).
Reason string `json:"reason,omitempty"`
}
// QuotaEvent represents a change in quota usage, used for recovery.
type QuotaEvent string
const (
// QuotaEventJobStarted deducts quota when a job begins.
QuotaEventJobStarted QuotaEvent = "job_started"
// QuotaEventJobCompleted deducts nothing (already counted).
QuotaEventJobCompleted QuotaEvent = "job_completed"
// QuotaEventJobFailed returns 50% of token quota.
QuotaEventJobFailed QuotaEvent = "job_failed"
// QuotaEventJobCancelled returns 100% of token quota.
QuotaEventJobCancelled QuotaEvent = "job_cancelled"
)
// UsageReport is emitted by the agent runner to report token consumption.
type UsageReport struct {
// AgentID is the agent that consumed tokens.
AgentID string `json:"agent_id"`
// JobID identifies the specific job.
JobID string `json:"job_id"`
// Model is the model used.
Model string `json:"model"`
// TokensIn is the number of input tokens consumed.
TokensIn int64 `json:"tokens_in"`
// TokensOut is the number of output tokens consumed.
TokensOut int64 `json:"tokens_out"`
// Event is the type of quota event.
Event QuotaEvent `json:"event"`
// Timestamp is when the usage occurred.
Timestamp time.Time `json:"timestamp"`
}
// AllowanceStore is the interface for persisting and querying allowance data.
// Implementations may use Redis, SQLite, or any backing store.
type AllowanceStore interface {
// GetAllowance returns the quota limits for an agent.
GetAllowance(agentID string) (*AgentAllowance, error)
// SetAllowance persists quota limits for an agent.
SetAllowance(a *AgentAllowance) error
// Allowances returns an iterator over all agent allowances.
Allowances() iter.Seq[*AgentAllowance]
// GetUsage returns the current usage record for an agent.
GetUsage(agentID string) (*UsageRecord, error)
// Usages returns an iterator over all usage records.
Usages() iter.Seq[*UsageRecord]
// IncrementUsage atomically adds to an agent's usage counters.
IncrementUsage(agentID string, tokens int64, jobs int) error
// DecrementActiveJobs reduces the active job count by 1.
DecrementActiveJobs(agentID string) error
// ReturnTokens adds tokens back to the agent's remaining quota.
ReturnTokens(agentID string, tokens int64) error
// ResetUsage clears usage counters for an agent (daily reset).
ResetUsage(agentID string) error
// GetModelQuota returns global limits for a model.
GetModelQuota(model string) (*ModelQuota, error)
// GetModelUsage returns current token usage for a model.
GetModelUsage(model string) (int64, error)
// IncrementModelUsage atomically adds to a model's usage counter.
IncrementModelUsage(model string, tokens int64) error
}
// MemoryStore is an in-memory AllowanceStore for testing and single-node use.
type MemoryStore struct {
mu sync.RWMutex
allowances map[string]*AgentAllowance
usage map[string]*UsageRecord
modelQuotas map[string]*ModelQuota
modelUsage map[string]int64
}
// NewMemoryStore creates a new in-memory allowance store.
func NewMemoryStore() *MemoryStore {
return &MemoryStore{
allowances: make(map[string]*AgentAllowance),
usage: make(map[string]*UsageRecord),
modelQuotas: make(map[string]*ModelQuota),
modelUsage: make(map[string]int64),
}
}
// GetAllowance returns the quota limits for an agent.
func (m *MemoryStore) GetAllowance(agentID string) (*AgentAllowance, error) {
m.mu.RLock()
defer m.mu.RUnlock()
a, ok := m.allowances[agentID]
if !ok {
return nil, &APIError{Code: 404, Message: "allowance not found for agent: " + agentID}
}
cp := *a
return &cp, nil
}
// SetAllowance persists quota limits for an agent.
func (m *MemoryStore) SetAllowance(a *AgentAllowance) error {
m.mu.Lock()
defer m.mu.Unlock()
cp := *a
m.allowances[a.AgentID] = &cp
return nil
}
// Allowances returns an iterator over all agent allowances.
func (m *MemoryStore) Allowances() iter.Seq[*AgentAllowance] {
return func(yield func(*AgentAllowance) bool) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, a := range m.allowances {
cp := *a
if !yield(&cp) {
return
}
}
}
}
// GetUsage returns the current usage record for an agent.
func (m *MemoryStore) GetUsage(agentID string) (*UsageRecord, error) {
m.mu.RLock()
defer m.mu.RUnlock()
u, ok := m.usage[agentID]
if !ok {
return &UsageRecord{
AgentID: agentID,
PeriodStart: startOfDay(time.Now().UTC()),
}, nil
}
cp := *u
return &cp, nil
}
// Usages returns an iterator over all usage records.
func (m *MemoryStore) Usages() iter.Seq[*UsageRecord] {
return func(yield func(*UsageRecord) bool) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, u := range m.usage {
cp := *u
if !yield(&cp) {
return
}
}
}
}
// IncrementUsage atomically adds to an agent's usage counters.
func (m *MemoryStore) IncrementUsage(agentID string, tokens int64, jobs int) error {
m.mu.Lock()
defer m.mu.Unlock()
u, ok := m.usage[agentID]
if !ok {
u = &UsageRecord{
AgentID: agentID,
PeriodStart: startOfDay(time.Now().UTC()),
}
m.usage[agentID] = u
}
u.TokensUsed += tokens
u.JobsStarted += jobs
if jobs > 0 {
u.ActiveJobs += jobs
}
return nil
}
// DecrementActiveJobs reduces the active job count by 1.
func (m *MemoryStore) DecrementActiveJobs(agentID string) error {
m.mu.Lock()
defer m.mu.Unlock()
u, ok := m.usage[agentID]
if !ok {
return nil
}
if u.ActiveJobs > 0 {
u.ActiveJobs--
}
return nil
}
// ReturnTokens adds tokens back to the agent's remaining quota.
func (m *MemoryStore) ReturnTokens(agentID string, tokens int64) error {
m.mu.Lock()
defer m.mu.Unlock()
u, ok := m.usage[agentID]
if !ok {
return nil
}
u.TokensUsed -= tokens
if u.TokensUsed < 0 {
u.TokensUsed = 0
}
return nil
}
// ResetUsage clears usage counters for an agent.
func (m *MemoryStore) ResetUsage(agentID string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.usage[agentID] = &UsageRecord{
AgentID: agentID,
PeriodStart: startOfDay(time.Now().UTC()),
}
return nil
}
// GetModelQuota returns global limits for a model.
func (m *MemoryStore) GetModelQuota(model string) (*ModelQuota, error) {
m.mu.RLock()
defer m.mu.RUnlock()
q, ok := m.modelQuotas[model]
if !ok {
return nil, &APIError{Code: 404, Message: "model quota not found: " + model}
}
cp := *q
return &cp, nil
}
// GetModelUsage returns current token usage for a model.
func (m *MemoryStore) GetModelUsage(model string) (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.modelUsage[model], nil
}
// IncrementModelUsage atomically adds to a model's usage counter.
func (m *MemoryStore) IncrementModelUsage(model string, tokens int64) error {
m.mu.Lock()
defer m.mu.Unlock()
m.modelUsage[model] += tokens
return nil
}
// SetModelQuota sets global limits for a model (used in testing).
func (m *MemoryStore) SetModelQuota(q *ModelQuota) {
m.mu.Lock()
defer m.mu.Unlock()
cp := *q
m.modelQuotas[q.Model] = &cp
}
// startOfDay returns midnight UTC for the given time.
func startOfDay(t time.Time) time.Time {
y, mo, d := t.Date()
return time.Date(y, mo, d, 0, 0, 0, 0, time.UTC)
}

View file

@ -1,662 +0,0 @@
package lifecycle
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Allowance exhaustion edge cases ---
func TestAllowanceExhaustion_ExactlyAtTokenLimit(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "edge-agent",
DailyTokenLimit: 10000,
})
// Use exactly the limit.
_ = store.IncrementUsage("edge-agent", 10000, 0)
result, err := svc.Check("edge-agent", "")
require.NoError(t, err)
assert.False(t, result.Allowed, "should be denied at exactly the limit")
assert.Equal(t, AllowanceExceeded, result.Status)
assert.Equal(t, int64(0), result.RemainingTokens)
assert.Contains(t, result.Reason, "daily token limit exceeded")
}
func TestAllowanceExhaustion_OneOverTokenLimit(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "edge-agent",
DailyTokenLimit: 10000,
})
_ = store.IncrementUsage("edge-agent", 10001, 0)
result, err := svc.Check("edge-agent", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
assert.Equal(t, AllowanceExceeded, result.Status)
assert.True(t, result.RemainingTokens < 0, "remaining should be negative")
}
func TestAllowanceExhaustion_OneUnderTokenLimit(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "edge-agent",
DailyTokenLimit: 10000,
})
_ = store.IncrementUsage("edge-agent", 9999, 0)
result, err := svc.Check("edge-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed, "should be allowed with 1 token remaining")
assert.Equal(t, AllowanceWarning, result.Status, "99.99% usage should be warning")
assert.Equal(t, int64(1), result.RemainingTokens)
}
func TestAllowanceExhaustion_ZeroAllowance(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
// DailyTokenLimit=0 means unlimited.
_ = store.SetAllowance(&AgentAllowance{
AgentID: "unlimited-agent",
DailyTokenLimit: 0,
DailyJobLimit: 0,
ConcurrentJobs: 0,
})
_ = store.IncrementUsage("unlimited-agent", 999999999, 999)
result, err := svc.Check("unlimited-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed, "unlimited agent should always be allowed")
assert.Equal(t, AllowanceOK, result.Status)
assert.Equal(t, int64(-1), result.RemainingTokens, "unlimited should show -1")
assert.Equal(t, -1, result.RemainingJobs, "unlimited should show -1")
}
func TestAllowanceExhaustion_ExactlyAtJobLimit(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "edge-agent",
DailyJobLimit: 5,
})
_ = store.IncrementUsage("edge-agent", 0, 5)
result, err := svc.Check("edge-agent", "")
require.NoError(t, err)
assert.False(t, result.Allowed, "should be denied at exactly the job limit")
assert.Equal(t, AllowanceExceeded, result.Status)
assert.Equal(t, 0, result.RemainingJobs)
assert.Contains(t, result.Reason, "daily job limit exceeded")
}
func TestAllowanceExhaustion_OneUnderJobLimit(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "edge-agent",
DailyJobLimit: 5,
})
_ = store.IncrementUsage("edge-agent", 0, 4)
result, err := svc.Check("edge-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed, "should be allowed with 1 job remaining")
assert.Equal(t, 1, result.RemainingJobs)
}
func TestAllowanceExhaustion_ConcurrentJobsExactlyAtLimit(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "edge-agent",
ConcurrentJobs: 2,
})
// Start 2 concurrent jobs.
_ = store.IncrementUsage("edge-agent", 0, 2)
result, err := svc.Check("edge-agent", "")
require.NoError(t, err)
assert.False(t, result.Allowed, "should be denied at concurrent limit")
assert.Contains(t, result.Reason, "concurrent job limit reached")
}
func TestAllowanceExhaustion_ConcurrentJobsOneUnderLimit(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "edge-agent",
ConcurrentJobs: 3,
})
_ = store.IncrementUsage("edge-agent", 0, 2)
result, err := svc.Check("edge-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed, "should be allowed with 1 concurrent slot remaining")
}
func TestAllowanceExhaustion_ConcurrentJobsFreedByCompletion(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "edge-agent",
ConcurrentJobs: 1,
})
// Start a job - fills the slot.
_ = svc.RecordUsage(UsageReport{
AgentID: "edge-agent",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
result, err := svc.Check("edge-agent", "")
require.NoError(t, err)
assert.False(t, result.Allowed, "should be denied, 1 active job")
// Complete the job - frees the slot.
_ = svc.RecordUsage(UsageReport{
AgentID: "edge-agent",
JobID: "job-1",
TokensIn: 100,
TokensOut: 50,
Event: QuotaEventJobCompleted,
})
result, err = svc.Check("edge-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed, "should be allowed after job completes")
}
func TestAllowanceExhaustion_TokenWarningThreshold(t *testing.T) {
tests := []struct {
name string
limit int64
used int64
expectedStatus AllowanceStatus
expectedAllow bool
}{
{
name: "79% usage is OK",
limit: 10000,
used: 7900,
expectedStatus: AllowanceOK,
expectedAllow: true,
},
{
name: "80% usage is warning",
limit: 10000,
used: 8000,
expectedStatus: AllowanceWarning,
expectedAllow: true,
},
{
name: "90% usage is warning",
limit: 10000,
used: 9000,
expectedStatus: AllowanceWarning,
expectedAllow: true,
},
{
name: "99% usage is warning",
limit: 10000,
used: 9999,
expectedStatus: AllowanceWarning,
expectedAllow: true,
},
{
name: "100% usage is exceeded",
limit: 10000,
used: 10000,
expectedStatus: AllowanceExceeded,
expectedAllow: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "threshold-agent",
DailyTokenLimit: tt.limit,
})
_ = store.IncrementUsage("threshold-agent", tt.used, 0)
result, err := svc.Check("threshold-agent", "")
require.NoError(t, err)
assert.Equal(t, tt.expectedAllow, result.Allowed)
assert.Equal(t, tt.expectedStatus, result.Status)
})
}
}
func TestAllowanceExhaustion_ResetRestoresCapacity(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "reset-agent",
DailyTokenLimit: 10000,
DailyJobLimit: 5,
})
// Exhaust all limits.
_ = store.IncrementUsage("reset-agent", 10000, 5)
result, err := svc.Check("reset-agent", "")
require.NoError(t, err)
assert.False(t, result.Allowed, "should be denied when exhausted")
// Reset the agent (simulates midnight reset).
err = svc.ResetAgent("reset-agent")
require.NoError(t, err)
result, err = svc.Check("reset-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed, "should be allowed after reset")
assert.Equal(t, int64(10000), result.RemainingTokens)
assert.Equal(t, 5, result.RemainingJobs)
}
func TestAllowanceExhaustion_GlobalModelBudgetExactlyAtLimit(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "model-edge-agent",
})
store.SetModelQuota(&ModelQuota{
Model: "claude-opus-4-6",
DailyTokenBudget: 50000,
})
_ = store.IncrementModelUsage("claude-opus-4-6", 50000)
result, err := svc.Check("model-edge-agent", "claude-opus-4-6")
require.NoError(t, err)
assert.False(t, result.Allowed, "should be denied at exact model budget")
assert.Contains(t, result.Reason, "global model token budget exceeded")
}
func TestAllowanceExhaustion_GlobalModelBudgetOneUnder(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "model-edge-agent",
})
store.SetModelQuota(&ModelQuota{
Model: "claude-opus-4-6",
DailyTokenBudget: 50000,
})
_ = store.IncrementModelUsage("claude-opus-4-6", 49999)
result, err := svc.Check("model-edge-agent", "claude-opus-4-6")
require.NoError(t, err)
assert.True(t, result.Allowed, "should be allowed with 1 token remaining in model budget")
}
func TestAllowanceExhaustion_FailedJobWithZeroTokens(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = svc.RecordUsage(UsageReport{
AgentID: "zero-token-agent",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
// Job fails but consumed zero tokens.
err := svc.RecordUsage(UsageReport{
AgentID: "zero-token-agent",
JobID: "job-1",
Model: "claude-sonnet",
TokensIn: 0,
TokensOut: 0,
Event: QuotaEventJobFailed,
})
require.NoError(t, err)
usage, _ := store.GetUsage("zero-token-agent")
assert.Equal(t, int64(0), usage.TokensUsed, "no tokens should be charged")
assert.Equal(t, 0, usage.ActiveJobs)
// Model usage should be zero too (50% of 0 = 0).
modelUsage, _ := store.GetModelUsage("claude-sonnet")
assert.Equal(t, int64(0), modelUsage)
}
func TestAllowanceExhaustion_CancelledJobWithZeroTokens(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = svc.RecordUsage(UsageReport{
AgentID: "zero-token-agent",
JobID: "job-2",
Event: QuotaEventJobStarted,
})
// Job cancelled with zero tokens.
err := svc.RecordUsage(UsageReport{
AgentID: "zero-token-agent",
JobID: "job-2",
TokensIn: 0,
TokensOut: 0,
Event: QuotaEventJobCancelled,
})
require.NoError(t, err)
usage, _ := store.GetUsage("zero-token-agent")
assert.Equal(t, int64(0), usage.TokensUsed)
assert.Equal(t, 0, usage.ActiveJobs)
}
func TestAllowanceExhaustion_CompletedJobWithNoModel(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = svc.RecordUsage(UsageReport{
AgentID: "no-model-agent",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
// Complete with empty model -- should skip model-level usage recording.
err := svc.RecordUsage(UsageReport{
AgentID: "no-model-agent",
JobID: "job-1",
Model: "",
TokensIn: 500,
TokensOut: 200,
Event: QuotaEventJobCompleted,
})
require.NoError(t, err)
usage, _ := store.GetUsage("no-model-agent")
assert.Equal(t, int64(700), usage.TokensUsed)
assert.Equal(t, 0, usage.ActiveJobs)
}
func TestAllowanceExhaustion_FailedJobWithNoModel(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = svc.RecordUsage(UsageReport{
AgentID: "no-model-fail-agent",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
// Fail with empty model.
err := svc.RecordUsage(UsageReport{
AgentID: "no-model-fail-agent",
JobID: "job-1",
Model: "",
TokensIn: 600,
TokensOut: 400,
Event: QuotaEventJobFailed,
})
require.NoError(t, err)
usage, _ := store.GetUsage("no-model-fail-agent")
// 1000 tokens used, 500 returned = 500 net.
assert.Equal(t, int64(500), usage.TokensUsed)
assert.Equal(t, 0, usage.ActiveJobs)
}
func TestAllowanceExhaustion_MultipleChecksWithIncrementalUsage(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "incremental-agent",
DailyTokenLimit: 1000,
})
// First check: fresh agent.
result, err := svc.Check("incremental-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
assert.Equal(t, AllowanceOK, result.Status)
assert.Equal(t, int64(1000), result.RemainingTokens)
// Use 500 tokens.
_ = store.IncrementUsage("incremental-agent", 500, 0)
result, err = svc.Check("incremental-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
assert.Equal(t, AllowanceOK, result.Status)
assert.Equal(t, int64(500), result.RemainingTokens)
// Use another 300 tokens (total 800, at 80% threshold).
_ = store.IncrementUsage("incremental-agent", 300, 0)
result, err = svc.Check("incremental-agent", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
assert.Equal(t, AllowanceWarning, result.Status)
assert.Equal(t, int64(200), result.RemainingTokens)
// Use remaining 200 tokens (total 1000, at 100%).
_ = store.IncrementUsage("incremental-agent", 200, 0)
result, err = svc.Check("incremental-agent", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
assert.Equal(t, AllowanceExceeded, result.Status)
assert.Equal(t, int64(0), result.RemainingTokens)
}
// --- MemoryStore additional edge cases ---
func TestMemoryStore_GetUsage_NewAgentReturnsDefaults(t *testing.T) {
store := NewMemoryStore()
usage, err := store.GetUsage("brand-new-agent")
require.NoError(t, err)
assert.Equal(t, "brand-new-agent", usage.AgentID)
assert.Equal(t, int64(0), usage.TokensUsed)
assert.Equal(t, 0, usage.JobsStarted)
assert.Equal(t, 0, usage.ActiveJobs)
assert.Equal(t, startOfDay(time.Now().UTC()), usage.PeriodStart)
}
func TestMemoryStore_ReturnTokens_NonexistentAgent(t *testing.T) {
store := NewMemoryStore()
// ReturnTokens on a nonexistent agent should be a no-op.
err := store.ReturnTokens("ghost-agent", 5000)
require.NoError(t, err)
}
func TestMemoryStore_DecrementActiveJobs_NonexistentAgent(t *testing.T) {
store := NewMemoryStore()
// DecrementActiveJobs on a nonexistent agent should be a no-op.
err := store.DecrementActiveJobs("ghost-agent")
require.NoError(t, err)
}
func TestMemoryStore_GetModelQuota_NotFound(t *testing.T) {
store := NewMemoryStore()
_, err := store.GetModelQuota("nonexistent-model")
require.Error(t, err)
assert.Contains(t, err.Error(), "model quota not found")
}
func TestMemoryStore_GetModelUsage_NewModelReturnsZero(t *testing.T) {
store := NewMemoryStore()
usage, err := store.GetModelUsage("brand-new-model")
require.NoError(t, err)
assert.Equal(t, int64(0), usage)
}
func TestMemoryStore_SetAllowance_Overwrite(t *testing.T) {
store := NewMemoryStore()
_ = store.SetAllowance(&AgentAllowance{
AgentID: "overwrite-agent",
DailyTokenLimit: 5000,
})
_ = store.SetAllowance(&AgentAllowance{
AgentID: "overwrite-agent",
DailyTokenLimit: 9000,
})
a, err := store.GetAllowance("overwrite-agent")
require.NoError(t, err)
assert.Equal(t, int64(9000), a.DailyTokenLimit, "should have overwritten the old allowance")
}
func TestMemoryStore_SetAllowance_IsolatesOriginal(t *testing.T) {
store := NewMemoryStore()
original := &AgentAllowance{
AgentID: "isolated-agent",
DailyTokenLimit: 5000,
}
_ = store.SetAllowance(original)
// Mutate the original.
original.DailyTokenLimit = 99999
got, err := store.GetAllowance("isolated-agent")
require.NoError(t, err)
assert.Equal(t, int64(5000), got.DailyTokenLimit, "store should hold a copy, not the original")
}
func TestMemoryStore_GetAllowance_IsolatesReturn(t *testing.T) {
store := NewMemoryStore()
_ = store.SetAllowance(&AgentAllowance{
AgentID: "isolated-agent",
DailyTokenLimit: 5000,
})
got1, _ := store.GetAllowance("isolated-agent")
got1.DailyTokenLimit = 99999
got2, _ := store.GetAllowance("isolated-agent")
assert.Equal(t, int64(5000), got2.DailyTokenLimit, "returned value should be a copy")
}
func TestMemoryStore_IncrementUsage_MultipleIncrements(t *testing.T) {
store := NewMemoryStore()
_ = store.IncrementUsage("multi-agent", 100, 1)
_ = store.IncrementUsage("multi-agent", 200, 1)
_ = store.IncrementUsage("multi-agent", 300, 0)
usage, err := store.GetUsage("multi-agent")
require.NoError(t, err)
assert.Equal(t, int64(600), usage.TokensUsed)
assert.Equal(t, 2, usage.JobsStarted)
assert.Equal(t, 2, usage.ActiveJobs)
}
func TestMemoryStore_IncrementUsage_ZeroJobsDoesNotIncrementActive(t *testing.T) {
store := NewMemoryStore()
_ = store.IncrementUsage("token-only-agent", 5000, 0)
usage, err := store.GetUsage("token-only-agent")
require.NoError(t, err)
assert.Equal(t, int64(5000), usage.TokensUsed)
assert.Equal(t, 0, usage.JobsStarted)
assert.Equal(t, 0, usage.ActiveJobs, "zero jobs should not increment active count")
}
// --- AllowanceService Check priority ordering ---
func TestAllowanceServiceCheck_ModelAllowlistCheckedFirst(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
// Agent is over token limit AND using a disallowed model.
_ = store.SetAllowance(&AgentAllowance{
AgentID: "order-agent",
DailyTokenLimit: 1000,
ModelAllowlist: []string{"claude-haiku"},
})
_ = store.IncrementUsage("order-agent", 2000, 0)
result, err := svc.Check("order-agent", "claude-opus-4-6")
require.NoError(t, err)
assert.False(t, result.Allowed)
// Model allowlist is checked first in the code, so it should be the reason.
assert.Contains(t, result.Reason, "model not in allowlist")
}
func TestAllowanceServiceCheck_EmptyModelAllowlistPermitsAll(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "any-model-agent",
ModelAllowlist: nil,
})
result, err := svc.Check("any-model-agent", "any-model-at-all")
require.NoError(t, err)
assert.True(t, result.Allowed)
}
// --- QuotaEvent constants ---
func TestQuotaEvent_Values(t *testing.T) {
assert.Equal(t, QuotaEvent("job_started"), QuotaEventJobStarted)
assert.Equal(t, QuotaEvent("job_completed"), QuotaEventJobCompleted)
assert.Equal(t, QuotaEvent("job_failed"), QuotaEventJobFailed)
assert.Equal(t, QuotaEvent("job_cancelled"), QuotaEventJobCancelled)
}
func TestAllowanceExhaustion_FailedJobWithOddTokenCount(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = svc.RecordUsage(UsageReport{
AgentID: "odd-agent",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
// Odd total: 7 tokens. 50% return = 3 (integer division).
err := svc.RecordUsage(UsageReport{
AgentID: "odd-agent",
JobID: "job-1",
Model: "claude-sonnet",
TokensIn: 4,
TokensOut: 3,
Event: QuotaEventJobFailed,
})
require.NoError(t, err)
usage, _ := store.GetUsage("odd-agent")
// 7 charged - 3 returned = 4 net.
assert.Equal(t, int64(4), usage.TokensUsed)
// Model gets 7 - 3 = 4.
modelUsage, _ := store.GetModelUsage("claude-sonnet")
assert.Equal(t, int64(4), modelUsage)
}

View file

@ -1,272 +0,0 @@
package lifecycle
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// errorStore is a mock AllowanceStore that returns errors for specific operations.
type errorStore struct {
*MemoryStore
failIncrementUsage bool
failDecrementActive bool
failReturnTokens bool
failIncrementModel bool
failGetAllowance bool
failGetUsage bool
}
func newErrorStore() *errorStore {
return &errorStore{MemoryStore: NewMemoryStore()}
}
func (e *errorStore) GetAllowance(agentID string) (*AgentAllowance, error) {
if e.failGetAllowance {
return nil, errors.New("simulated GetAllowance error")
}
return e.MemoryStore.GetAllowance(agentID)
}
func (e *errorStore) GetUsage(agentID string) (*UsageRecord, error) {
if e.failGetUsage {
return nil, errors.New("simulated GetUsage error")
}
return e.MemoryStore.GetUsage(agentID)
}
func (e *errorStore) IncrementUsage(agentID string, tokens int64, jobs int) error {
if e.failIncrementUsage {
return errors.New("simulated IncrementUsage error")
}
return e.MemoryStore.IncrementUsage(agentID, tokens, jobs)
}
func (e *errorStore) DecrementActiveJobs(agentID string) error {
if e.failDecrementActive {
return errors.New("simulated DecrementActiveJobs error")
}
return e.MemoryStore.DecrementActiveJobs(agentID)
}
func (e *errorStore) ReturnTokens(agentID string, tokens int64) error {
if e.failReturnTokens {
return errors.New("simulated ReturnTokens error")
}
return e.MemoryStore.ReturnTokens(agentID, tokens)
}
func (e *errorStore) IncrementModelUsage(model string, tokens int64) error {
if e.failIncrementModel {
return errors.New("simulated IncrementModelUsage error")
}
return e.MemoryStore.IncrementModelUsage(model, tokens)
}
// --- RecordUsage error path tests ---
func TestRecordUsage_Bad_JobStarted_IncrementFails(t *testing.T) {
store := newErrorStore()
store.failIncrementUsage = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
Event: QuotaEventJobStarted,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to increment job count")
}
func TestRecordUsage_Bad_JobCompleted_IncrementFails(t *testing.T) {
store := newErrorStore()
store.failIncrementUsage = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
TokensIn: 100,
TokensOut: 50,
Event: QuotaEventJobCompleted,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to record token usage")
}
func TestRecordUsage_Bad_JobCompleted_DecrementFails(t *testing.T) {
store := newErrorStore()
store.failDecrementActive = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
TokensIn: 100,
TokensOut: 50,
Event: QuotaEventJobCompleted,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to decrement active jobs")
}
func TestRecordUsage_Bad_JobCompleted_ModelUsageFails(t *testing.T) {
store := newErrorStore()
store.failIncrementModel = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
Model: "claude-sonnet",
TokensIn: 100,
TokensOut: 50,
Event: QuotaEventJobCompleted,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to record model usage")
}
func TestRecordUsage_Bad_JobFailed_IncrementFails(t *testing.T) {
store := newErrorStore()
store.failIncrementUsage = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
TokensIn: 100,
TokensOut: 100,
Event: QuotaEventJobFailed,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to record token usage")
}
func TestRecordUsage_Bad_JobFailed_DecrementFails(t *testing.T) {
store := newErrorStore()
store.failDecrementActive = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
TokensIn: 100,
TokensOut: 100,
Event: QuotaEventJobFailed,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to decrement active jobs")
}
func TestRecordUsage_Bad_JobFailed_ReturnTokensFails(t *testing.T) {
store := newErrorStore()
store.failReturnTokens = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
TokensIn: 100,
TokensOut: 100,
Event: QuotaEventJobFailed,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to return tokens")
}
func TestRecordUsage_Bad_JobFailed_ModelUsageFails(t *testing.T) {
store := newErrorStore()
store.failIncrementModel = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
Model: "claude-sonnet",
TokensIn: 100,
TokensOut: 100,
Event: QuotaEventJobFailed,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to record model usage")
}
func TestRecordUsage_Bad_JobCancelled_DecrementFails(t *testing.T) {
store := newErrorStore()
store.failDecrementActive = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
TokensIn: 100,
TokensOut: 100,
Event: QuotaEventJobCancelled,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to decrement active jobs")
}
func TestRecordUsage_Bad_JobCancelled_ReturnTokensFails(t *testing.T) {
store := newErrorStore()
store.failReturnTokens = true
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
TokensIn: 100,
TokensOut: 100,
Event: QuotaEventJobCancelled,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to return tokens")
}
// --- Check error path tests ---
func TestCheck_Bad_GetAllowanceFails(t *testing.T) {
store := newErrorStore()
store.failGetAllowance = true
svc := NewAllowanceService(store)
_, err := svc.Check("agent-1", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to get allowance")
}
func TestCheck_Bad_GetUsageFails(t *testing.T) {
store := newErrorStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
})
store.failGetUsage = true
_, err := svc.Check("agent-1", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to get usage")
}
// --- ResetAgent error path ---
func TestResetAgent_Bad_ResetFails(t *testing.T) {
// MemoryStore.ResetUsage never fails, but we can test the service
// layer still returns nil for the happy path (already tested).
// For a true error test, we'd need a mock, but the MemoryStore
// never errors on ResetUsage. This confirms the pattern.
store := NewMemoryStore()
svc := NewAllowanceService(store)
err := svc.ResetAgent("nonexistent-agent")
require.NoError(t, err, "resetting a nonexistent agent should succeed")
}
// --- RecordUsage with unknown event type ---
func TestRecordUsage_Good_UnknownEvent(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
// Unknown event should be a no-op (falls through the switch).
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
Event: QuotaEvent("unknown_event"),
})
require.NoError(t, err, "unknown event should not error")
}

View file

@ -1,409 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"errors"
"iter"
"time"
"github.com/redis/go-redis/v9"
)
// RedisStore implements AllowanceStore using Redis as the backing store.
// It provides persistent, network-accessible storage suitable for multi-node
// deployments where agents share quota state.
type RedisStore struct {
client *redis.Client
prefix string
}
// Allowances returns an iterator over all agent allowances.
func (r *RedisStore) Allowances() iter.Seq[*AgentAllowance] {
return func(yield func(*AgentAllowance) bool) {
ctx := context.Background()
pattern := r.prefix + ":allowance:*"
iter := r.client.Scan(ctx, 0, pattern, 100).Iterator()
for iter.Next(ctx) {
val, err := r.client.Get(ctx, iter.Val()).Result()
if err != nil {
continue
}
var aj allowanceJSON
if err := json.Unmarshal([]byte(val), &aj); err != nil {
continue
}
if !yield(aj.toAgentAllowance()) {
return
}
}
}
}
// Usages returns an iterator over all usage records.
func (r *RedisStore) Usages() iter.Seq[*UsageRecord] {
return func(yield func(*UsageRecord) bool) {
ctx := context.Background()
pattern := r.prefix + ":usage:*"
iter := r.client.Scan(ctx, 0, pattern, 100).Iterator()
for iter.Next(ctx) {
val, err := r.client.Get(ctx, iter.Val()).Result()
if err != nil {
continue
}
var u UsageRecord
if err := json.Unmarshal([]byte(val), &u); err != nil {
continue
}
if !yield(&u) {
return
}
}
}
}
// redisConfig holds the configuration for a RedisStore.
type redisConfig struct {
password string
db int
prefix string
}
// RedisOption is a functional option for configuring a RedisStore.
type RedisOption func(*redisConfig)
// WithRedisPassword sets the password for authenticating with Redis.
func WithRedisPassword(pw string) RedisOption {
return func(c *redisConfig) {
c.password = pw
}
}
// WithRedisDB selects the Redis database number.
func WithRedisDB(db int) RedisOption {
return func(c *redisConfig) {
c.db = db
}
}
// WithRedisPrefix sets the key prefix for all Redis keys. Default: "agentic".
func WithRedisPrefix(prefix string) RedisOption {
return func(c *redisConfig) {
c.prefix = prefix
}
}
// NewRedisStore creates a new Redis-backed allowance store connecting to the
// given address (host:port). It pings the server to verify connectivity.
func NewRedisStore(addr string, opts ...RedisOption) (*RedisStore, error) {
cfg := &redisConfig{
prefix: "agentic",
}
for _, opt := range opts {
opt(cfg)
}
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: cfg.password,
DB: cfg.db,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
_ = client.Close()
return nil, &APIError{Code: 500, Message: "failed to connect to Redis: " + err.Error()}
}
return &RedisStore{
client: client,
prefix: cfg.prefix,
}, nil
}
// Close releases the underlying Redis connection.
func (r *RedisStore) Close() error {
return r.client.Close()
}
// --- key helpers ---
func (r *RedisStore) allowanceKey(agentID string) string {
return r.prefix + ":allowance:" + agentID
}
func (r *RedisStore) usageKey(agentID string) string {
return r.prefix + ":usage:" + agentID
}
func (r *RedisStore) modelQuotaKey(model string) string {
return r.prefix + ":model_quota:" + model
}
func (r *RedisStore) modelUsageKey(model string) string {
return r.prefix + ":model_usage:" + model
}
// --- AllowanceStore interface ---
// GetAllowance returns the quota limits for an agent.
func (r *RedisStore) GetAllowance(agentID string) (*AgentAllowance, error) {
ctx := context.Background()
val, err := r.client.Get(ctx, r.allowanceKey(agentID)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, &APIError{Code: 404, Message: "allowance not found for agent: " + agentID}
}
return nil, &APIError{Code: 500, Message: "failed to get allowance: " + err.Error()}
}
var aj allowanceJSON
if err := json.Unmarshal([]byte(val), &aj); err != nil {
return nil, &APIError{Code: 500, Message: "failed to unmarshal allowance: " + err.Error()}
}
return aj.toAgentAllowance(), nil
}
// SetAllowance persists quota limits for an agent.
func (r *RedisStore) SetAllowance(a *AgentAllowance) error {
ctx := context.Background()
aj := newAllowanceJSON(a)
data, err := json.Marshal(aj)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal allowance: " + err.Error()}
}
if err := r.client.Set(ctx, r.allowanceKey(a.AgentID), data, 0).Err(); err != nil {
return &APIError{Code: 500, Message: "failed to set allowance: " + err.Error()}
}
return nil
}
// GetUsage returns the current usage record for an agent.
func (r *RedisStore) GetUsage(agentID string) (*UsageRecord, error) {
ctx := context.Background()
val, err := r.client.Get(ctx, r.usageKey(agentID)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return &UsageRecord{
AgentID: agentID,
PeriodStart: startOfDay(time.Now().UTC()),
}, nil
}
return nil, &APIError{Code: 500, Message: "failed to get usage: " + err.Error()}
}
var u UsageRecord
if err := json.Unmarshal([]byte(val), &u); err != nil {
return nil, &APIError{Code: 500, Message: "failed to unmarshal usage: " + err.Error()}
}
return &u, nil
}
// incrementUsageLua atomically reads, increments, and writes back a usage record.
// KEYS[1] = usage key
// ARGV[1] = tokens to add (int64)
// ARGV[2] = jobs to add (int)
// ARGV[3] = agent ID (for creating a new record)
// ARGV[4] = period start ISO string (for creating a new record)
var incrementUsageLua = redis.NewScript(`
local val = redis.call('GET', KEYS[1])
local u
if val then
u = cjson.decode(val)
else
u = {
agent_id = ARGV[3],
tokens_used = 0,
jobs_started = 0,
active_jobs = 0,
period_start = ARGV[4]
}
end
u.tokens_used = u.tokens_used + tonumber(ARGV[1])
u.jobs_started = u.jobs_started + tonumber(ARGV[2])
if tonumber(ARGV[2]) > 0 then
u.active_jobs = u.active_jobs + tonumber(ARGV[2])
end
redis.call('SET', KEYS[1], cjson.encode(u))
return 'OK'
`)
// IncrementUsage atomically adds to an agent's usage counters.
func (r *RedisStore) IncrementUsage(agentID string, tokens int64, jobs int) error {
ctx := context.Background()
periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339)
err := incrementUsageLua.Run(ctx, r.client,
[]string{r.usageKey(agentID)},
tokens, jobs, agentID, periodStart,
).Err()
if err != nil {
return &APIError{Code: 500, Message: "failed to increment usage: " + err.Error()}
}
return nil
}
// decrementActiveJobsLua atomically decrements the active jobs counter, flooring at zero.
// KEYS[1] = usage key
// ARGV[1] = agent ID
// ARGV[2] = period start ISO string
var decrementActiveJobsLua = redis.NewScript(`
local val = redis.call('GET', KEYS[1])
if not val then
return 'OK'
end
local u = cjson.decode(val)
if u.active_jobs and u.active_jobs > 0 then
u.active_jobs = u.active_jobs - 1
end
redis.call('SET', KEYS[1], cjson.encode(u))
return 'OK'
`)
// DecrementActiveJobs reduces the active job count by 1.
func (r *RedisStore) DecrementActiveJobs(agentID string) error {
ctx := context.Background()
periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339)
err := decrementActiveJobsLua.Run(ctx, r.client,
[]string{r.usageKey(agentID)},
agentID, periodStart,
).Err()
if err != nil {
return &APIError{Code: 500, Message: "failed to decrement active jobs: " + err.Error()}
}
return nil
}
// returnTokensLua atomically subtracts tokens from usage, flooring at zero.
// KEYS[1] = usage key
// ARGV[1] = tokens to return (int64)
// ARGV[2] = agent ID
// ARGV[3] = period start ISO string
var returnTokensLua = redis.NewScript(`
local val = redis.call('GET', KEYS[1])
if not val then
return 'OK'
end
local u = cjson.decode(val)
u.tokens_used = u.tokens_used - tonumber(ARGV[1])
if u.tokens_used < 0 then
u.tokens_used = 0
end
redis.call('SET', KEYS[1], cjson.encode(u))
return 'OK'
`)
// ReturnTokens adds tokens back to the agent's remaining quota.
func (r *RedisStore) ReturnTokens(agentID string, tokens int64) error {
ctx := context.Background()
periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339)
err := returnTokensLua.Run(ctx, r.client,
[]string{r.usageKey(agentID)},
tokens, agentID, periodStart,
).Err()
if err != nil {
return &APIError{Code: 500, Message: "failed to return tokens: " + err.Error()}
}
return nil
}
// ResetUsage clears usage counters for an agent (daily reset).
func (r *RedisStore) ResetUsage(agentID string) error {
ctx := context.Background()
u := &UsageRecord{
AgentID: agentID,
PeriodStart: startOfDay(time.Now().UTC()),
}
data, err := json.Marshal(u)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal usage: " + err.Error()}
}
if err := r.client.Set(ctx, r.usageKey(agentID), data, 0).Err(); err != nil {
return &APIError{Code: 500, Message: "failed to reset usage: " + err.Error()}
}
return nil
}
// GetModelQuota returns global limits for a model.
func (r *RedisStore) GetModelQuota(model string) (*ModelQuota, error) {
ctx := context.Background()
val, err := r.client.Get(ctx, r.modelQuotaKey(model)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, &APIError{Code: 404, Message: "model quota not found: " + model}
}
return nil, &APIError{Code: 500, Message: "failed to get model quota: " + err.Error()}
}
var q ModelQuota
if err := json.Unmarshal([]byte(val), &q); err != nil {
return nil, &APIError{Code: 500, Message: "failed to unmarshal model quota: " + err.Error()}
}
return &q, nil
}
// GetModelUsage returns current token usage for a model.
func (r *RedisStore) GetModelUsage(model string) (int64, error) {
ctx := context.Background()
val, err := r.client.Get(ctx, r.modelUsageKey(model)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return 0, nil
}
return 0, &APIError{Code: 500, Message: "failed to get model usage: " + err.Error()}
}
var tokens int64
if err := json.Unmarshal([]byte(val), &tokens); err != nil {
return 0, &APIError{Code: 500, Message: "failed to unmarshal model usage: " + err.Error()}
}
return tokens, nil
}
// incrementModelUsageLua atomically increments the model usage counter.
// KEYS[1] = model usage key
// ARGV[1] = tokens to add
var incrementModelUsageLua = redis.NewScript(`
local val = redis.call('GET', KEYS[1])
local current = 0
if val then
current = tonumber(val)
end
current = current + tonumber(ARGV[1])
redis.call('SET', KEYS[1], tostring(current))
return current
`)
// IncrementModelUsage atomically adds to a model's usage counter.
func (r *RedisStore) IncrementModelUsage(model string, tokens int64) error {
ctx := context.Background()
err := incrementModelUsageLua.Run(ctx, r.client,
[]string{r.modelUsageKey(model)},
tokens,
).Err()
if err != nil {
return &APIError{Code: 500, Message: "failed to increment model usage: " + err.Error()}
}
return nil
}
// SetModelQuota persists global limits for a model.
func (r *RedisStore) SetModelQuota(q *ModelQuota) error {
ctx := context.Background()
data, err := json.Marshal(q)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal model quota: " + err.Error()}
}
if err := r.client.Set(ctx, r.modelQuotaKey(q.Model), data, 0).Err(); err != nil {
return &APIError{Code: 500, Message: "failed to set model quota: " + err.Error()}
}
return nil
}
// FlushPrefix deletes all keys matching the store's prefix. Useful for testing cleanup.
func (r *RedisStore) FlushPrefix(ctx context.Context) error {
iter := r.client.Scan(ctx, 0, r.prefix+":*", 100).Iterator()
for iter.Next(ctx) {
if err := r.client.Del(ctx, iter.Val()).Err(); err != nil {
return err
}
}
return iter.Err()
}

View file

@ -1,454 +0,0 @@
package lifecycle
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testRedisAddr = "10.69.69.87:6379"
// newTestRedisStore creates a RedisStore with a unique prefix for test isolation.
// Skips the test if Redis is unreachable.
func newTestRedisStore(t *testing.T) *RedisStore {
t.Helper()
prefix := fmt.Sprintf("test_%d", time.Now().UnixNano())
s, err := NewRedisStore(testRedisAddr, WithRedisPrefix(prefix))
if err != nil {
t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err)
}
t.Cleanup(func() {
ctx := context.Background()
_ = s.FlushPrefix(ctx)
_ = s.Close()
})
return s
}
// --- SetAllowance / GetAllowance ---
func TestRedisStore_SetGetAllowance_Good(t *testing.T) {
s := newTestRedisStore(t)
a := &AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
DailyJobLimit: 10,
ConcurrentJobs: 2,
MaxJobDuration: 30 * time.Minute,
ModelAllowlist: []string{"claude-sonnet-4-5-20250929"},
}
err := s.SetAllowance(a)
require.NoError(t, err)
got, err := s.GetAllowance("agent-1")
require.NoError(t, err)
assert.Equal(t, a.AgentID, got.AgentID)
assert.Equal(t, a.DailyTokenLimit, got.DailyTokenLimit)
assert.Equal(t, a.DailyJobLimit, got.DailyJobLimit)
assert.Equal(t, a.ConcurrentJobs, got.ConcurrentJobs)
assert.Equal(t, a.MaxJobDuration, got.MaxJobDuration)
assert.Equal(t, a.ModelAllowlist, got.ModelAllowlist)
}
func TestRedisStore_GetAllowance_Bad_NotFound(t *testing.T) {
s := newTestRedisStore(t)
_, err := s.GetAllowance("nonexistent")
require.Error(t, err)
apiErr, ok := err.(*APIError)
require.True(t, ok, "expected *APIError")
assert.Equal(t, 404, apiErr.Code)
assert.Contains(t, err.Error(), "allowance not found")
}
func TestRedisStore_SetAllowance_Good_Overwrite(t *testing.T) {
s := newTestRedisStore(t)
_ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 100})
_ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 200})
got, err := s.GetAllowance("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(200), got.DailyTokenLimit)
}
// --- GetUsage / IncrementUsage ---
func TestRedisStore_GetUsage_Good_Default(t *testing.T) {
s := newTestRedisStore(t)
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, "agent-1", u.AgentID)
assert.Equal(t, int64(0), u.TokensUsed)
assert.Equal(t, 0, u.JobsStarted)
assert.Equal(t, 0, u.ActiveJobs)
}
func TestRedisStore_IncrementUsage_Good(t *testing.T) {
s := newTestRedisStore(t)
err := s.IncrementUsage("agent-1", 5000, 1)
require.NoError(t, err)
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(5000), u.TokensUsed)
assert.Equal(t, 1, u.JobsStarted)
assert.Equal(t, 1, u.ActiveJobs)
}
func TestRedisStore_IncrementUsage_Good_Accumulates(t *testing.T) {
s := newTestRedisStore(t)
_ = s.IncrementUsage("agent-1", 1000, 1)
_ = s.IncrementUsage("agent-1", 2000, 1)
_ = s.IncrementUsage("agent-1", 3000, 0)
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(6000), u.TokensUsed)
assert.Equal(t, 2, u.JobsStarted)
assert.Equal(t, 2, u.ActiveJobs)
}
// --- DecrementActiveJobs ---
func TestRedisStore_DecrementActiveJobs_Good(t *testing.T) {
s := newTestRedisStore(t)
_ = s.IncrementUsage("agent-1", 0, 2)
_ = s.DecrementActiveJobs("agent-1")
u, _ := s.GetUsage("agent-1")
assert.Equal(t, 1, u.ActiveJobs)
}
func TestRedisStore_DecrementActiveJobs_Good_FloorAtZero(t *testing.T) {
s := newTestRedisStore(t)
_ = s.DecrementActiveJobs("agent-1") // no usage record yet
_ = s.IncrementUsage("agent-1", 0, 0)
_ = s.DecrementActiveJobs("agent-1") // should stay at 0
u, _ := s.GetUsage("agent-1")
assert.Equal(t, 0, u.ActiveJobs)
}
// --- ReturnTokens ---
func TestRedisStore_ReturnTokens_Good(t *testing.T) {
s := newTestRedisStore(t)
_ = s.IncrementUsage("agent-1", 10000, 0)
err := s.ReturnTokens("agent-1", 5000)
require.NoError(t, err)
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(5000), u.TokensUsed)
}
func TestRedisStore_ReturnTokens_Good_FloorAtZero(t *testing.T) {
s := newTestRedisStore(t)
_ = s.IncrementUsage("agent-1", 1000, 0)
_ = s.ReturnTokens("agent-1", 5000) // more than used
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(0), u.TokensUsed)
}
func TestRedisStore_ReturnTokens_Good_NoRecord(t *testing.T) {
s := newTestRedisStore(t)
// Return tokens for agent with no usage record -- should be a no-op
err := s.ReturnTokens("agent-1", 500)
require.NoError(t, err)
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(0), u.TokensUsed)
}
// --- ResetUsage ---
func TestRedisStore_ResetUsage_Good(t *testing.T) {
s := newTestRedisStore(t)
_ = s.IncrementUsage("agent-1", 50000, 5)
err := s.ResetUsage("agent-1")
require.NoError(t, err)
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(0), u.TokensUsed)
assert.Equal(t, 0, u.JobsStarted)
assert.Equal(t, 0, u.ActiveJobs)
}
// --- ModelQuota ---
func TestRedisStore_GetModelQuota_Bad_NotFound(t *testing.T) {
s := newTestRedisStore(t)
_, err := s.GetModelQuota("nonexistent")
require.Error(t, err)
apiErr, ok := err.(*APIError)
require.True(t, ok, "expected *APIError")
assert.Equal(t, 404, apiErr.Code)
assert.Contains(t, err.Error(), "model quota not found")
}
func TestRedisStore_SetGetModelQuota_Good(t *testing.T) {
s := newTestRedisStore(t)
q := &ModelQuota{
Model: "claude-opus-4-6",
DailyTokenBudget: 500000,
HourlyRateLimit: 100,
CostCeiling: 10000,
}
err := s.SetModelQuota(q)
require.NoError(t, err)
got, err := s.GetModelQuota("claude-opus-4-6")
require.NoError(t, err)
assert.Equal(t, q.Model, got.Model)
assert.Equal(t, q.DailyTokenBudget, got.DailyTokenBudget)
assert.Equal(t, q.HourlyRateLimit, got.HourlyRateLimit)
assert.Equal(t, q.CostCeiling, got.CostCeiling)
}
// --- ModelUsage ---
func TestRedisStore_ModelUsage_Good(t *testing.T) {
s := newTestRedisStore(t)
_ = s.IncrementModelUsage("claude-sonnet", 10000)
_ = s.IncrementModelUsage("claude-sonnet", 5000)
usage, err := s.GetModelUsage("claude-sonnet")
require.NoError(t, err)
assert.Equal(t, int64(15000), usage)
}
func TestRedisStore_GetModelUsage_Good_Default(t *testing.T) {
s := newTestRedisStore(t)
usage, err := s.GetModelUsage("unknown-model")
require.NoError(t, err)
assert.Equal(t, int64(0), usage)
}
// --- Persistence: set, get, verify ---
func TestRedisStore_Persistence_Good(t *testing.T) {
s := newTestRedisStore(t)
_ = s.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
MaxJobDuration: 15 * time.Minute,
})
_ = s.IncrementUsage("agent-1", 25000, 3)
_ = s.SetModelQuota(&ModelQuota{Model: "claude-opus-4-6", DailyTokenBudget: 500000})
_ = s.IncrementModelUsage("claude-opus-4-6", 42000)
// Verify all data persists (same connection, but data is in Redis)
a, err := s.GetAllowance("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(100000), a.DailyTokenLimit)
assert.Equal(t, 15*time.Minute, a.MaxJobDuration)
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(25000), u.TokensUsed)
assert.Equal(t, 3, u.JobsStarted)
assert.Equal(t, 3, u.ActiveJobs)
q, err := s.GetModelQuota("claude-opus-4-6")
require.NoError(t, err)
assert.Equal(t, int64(500000), q.DailyTokenBudget)
mu, err := s.GetModelUsage("claude-opus-4-6")
require.NoError(t, err)
assert.Equal(t, int64(42000), mu)
}
// --- Concurrent access ---
func TestRedisStore_ConcurrentIncrementUsage_Good(t *testing.T) {
s := newTestRedisStore(t)
const goroutines = 10
const tokensEach = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for range goroutines {
go func() {
defer wg.Done()
err := s.IncrementUsage("agent-1", tokensEach, 1)
assert.NoError(t, err)
}()
}
wg.Wait()
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(goroutines*tokensEach), u.TokensUsed)
assert.Equal(t, goroutines, u.JobsStarted)
assert.Equal(t, goroutines, u.ActiveJobs)
}
func TestRedisStore_ConcurrentModelUsage_Good(t *testing.T) {
s := newTestRedisStore(t)
const goroutines = 10
const tokensEach int64 = 500
var wg sync.WaitGroup
wg.Add(goroutines)
for range goroutines {
go func() {
defer wg.Done()
err := s.IncrementModelUsage("claude-opus-4-6", tokensEach)
assert.NoError(t, err)
}()
}
wg.Wait()
usage, err := s.GetModelUsage("claude-opus-4-6")
require.NoError(t, err)
assert.Equal(t, goroutines*tokensEach, usage)
}
func TestRedisStore_ConcurrentMixed_Good(t *testing.T) {
s := newTestRedisStore(t)
_ = s.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 1000000,
DailyJobLimit: 100,
ConcurrentJobs: 50,
})
const goroutines = 10
var wg sync.WaitGroup
wg.Add(goroutines * 3)
// Increment usage
for range goroutines {
go func() {
defer wg.Done()
_ = s.IncrementUsage("agent-1", 100, 1)
}()
}
// Decrement active jobs
for range goroutines {
go func() {
defer wg.Done()
_ = s.DecrementActiveJobs("agent-1")
}()
}
// Return tokens
for range goroutines {
go func() {
defer wg.Done()
_ = s.ReturnTokens("agent-1", 10)
}()
}
wg.Wait()
// Verify no panics and data is consistent
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.GreaterOrEqual(t, u.TokensUsed, int64(0))
assert.GreaterOrEqual(t, u.ActiveJobs, 0)
}
// --- AllowanceService integration via RedisStore ---
func TestRedisStore_AllowanceServiceCheck_Good(t *testing.T) {
s := newTestRedisStore(t)
svc := NewAllowanceService(s)
_ = s.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
DailyJobLimit: 10,
ConcurrentJobs: 2,
})
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
assert.Equal(t, AllowanceOK, result.Status)
}
func TestRedisStore_AllowanceServiceRecordUsage_Good(t *testing.T) {
s := newTestRedisStore(t)
svc := NewAllowanceService(s)
_ = s.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
})
// Start job
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
require.NoError(t, err)
// Complete job
err = svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Model: "claude-sonnet",
TokensIn: 1000,
TokensOut: 500,
Event: QuotaEventJobCompleted,
})
require.NoError(t, err)
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(1500), u.TokensUsed)
assert.Equal(t, 0, u.ActiveJobs)
}
// --- Config-based factory with redis backend ---
func TestNewAllowanceStoreFromConfig_Good_Redis(t *testing.T) {
cfg := AllowanceConfig{
StoreBackend: "redis",
RedisAddr: testRedisAddr,
}
s, err := NewAllowanceStoreFromConfig(cfg)
if err != nil {
t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err)
}
rs, ok := s.(*RedisStore)
assert.True(t, ok, "expected RedisStore")
_ = rs.Close()
}
// --- Constructor error case ---
func TestNewRedisStore_Bad_Unreachable(t *testing.T) {
_, err := NewRedisStore("127.0.0.1:1") // almost certainly unreachable
require.Error(t, err)
apiErr, ok := err.(*APIError)
require.True(t, ok, "expected *APIError")
assert.Equal(t, 500, apiErr.Code)
assert.Contains(t, err.Error(), "failed to connect to Redis")
}

View file

@ -1,204 +0,0 @@
package lifecycle
import (
"context"
"slices"
"time"
"forge.lthn.ai/core/go-log"
)
// AllowanceService enforces agent quota limits. It provides pre-dispatch checks,
// runtime usage recording, and quota recovery for failed/cancelled jobs.
type AllowanceService struct {
store AllowanceStore
events EventEmitter
}
// NewAllowanceService creates a new AllowanceService with the given store.
func NewAllowanceService(store AllowanceStore) *AllowanceService {
return &AllowanceService{store: store}
}
// SetEventEmitter attaches an event emitter for quota lifecycle notifications.
func (s *AllowanceService) SetEventEmitter(em EventEmitter) {
s.events = em
}
// emitEvent is a convenience helper that publishes an event if an emitter is set.
func (s *AllowanceService) emitEvent(eventType EventType, agentID string, payload any) {
if s.events != nil {
_ = s.events.Emit(context.Background(), Event{
Type: eventType,
AgentID: agentID,
Timestamp: time.Now().UTC(),
Payload: payload,
})
}
}
// Check performs a pre-dispatch allowance check for the given agent and model.
// It verifies daily token limits, daily job limits, concurrent job limits, and
// model allowlists. Returns a QuotaCheckResult indicating whether the agent may proceed.
func (s *AllowanceService) Check(agentID, model string) (*QuotaCheckResult, error) {
const op = "AllowanceService.Check"
allowance, err := s.store.GetAllowance(agentID)
if err != nil {
return nil, log.E(op, "failed to get allowance", err)
}
usage, err := s.store.GetUsage(agentID)
if err != nil {
return nil, log.E(op, "failed to get usage", err)
}
result := &QuotaCheckResult{
Allowed: true,
Status: AllowanceOK,
RemainingTokens: -1, // unlimited
RemainingJobs: -1, // unlimited
}
// Check model allowlist
if len(allowance.ModelAllowlist) > 0 && model != "" {
if !slices.Contains(allowance.ModelAllowlist, model) {
result.Allowed = false
result.Status = AllowanceExceeded
result.Reason = "model not in allowlist: " + model
s.emitEvent(EventQuotaExceeded, agentID, result.Reason)
return result, nil
}
}
// Check daily token limit
if allowance.DailyTokenLimit > 0 {
remaining := allowance.DailyTokenLimit - usage.TokensUsed
result.RemainingTokens = remaining
if remaining <= 0 {
result.Allowed = false
result.Status = AllowanceExceeded
result.Reason = "daily token limit exceeded"
s.emitEvent(EventQuotaExceeded, agentID, result.Reason)
return result, nil
}
ratio := float64(usage.TokensUsed) / float64(allowance.DailyTokenLimit)
if ratio >= 0.8 {
result.Status = AllowanceWarning
s.emitEvent(EventQuotaWarning, agentID, ratio)
}
}
// Check daily job limit
if allowance.DailyJobLimit > 0 {
remaining := allowance.DailyJobLimit - usage.JobsStarted
result.RemainingJobs = remaining
if remaining <= 0 {
result.Allowed = false
result.Status = AllowanceExceeded
result.Reason = "daily job limit exceeded"
s.emitEvent(EventQuotaExceeded, agentID, result.Reason)
return result, nil
}
}
// Check concurrent jobs
if allowance.ConcurrentJobs > 0 && usage.ActiveJobs >= allowance.ConcurrentJobs {
result.Allowed = false
result.Status = AllowanceExceeded
result.Reason = "concurrent job limit reached"
s.emitEvent(EventQuotaExceeded, agentID, result.Reason)
return result, nil
}
// Check global model quota
if model != "" {
modelQuota, err := s.store.GetModelQuota(model)
if err == nil && modelQuota.DailyTokenBudget > 0 {
modelUsage, err := s.store.GetModelUsage(model)
if err == nil && modelUsage >= modelQuota.DailyTokenBudget {
result.Allowed = false
result.Status = AllowanceExceeded
result.Reason = "global model token budget exceeded for: " + model
s.emitEvent(EventQuotaExceeded, agentID, result.Reason)
return result, nil
}
}
}
return result, nil
}
// RecordUsage processes a usage report, updating counters and handling quota recovery.
func (s *AllowanceService) RecordUsage(report UsageReport) error {
const op = "AllowanceService.RecordUsage"
totalTokens := report.TokensIn + report.TokensOut
switch report.Event {
case QuotaEventJobStarted:
if err := s.store.IncrementUsage(report.AgentID, 0, 1); err != nil {
return log.E(op, "failed to increment job count", err)
}
s.emitEvent(EventUsageRecorded, report.AgentID, report)
case QuotaEventJobCompleted:
if err := s.store.IncrementUsage(report.AgentID, totalTokens, 0); err != nil {
return log.E(op, "failed to record token usage", err)
}
if err := s.store.DecrementActiveJobs(report.AgentID); err != nil {
return log.E(op, "failed to decrement active jobs", err)
}
// Record model-level usage
if report.Model != "" {
if err := s.store.IncrementModelUsage(report.Model, totalTokens); err != nil {
return log.E(op, "failed to record model usage", err)
}
}
s.emitEvent(EventUsageRecorded, report.AgentID, report)
case QuotaEventJobFailed:
// Record partial usage, return 50% of tokens
if err := s.store.IncrementUsage(report.AgentID, totalTokens, 0); err != nil {
return log.E(op, "failed to record token usage", err)
}
if err := s.store.DecrementActiveJobs(report.AgentID); err != nil {
return log.E(op, "failed to decrement active jobs", err)
}
returnAmount := totalTokens / 2
if returnAmount > 0 {
if err := s.store.ReturnTokens(report.AgentID, returnAmount); err != nil {
return log.E(op, "failed to return tokens", err)
}
}
// Still record model-level usage (net of return)
if report.Model != "" {
if err := s.store.IncrementModelUsage(report.Model, totalTokens-returnAmount); err != nil {
return log.E(op, "failed to record model usage", err)
}
}
case QuotaEventJobCancelled:
// Return 100% of tokens
if err := s.store.DecrementActiveJobs(report.AgentID); err != nil {
return log.E(op, "failed to decrement active jobs", err)
}
if totalTokens > 0 {
if err := s.store.ReturnTokens(report.AgentID, totalTokens); err != nil {
return log.E(op, "failed to return tokens", err)
}
}
// No model-level usage for cancelled jobs
}
return nil
}
// ResetAgent clears daily usage counters for the given agent (midnight reset).
func (s *AllowanceService) ResetAgent(agentID string) error {
const op = "AllowanceService.ResetAgent"
if err := s.store.ResetUsage(agentID); err != nil {
return log.E(op, "failed to reset usage", err)
}
return nil
}

View file

@ -1,333 +0,0 @@
package lifecycle
import (
"encoding/json"
"errors"
"iter"
"sync"
"time"
"forge.lthn.ai/core/go-store"
)
// SQLite group names for namespacing data in the KV store.
const (
groupAllowances = "allowances"
groupUsage = "usage"
groupModelQuota = "model_quotas"
groupModelUsage = "model_usage"
)
// SQLiteStore implements AllowanceStore using go-store (SQLite KV).
// It provides persistent storage that survives process restarts.
type SQLiteStore struct {
db *store.Store
mu sync.Mutex // serialises read-modify-write operations
}
// Allowances returns an iterator over all agent allowances.
func (s *SQLiteStore) Allowances() iter.Seq[*AgentAllowance] {
return func(yield func(*AgentAllowance) bool) {
for kv, err := range s.db.All(groupAllowances) {
if err != nil {
continue
}
var a allowanceJSON
if err := json.Unmarshal([]byte(kv.Value), &a); err != nil {
continue
}
if !yield(a.toAgentAllowance()) {
return
}
}
}
}
// Usages returns an iterator over all usage records.
func (s *SQLiteStore) Usages() iter.Seq[*UsageRecord] {
return func(yield func(*UsageRecord) bool) {
for kv, err := range s.db.All(groupUsage) {
if err != nil {
continue
}
var u UsageRecord
if err := json.Unmarshal([]byte(kv.Value), &u); err != nil {
continue
}
if !yield(&u) {
return
}
}
}
}
// NewSQLiteStore creates a new SQLite-backed allowance store at the given path.
// Use ":memory:" for tests that do not need persistence.
func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
db, err := store.New(dbPath)
if err != nil {
return nil, &APIError{Code: 500, Message: "failed to open SQLite store: " + err.Error()}
}
return &SQLiteStore{db: db}, nil
}
// Close releases the underlying SQLite database.
func (s *SQLiteStore) Close() error {
return s.db.Close()
}
// GetAllowance returns the quota limits for an agent.
func (s *SQLiteStore) GetAllowance(agentID string) (*AgentAllowance, error) {
val, err := s.db.Get(groupAllowances, agentID)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, &APIError{Code: 404, Message: "allowance not found for agent: " + agentID}
}
return nil, &APIError{Code: 500, Message: "failed to get allowance: " + err.Error()}
}
var a allowanceJSON
if err := json.Unmarshal([]byte(val), &a); err != nil {
return nil, &APIError{Code: 500, Message: "failed to unmarshal allowance: " + err.Error()}
}
return a.toAgentAllowance(), nil
}
// SetAllowance persists quota limits for an agent.
func (s *SQLiteStore) SetAllowance(a *AgentAllowance) error {
aj := newAllowanceJSON(a)
data, err := json.Marshal(aj)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal allowance: " + err.Error()}
}
if err := s.db.Set(groupAllowances, a.AgentID, string(data)); err != nil {
return &APIError{Code: 500, Message: "failed to set allowance: " + err.Error()}
}
return nil
}
// GetUsage returns the current usage record for an agent.
func (s *SQLiteStore) GetUsage(agentID string) (*UsageRecord, error) {
val, err := s.db.Get(groupUsage, agentID)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return &UsageRecord{
AgentID: agentID,
PeriodStart: startOfDay(time.Now().UTC()),
}, nil
}
return nil, &APIError{Code: 500, Message: "failed to get usage: " + err.Error()}
}
var u UsageRecord
if err := json.Unmarshal([]byte(val), &u); err != nil {
return nil, &APIError{Code: 500, Message: "failed to unmarshal usage: " + err.Error()}
}
return &u, nil
}
// IncrementUsage atomically adds to an agent's usage counters.
func (s *SQLiteStore) IncrementUsage(agentID string, tokens int64, jobs int) error {
s.mu.Lock()
defer s.mu.Unlock()
u, err := s.getUsageLocked(agentID)
if err != nil {
return err
}
u.TokensUsed += tokens
u.JobsStarted += jobs
if jobs > 0 {
u.ActiveJobs += jobs
}
return s.putUsageLocked(u)
}
// DecrementActiveJobs reduces the active job count by 1.
func (s *SQLiteStore) DecrementActiveJobs(agentID string) error {
s.mu.Lock()
defer s.mu.Unlock()
u, err := s.getUsageLocked(agentID)
if err != nil {
return err
}
if u.ActiveJobs > 0 {
u.ActiveJobs--
}
return s.putUsageLocked(u)
}
// ReturnTokens adds tokens back to the agent's remaining quota.
func (s *SQLiteStore) ReturnTokens(agentID string, tokens int64) error {
s.mu.Lock()
defer s.mu.Unlock()
u, err := s.getUsageLocked(agentID)
if err != nil {
return err
}
u.TokensUsed -= tokens
if u.TokensUsed < 0 {
u.TokensUsed = 0
}
return s.putUsageLocked(u)
}
// ResetUsage clears usage counters for an agent.
func (s *SQLiteStore) ResetUsage(agentID string) error {
s.mu.Lock()
defer s.mu.Unlock()
u := &UsageRecord{
AgentID: agentID,
PeriodStart: startOfDay(time.Now().UTC()),
}
return s.putUsageLocked(u)
}
// GetModelQuota returns global limits for a model.
func (s *SQLiteStore) GetModelQuota(model string) (*ModelQuota, error) {
val, err := s.db.Get(groupModelQuota, model)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, &APIError{Code: 404, Message: "model quota not found: " + model}
}
return nil, &APIError{Code: 500, Message: "failed to get model quota: " + err.Error()}
}
var q ModelQuota
if err := json.Unmarshal([]byte(val), &q); err != nil {
return nil, &APIError{Code: 500, Message: "failed to unmarshal model quota: " + err.Error()}
}
return &q, nil
}
// GetModelUsage returns current token usage for a model.
func (s *SQLiteStore) GetModelUsage(model string) (int64, error) {
val, err := s.db.Get(groupModelUsage, model)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return 0, nil
}
return 0, &APIError{Code: 500, Message: "failed to get model usage: " + err.Error()}
}
var tokens int64
if err := json.Unmarshal([]byte(val), &tokens); err != nil {
return 0, &APIError{Code: 500, Message: "failed to unmarshal model usage: " + err.Error()}
}
return tokens, nil
}
// IncrementModelUsage atomically adds to a model's usage counter.
func (s *SQLiteStore) IncrementModelUsage(model string, tokens int64) error {
s.mu.Lock()
defer s.mu.Unlock()
current, err := s.getModelUsageLocked(model)
if err != nil {
return err
}
current += tokens
data, err := json.Marshal(current)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal model usage: " + err.Error()}
}
if err := s.db.Set(groupModelUsage, model, string(data)); err != nil {
return &APIError{Code: 500, Message: "failed to set model usage: " + err.Error()}
}
return nil
}
// SetModelQuota persists global limits for a model.
func (s *SQLiteStore) SetModelQuota(q *ModelQuota) error {
data, err := json.Marshal(q)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal model quota: " + err.Error()}
}
if err := s.db.Set(groupModelQuota, q.Model, string(data)); err != nil {
return &APIError{Code: 500, Message: "failed to set model quota: " + err.Error()}
}
return nil
}
// --- internal helpers (must be called with mu held) ---
// getUsageLocked reads a UsageRecord from the store. Caller must hold s.mu.
func (s *SQLiteStore) getUsageLocked(agentID string) (*UsageRecord, error) {
val, err := s.db.Get(groupUsage, agentID)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return &UsageRecord{
AgentID: agentID,
PeriodStart: startOfDay(time.Now().UTC()),
}, nil
}
return nil, &APIError{Code: 500, Message: "failed to get usage: " + err.Error()}
}
var u UsageRecord
if err := json.Unmarshal([]byte(val), &u); err != nil {
return nil, &APIError{Code: 500, Message: "failed to unmarshal usage: " + err.Error()}
}
return &u, nil
}
// putUsageLocked writes a UsageRecord to the store. Caller must hold s.mu.
func (s *SQLiteStore) putUsageLocked(u *UsageRecord) error {
data, err := json.Marshal(u)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal usage: " + err.Error()}
}
if err := s.db.Set(groupUsage, u.AgentID, string(data)); err != nil {
return &APIError{Code: 500, Message: "failed to set usage: " + err.Error()}
}
return nil
}
// getModelUsageLocked reads model usage from the store. Caller must hold s.mu.
func (s *SQLiteStore) getModelUsageLocked(model string) (int64, error) {
val, err := s.db.Get(groupModelUsage, model)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return 0, nil
}
return 0, &APIError{Code: 500, Message: "failed to get model usage: " + err.Error()}
}
var tokens int64
if err := json.Unmarshal([]byte(val), &tokens); err != nil {
return 0, &APIError{Code: 500, Message: "failed to unmarshal model usage: " + err.Error()}
}
return tokens, nil
}
// --- JSON serialisation helper for AgentAllowance ---
// time.Duration does not have a stable JSON representation. We serialise it
// as an int64 (nanoseconds) to avoid locale-dependent string parsing.
type allowanceJSON struct {
AgentID string `json:"agent_id"`
DailyTokenLimit int64 `json:"daily_token_limit"`
DailyJobLimit int `json:"daily_job_limit"`
ConcurrentJobs int `json:"concurrent_jobs"`
MaxJobDurationNs int64 `json:"max_job_duration_ns"`
ModelAllowlist []string `json:"model_allowlist,omitempty"`
}
func newAllowanceJSON(a *AgentAllowance) *allowanceJSON {
return &allowanceJSON{
AgentID: a.AgentID,
DailyTokenLimit: a.DailyTokenLimit,
DailyJobLimit: a.DailyJobLimit,
ConcurrentJobs: a.ConcurrentJobs,
MaxJobDurationNs: int64(a.MaxJobDuration),
ModelAllowlist: a.ModelAllowlist,
}
}
func (aj *allowanceJSON) toAgentAllowance() *AgentAllowance {
return &AgentAllowance{
AgentID: aj.AgentID,
DailyTokenLimit: aj.DailyTokenLimit,
DailyJobLimit: aj.DailyJobLimit,
ConcurrentJobs: aj.ConcurrentJobs,
MaxJobDuration: time.Duration(aj.MaxJobDurationNs),
ModelAllowlist: aj.ModelAllowlist,
}
}

View file

@ -1,465 +0,0 @@
package lifecycle
import (
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestSQLiteStore creates a SQLiteStore in a temp directory.
func newTestSQLiteStore(t *testing.T) *SQLiteStore {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
s, err := NewSQLiteStore(dbPath)
require.NoError(t, err)
t.Cleanup(func() { _ = s.Close() })
return s
}
// --- SetAllowance / GetAllowance ---
func TestSQLiteStore_SetGetAllowance_Good(t *testing.T) {
s := newTestSQLiteStore(t)
a := &AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
DailyJobLimit: 10,
ConcurrentJobs: 2,
MaxJobDuration: 30 * time.Minute,
ModelAllowlist: []string{"claude-sonnet-4-5-20250929"},
}
err := s.SetAllowance(a)
require.NoError(t, err)
got, err := s.GetAllowance("agent-1")
require.NoError(t, err)
assert.Equal(t, a.AgentID, got.AgentID)
assert.Equal(t, a.DailyTokenLimit, got.DailyTokenLimit)
assert.Equal(t, a.DailyJobLimit, got.DailyJobLimit)
assert.Equal(t, a.ConcurrentJobs, got.ConcurrentJobs)
assert.Equal(t, a.MaxJobDuration, got.MaxJobDuration)
assert.Equal(t, a.ModelAllowlist, got.ModelAllowlist)
}
func TestSQLiteStore_GetAllowance_Bad_NotFound(t *testing.T) {
s := newTestSQLiteStore(t)
_, err := s.GetAllowance("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "allowance not found")
}
func TestSQLiteStore_SetAllowance_Good_Overwrite(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 100})
_ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 200})
got, err := s.GetAllowance("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(200), got.DailyTokenLimit)
}
// --- GetUsage / IncrementUsage ---
func TestSQLiteStore_GetUsage_Good_Default(t *testing.T) {
s := newTestSQLiteStore(t)
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, "agent-1", u.AgentID)
assert.Equal(t, int64(0), u.TokensUsed)
assert.Equal(t, 0, u.JobsStarted)
assert.Equal(t, 0, u.ActiveJobs)
}
func TestSQLiteStore_IncrementUsage_Good(t *testing.T) {
s := newTestSQLiteStore(t)
err := s.IncrementUsage("agent-1", 5000, 1)
require.NoError(t, err)
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(5000), u.TokensUsed)
assert.Equal(t, 1, u.JobsStarted)
assert.Equal(t, 1, u.ActiveJobs)
}
func TestSQLiteStore_IncrementUsage_Good_Accumulates(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.IncrementUsage("agent-1", 1000, 1)
_ = s.IncrementUsage("agent-1", 2000, 1)
_ = s.IncrementUsage("agent-1", 3000, 0)
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(6000), u.TokensUsed)
assert.Equal(t, 2, u.JobsStarted)
assert.Equal(t, 2, u.ActiveJobs)
}
// --- DecrementActiveJobs ---
func TestSQLiteStore_DecrementActiveJobs_Good(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.IncrementUsage("agent-1", 0, 2)
_ = s.DecrementActiveJobs("agent-1")
u, _ := s.GetUsage("agent-1")
assert.Equal(t, 1, u.ActiveJobs)
}
func TestSQLiteStore_DecrementActiveJobs_Good_FloorAtZero(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.DecrementActiveJobs("agent-1") // no usage record yet
_ = s.IncrementUsage("agent-1", 0, 0)
_ = s.DecrementActiveJobs("agent-1") // should stay at 0
u, _ := s.GetUsage("agent-1")
assert.Equal(t, 0, u.ActiveJobs)
}
// --- ReturnTokens ---
func TestSQLiteStore_ReturnTokens_Good(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.IncrementUsage("agent-1", 10000, 0)
err := s.ReturnTokens("agent-1", 5000)
require.NoError(t, err)
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(5000), u.TokensUsed)
}
func TestSQLiteStore_ReturnTokens_Good_FloorAtZero(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.IncrementUsage("agent-1", 1000, 0)
_ = s.ReturnTokens("agent-1", 5000) // more than used
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(0), u.TokensUsed)
}
func TestSQLiteStore_ReturnTokens_Good_NoRecord(t *testing.T) {
s := newTestSQLiteStore(t)
// Return tokens for agent with no usage record -- should create one at 0
err := s.ReturnTokens("agent-1", 500)
require.NoError(t, err)
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(0), u.TokensUsed)
}
// --- ResetUsage ---
func TestSQLiteStore_ResetUsage_Good(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.IncrementUsage("agent-1", 50000, 5)
err := s.ResetUsage("agent-1")
require.NoError(t, err)
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(0), u.TokensUsed)
assert.Equal(t, 0, u.JobsStarted)
assert.Equal(t, 0, u.ActiveJobs)
}
// --- ModelQuota ---
func TestSQLiteStore_GetModelQuota_Bad_NotFound(t *testing.T) {
s := newTestSQLiteStore(t)
_, err := s.GetModelQuota("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "model quota not found")
}
func TestSQLiteStore_SetGetModelQuota_Good(t *testing.T) {
s := newTestSQLiteStore(t)
q := &ModelQuota{
Model: "claude-opus-4-6",
DailyTokenBudget: 500000,
HourlyRateLimit: 100,
CostCeiling: 10000,
}
err := s.SetModelQuota(q)
require.NoError(t, err)
got, err := s.GetModelQuota("claude-opus-4-6")
require.NoError(t, err)
assert.Equal(t, q.Model, got.Model)
assert.Equal(t, q.DailyTokenBudget, got.DailyTokenBudget)
assert.Equal(t, q.HourlyRateLimit, got.HourlyRateLimit)
assert.Equal(t, q.CostCeiling, got.CostCeiling)
}
// --- ModelUsage ---
func TestSQLiteStore_ModelUsage_Good(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.IncrementModelUsage("claude-sonnet", 10000)
_ = s.IncrementModelUsage("claude-sonnet", 5000)
usage, err := s.GetModelUsage("claude-sonnet")
require.NoError(t, err)
assert.Equal(t, int64(15000), usage)
}
func TestSQLiteStore_GetModelUsage_Good_Default(t *testing.T) {
s := newTestSQLiteStore(t)
usage, err := s.GetModelUsage("unknown-model")
require.NoError(t, err)
assert.Equal(t, int64(0), usage)
}
// --- Persistence: close and reopen ---
func TestSQLiteStore_Persistence_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "persist.db")
// Phase 1: write data
s1, err := NewSQLiteStore(dbPath)
require.NoError(t, err)
_ = s1.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
MaxJobDuration: 15 * time.Minute,
})
_ = s1.IncrementUsage("agent-1", 25000, 3)
_ = s1.SetModelQuota(&ModelQuota{Model: "claude-opus-4-6", DailyTokenBudget: 500000})
_ = s1.IncrementModelUsage("claude-opus-4-6", 42000)
require.NoError(t, s1.Close())
// Phase 2: reopen and verify
s2, err := NewSQLiteStore(dbPath)
require.NoError(t, err)
defer func() { _ = s2.Close() }()
a, err := s2.GetAllowance("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(100000), a.DailyTokenLimit)
assert.Equal(t, 15*time.Minute, a.MaxJobDuration)
u, err := s2.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(25000), u.TokensUsed)
assert.Equal(t, 3, u.JobsStarted)
assert.Equal(t, 3, u.ActiveJobs)
q, err := s2.GetModelQuota("claude-opus-4-6")
require.NoError(t, err)
assert.Equal(t, int64(500000), q.DailyTokenBudget)
mu, err := s2.GetModelUsage("claude-opus-4-6")
require.NoError(t, err)
assert.Equal(t, int64(42000), mu)
}
// --- Concurrent access ---
func TestSQLiteStore_ConcurrentIncrementUsage_Good(t *testing.T) {
s := newTestSQLiteStore(t)
const goroutines = 10
const tokensEach = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for range goroutines {
go func() {
defer wg.Done()
err := s.IncrementUsage("agent-1", tokensEach, 1)
assert.NoError(t, err)
}()
}
wg.Wait()
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(goroutines*tokensEach), u.TokensUsed)
assert.Equal(t, goroutines, u.JobsStarted)
assert.Equal(t, goroutines, u.ActiveJobs)
}
func TestSQLiteStore_ConcurrentModelUsage_Good(t *testing.T) {
s := newTestSQLiteStore(t)
const goroutines = 10
const tokensEach int64 = 500
var wg sync.WaitGroup
wg.Add(goroutines)
for range goroutines {
go func() {
defer wg.Done()
err := s.IncrementModelUsage("claude-opus-4-6", tokensEach)
assert.NoError(t, err)
}()
}
wg.Wait()
usage, err := s.GetModelUsage("claude-opus-4-6")
require.NoError(t, err)
assert.Equal(t, goroutines*tokensEach, usage)
}
func TestSQLiteStore_ConcurrentMixed_Good(t *testing.T) {
s := newTestSQLiteStore(t)
_ = s.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 1000000,
DailyJobLimit: 100,
ConcurrentJobs: 50,
})
const goroutines = 10
var wg sync.WaitGroup
wg.Add(goroutines * 3)
// Increment usage
for range goroutines {
go func() {
defer wg.Done()
_ = s.IncrementUsage("agent-1", 100, 1)
}()
}
// Decrement active jobs
for range goroutines {
go func() {
defer wg.Done()
_ = s.DecrementActiveJobs("agent-1")
}()
}
// Return tokens
for range goroutines {
go func() {
defer wg.Done()
_ = s.ReturnTokens("agent-1", 10)
}()
}
wg.Wait()
// Just verify no panics and data is consistent
u, err := s.GetUsage("agent-1")
require.NoError(t, err)
assert.GreaterOrEqual(t, u.TokensUsed, int64(0))
assert.GreaterOrEqual(t, u.ActiveJobs, 0)
}
// --- AllowanceService integration via SQLiteStore ---
func TestSQLiteStore_AllowanceServiceCheck_Good(t *testing.T) {
s := newTestSQLiteStore(t)
svc := NewAllowanceService(s)
_ = s.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
DailyJobLimit: 10,
ConcurrentJobs: 2,
})
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
assert.Equal(t, AllowanceOK, result.Status)
}
func TestSQLiteStore_AllowanceServiceRecordUsage_Good(t *testing.T) {
s := newTestSQLiteStore(t)
svc := NewAllowanceService(s)
_ = s.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
})
// Start job
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
require.NoError(t, err)
// Complete job
err = svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Model: "claude-sonnet",
TokensIn: 1000,
TokensOut: 500,
Event: QuotaEventJobCompleted,
})
require.NoError(t, err)
u, _ := s.GetUsage("agent-1")
assert.Equal(t, int64(1500), u.TokensUsed)
assert.Equal(t, 0, u.ActiveJobs)
}
// --- Config-based factory ---
func TestNewAllowanceStoreFromConfig_Good_Memory(t *testing.T) {
cfg := AllowanceConfig{StoreBackend: "memory"}
s, err := NewAllowanceStoreFromConfig(cfg)
require.NoError(t, err)
_, ok := s.(*MemoryStore)
assert.True(t, ok, "expected MemoryStore")
}
func TestNewAllowanceStoreFromConfig_Good_Default(t *testing.T) {
cfg := AllowanceConfig{} // empty defaults to memory
s, err := NewAllowanceStoreFromConfig(cfg)
require.NoError(t, err)
_, ok := s.(*MemoryStore)
assert.True(t, ok, "expected MemoryStore for empty config")
}
func TestNewAllowanceStoreFromConfig_Good_SQLite(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "factory.db")
cfg := AllowanceConfig{
StoreBackend: "sqlite",
StorePath: dbPath,
}
s, err := NewAllowanceStoreFromConfig(cfg)
require.NoError(t, err)
ss, ok := s.(*SQLiteStore)
assert.True(t, ok, "expected SQLiteStore")
_ = ss.Close()
}
func TestNewAllowanceStoreFromConfig_Bad_UnknownBackend(t *testing.T) {
cfg := AllowanceConfig{StoreBackend: "cassandra"}
_, err := NewAllowanceStoreFromConfig(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported store backend")
}
// --- NewSQLiteStore error case ---
func TestNewSQLiteStore_Bad_InvalidPath(t *testing.T) {
_, err := NewSQLiteStore("/nonexistent/deeply/nested/dir/test.db")
require.Error(t, err)
}

View file

@ -1,407 +0,0 @@
package lifecycle
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- MemoryStore tests ---
func TestMemoryStore_SetGetAllowance_Good(t *testing.T) {
store := NewMemoryStore()
a := &AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
DailyJobLimit: 10,
ConcurrentJobs: 2,
MaxJobDuration: 30 * time.Minute,
ModelAllowlist: []string{"claude-sonnet-4-5-20250929"},
}
err := store.SetAllowance(a)
require.NoError(t, err)
got, err := store.GetAllowance("agent-1")
require.NoError(t, err)
assert.Equal(t, a.AgentID, got.AgentID)
assert.Equal(t, a.DailyTokenLimit, got.DailyTokenLimit)
assert.Equal(t, a.DailyJobLimit, got.DailyJobLimit)
assert.Equal(t, a.ConcurrentJobs, got.ConcurrentJobs)
assert.Equal(t, a.ModelAllowlist, got.ModelAllowlist)
}
func TestMemoryStore_GetAllowance_Bad_NotFound(t *testing.T) {
store := NewMemoryStore()
_, err := store.GetAllowance("nonexistent")
require.Error(t, err)
}
func TestMemoryStore_IncrementUsage_Good(t *testing.T) {
store := NewMemoryStore()
err := store.IncrementUsage("agent-1", 5000, 1)
require.NoError(t, err)
usage, err := store.GetUsage("agent-1")
require.NoError(t, err)
assert.Equal(t, int64(5000), usage.TokensUsed)
assert.Equal(t, 1, usage.JobsStarted)
assert.Equal(t, 1, usage.ActiveJobs)
}
func TestMemoryStore_DecrementActiveJobs_Good(t *testing.T) {
store := NewMemoryStore()
_ = store.IncrementUsage("agent-1", 0, 2)
_ = store.DecrementActiveJobs("agent-1")
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, 1, usage.ActiveJobs)
}
func TestMemoryStore_DecrementActiveJobs_Good_FloorAtZero(t *testing.T) {
store := NewMemoryStore()
_ = store.DecrementActiveJobs("agent-1") // no-op, no usage record
_ = store.IncrementUsage("agent-1", 0, 0)
_ = store.DecrementActiveJobs("agent-1") // should stay at 0
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, 0, usage.ActiveJobs)
}
func TestMemoryStore_ReturnTokens_Good(t *testing.T) {
store := NewMemoryStore()
_ = store.IncrementUsage("agent-1", 10000, 0)
err := store.ReturnTokens("agent-1", 5000)
require.NoError(t, err)
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, int64(5000), usage.TokensUsed)
}
func TestMemoryStore_ReturnTokens_Good_FloorAtZero(t *testing.T) {
store := NewMemoryStore()
_ = store.IncrementUsage("agent-1", 1000, 0)
_ = store.ReturnTokens("agent-1", 5000) // more than used
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, int64(0), usage.TokensUsed)
}
func TestMemoryStore_ResetUsage_Good(t *testing.T) {
store := NewMemoryStore()
_ = store.IncrementUsage("agent-1", 50000, 5)
err := store.ResetUsage("agent-1")
require.NoError(t, err)
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, int64(0), usage.TokensUsed)
assert.Equal(t, 0, usage.JobsStarted)
assert.Equal(t, 0, usage.ActiveJobs)
}
func TestMemoryStore_ModelUsage_Good(t *testing.T) {
store := NewMemoryStore()
_ = store.IncrementModelUsage("claude-sonnet", 10000)
_ = store.IncrementModelUsage("claude-sonnet", 5000)
usage, err := store.GetModelUsage("claude-sonnet")
require.NoError(t, err)
assert.Equal(t, int64(15000), usage)
}
// --- AllowanceService.Check tests ---
func TestAllowanceServiceCheck_Good(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
DailyJobLimit: 10,
ConcurrentJobs: 2,
})
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
assert.Equal(t, AllowanceOK, result.Status)
assert.Equal(t, int64(100000), result.RemainingTokens)
assert.Equal(t, 10, result.RemainingJobs)
}
func TestAllowanceServiceCheck_Good_Warning(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
})
_ = store.IncrementUsage("agent-1", 85000, 0)
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
assert.Equal(t, AllowanceWarning, result.Status)
assert.Equal(t, int64(15000), result.RemainingTokens)
}
func TestAllowanceServiceCheck_Bad_TokenLimitExceeded(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100000,
})
_ = store.IncrementUsage("agent-1", 100001, 0)
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
assert.Equal(t, AllowanceExceeded, result.Status)
assert.Contains(t, result.Reason, "daily token limit")
}
func TestAllowanceServiceCheck_Bad_JobLimitExceeded(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyJobLimit: 5,
})
_ = store.IncrementUsage("agent-1", 0, 5)
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
assert.Contains(t, result.Reason, "daily job limit")
}
func TestAllowanceServiceCheck_Bad_ConcurrentLimitReached(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
ConcurrentJobs: 1,
})
_ = store.IncrementUsage("agent-1", 0, 1) // 1 active job
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
assert.Contains(t, result.Reason, "concurrent job limit")
}
func TestAllowanceServiceCheck_Bad_ModelNotInAllowlist(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
ModelAllowlist: []string{"claude-sonnet-4-5-20250929"},
})
result, err := svc.Check("agent-1", "claude-opus-4-6")
require.NoError(t, err)
assert.False(t, result.Allowed)
assert.Contains(t, result.Reason, "model not in allowlist")
}
func TestAllowanceServiceCheck_Good_ModelInAllowlist(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
ModelAllowlist: []string{"claude-sonnet-4-5-20250929", "claude-haiku-4-5-20251001"},
})
result, err := svc.Check("agent-1", "claude-sonnet-4-5-20250929")
require.NoError(t, err)
assert.True(t, result.Allowed)
}
func TestAllowanceServiceCheck_Good_EmptyModelSkipsCheck(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
ModelAllowlist: []string{"claude-sonnet-4-5-20250929"},
})
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
}
func TestAllowanceServiceCheck_Bad_GlobalModelBudgetExceeded(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
})
store.SetModelQuota(&ModelQuota{
Model: "claude-opus-4-6",
DailyTokenBudget: 500000,
})
_ = store.IncrementModelUsage("claude-opus-4-6", 500001)
result, err := svc.Check("agent-1", "claude-opus-4-6")
require.NoError(t, err)
assert.False(t, result.Allowed)
assert.Contains(t, result.Reason, "global model token budget")
}
func TestAllowanceServiceCheck_Bad_NoAllowance(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_, err := svc.Check("unknown-agent", "")
require.Error(t, err)
}
// --- AllowanceService.RecordUsage tests ---
func TestAllowanceServiceRecordUsage_Good_JobStarted(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
require.NoError(t, err)
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, 1, usage.JobsStarted)
assert.Equal(t, 1, usage.ActiveJobs)
assert.Equal(t, int64(0), usage.TokensUsed)
}
func TestAllowanceServiceRecordUsage_Good_JobCompleted(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
// Start a job first
_ = svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Model: "claude-sonnet",
TokensIn: 1000,
TokensOut: 500,
Event: QuotaEventJobCompleted,
})
require.NoError(t, err)
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, int64(1500), usage.TokensUsed)
assert.Equal(t, 0, usage.ActiveJobs)
modelUsage, _ := store.GetModelUsage("claude-sonnet")
assert.Equal(t, int64(1500), modelUsage)
}
func TestAllowanceServiceRecordUsage_Good_JobFailed_ReturnsHalf(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Event: QuotaEventJobStarted,
})
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Model: "claude-sonnet",
TokensIn: 1000,
TokensOut: 1000,
Event: QuotaEventJobFailed,
})
require.NoError(t, err)
usage, _ := store.GetUsage("agent-1")
// 2000 tokens used, 1000 returned (50%) = 1000 net
assert.Equal(t, int64(1000), usage.TokensUsed)
assert.Equal(t, 0, usage.ActiveJobs)
// Model sees net usage (2000 - 1000 = 1000)
modelUsage, _ := store.GetModelUsage("claude-sonnet")
assert.Equal(t, int64(1000), modelUsage)
}
func TestAllowanceServiceRecordUsage_Good_JobCancelled_ReturnsAll(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.IncrementUsage("agent-1", 5000, 1) // simulate pre-existing usage
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
TokensIn: 500,
TokensOut: 500,
Event: QuotaEventJobCancelled,
})
require.NoError(t, err)
usage, _ := store.GetUsage("agent-1")
// 5000 pre-existing - 1000 returned = 4000
assert.Equal(t, int64(4000), usage.TokensUsed)
assert.Equal(t, 0, usage.ActiveJobs)
}
// --- AllowanceService.ResetAgent tests ---
func TestAllowanceServiceResetAgent_Good(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.IncrementUsage("agent-1", 50000, 5)
err := svc.ResetAgent("agent-1")
require.NoError(t, err)
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, int64(0), usage.TokensUsed)
assert.Equal(t, 0, usage.JobsStarted)
}
// --- startOfDay helper test ---
func TestStartOfDay_Good(t *testing.T) {
input := time.Date(2026, 2, 10, 15, 30, 45, 0, time.UTC)
expected := time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC)
assert.Equal(t, expected, startOfDay(input))
}
// --- AllowanceStatus tests ---
func TestAllowanceStatus_Good_Values(t *testing.T) {
assert.Equal(t, AllowanceStatus("ok"), AllowanceOK)
assert.Equal(t, AllowanceStatus("warning"), AllowanceWarning)
assert.Equal(t, AllowanceStatus("exceeded"), AllowanceExceeded)
}

View file

@ -1,215 +0,0 @@
package lifecycle
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"forge.lthn.ai/core/go-log"
)
// MemoryType represents the classification of a brain memory.
type MemoryType string
const (
MemoryFact MemoryType = "fact"
MemoryDecision MemoryType = "decision"
MemoryPattern MemoryType = "pattern"
MemoryContext MemoryType = "context"
MemoryProcedure MemoryType = "procedure"
)
// Memory represents a single memory entry from the OpenBrain API.
type Memory struct {
ID string `json:"id"`
AgentID string `json:"agent_id,omitempty"`
Type string `json:"type"`
Content string `json:"content"`
Tags []string `json:"tags,omitempty"`
Project string `json:"project,omitempty"`
Confidence float64 `json:"confidence,omitempty"`
SupersedesID string `json:"supersedes_id,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
// RememberRequest is the payload for storing a new memory.
type RememberRequest struct {
Content string `json:"content"`
Type string `json:"type"`
Project string `json:"project,omitempty"`
AgentID string `json:"agent_id,omitempty"`
Tags []string `json:"tags,omitempty"`
Confidence float64 `json:"confidence,omitempty"`
SupersedesID string `json:"supersedes_id,omitempty"`
Source string `json:"source,omitempty"`
}
// RememberResponse is returned after storing a memory.
type RememberResponse struct {
ID string `json:"id"`
Type string `json:"type"`
Project string `json:"project"`
CreatedAt string `json:"created_at"`
}
// RecallRequest is the payload for semantic search.
type RecallRequest struct {
Query string `json:"query"`
TopK int `json:"top_k,omitempty"`
Project string `json:"project,omitempty"`
Type string `json:"type,omitempty"`
AgentID string `json:"agent_id,omitempty"`
MinConfidence *float64 `json:"min_confidence,omitempty"`
}
// RecallResponse is returned from a semantic search.
type RecallResponse struct {
Memories []Memory `json:"memories"`
Scores map[string]float64 `json:"scores"`
}
// Remember stores a memory via POST /v1/brain/remember.
func (c *Client) Remember(ctx context.Context, req RememberRequest) (*RememberResponse, error) {
const op = "agentic.Client.Remember"
if req.Content == "" {
return nil, log.E(op, "content is required", nil)
}
if req.Type == "" {
return nil, log.E(op, "type is required", nil)
}
data, err := json.Marshal(req)
if err != nil {
return nil, log.E(op, "failed to marshal request", err)
}
endpoint := c.BaseURL + "/v1/brain/remember"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(httpReq)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result RememberResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &result, nil
}
// Recall performs semantic search via POST /v1/brain/recall.
func (c *Client) Recall(ctx context.Context, req RecallRequest) (*RecallResponse, error) {
const op = "agentic.Client.Recall"
if req.Query == "" {
return nil, log.E(op, "query is required", nil)
}
data, err := json.Marshal(req)
if err != nil {
return nil, log.E(op, "failed to marshal request", err)
}
endpoint := c.BaseURL + "/v1/brain/recall"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(httpReq)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result RecallResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &result, nil
}
// Forget removes a memory via DELETE /v1/brain/forget/{id}.
func (c *Client) Forget(ctx context.Context, id string) error {
const op = "agentic.Client.Forget"
if id == "" {
return log.E(op, "memory ID is required", nil)
}
endpoint := fmt.Sprintf("%s/v1/brain/forget/%s", c.BaseURL, url.PathEscape(id))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(httpReq)
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return log.E(op, "API error", err)
}
return nil
}
// EnsureCollection ensures the Qdrant collection exists via POST /v1/brain/collections.
func (c *Client) EnsureCollection(ctx context.Context) error {
const op = "agentic.Client.EnsureCollection"
endpoint := c.BaseURL + "/v1/brain/collections"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil)
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(httpReq)
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return log.E(op, "API error", err)
}
return nil
}

View file

@ -1,234 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClient_Remember_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/v1/brain/remember", r.URL.Path)
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
var req RememberRequest
err := json.NewDecoder(r.Body).Decode(&req)
require.NoError(t, err)
assert.Equal(t, "Go uses structural typing", req.Content)
assert.Equal(t, "fact", req.Type)
assert.Equal(t, "go-agentic", req.Project)
assert.Equal(t, []string{"go", "typing"}, req.Tags)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(RememberResponse{
ID: "mem-abc-123",
Type: "fact",
Project: "go-agentic",
CreatedAt: "2026-03-03T12:00:00+00:00",
})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.Remember(context.Background(), RememberRequest{
Content: "Go uses structural typing",
Type: "fact",
Project: "go-agentic",
Tags: []string{"go", "typing"},
})
require.NoError(t, err)
assert.Equal(t, "mem-abc-123", result.ID)
assert.Equal(t, "fact", result.Type)
assert.Equal(t, "go-agentic", result.Project)
}
func TestClient_Remember_Bad_EmptyContent(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
result, err := client.Remember(context.Background(), RememberRequest{
Type: "fact",
})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "content is required")
}
func TestClient_Remember_Bad_EmptyType(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
result, err := client.Remember(context.Background(), RememberRequest{
Content: "something",
})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "type is required")
}
func TestClient_Remember_Bad_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnprocessableEntity)
_ = json.NewEncoder(w).Encode(APIError{Message: "validation failed"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.Remember(context.Background(), RememberRequest{
Content: "test",
Type: "fact",
})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "validation failed")
}
func TestClient_Recall_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/v1/brain/recall", r.URL.Path)
var req RecallRequest
err := json.NewDecoder(r.Body).Decode(&req)
require.NoError(t, err)
assert.Equal(t, "how does typing work in Go", req.Query)
assert.Equal(t, 5, req.TopK)
assert.Equal(t, "go-agentic", req.Project)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(RecallResponse{
Memories: []Memory{
{
ID: "mem-abc-123",
Type: "fact",
Content: "Go uses structural typing",
Project: "go-agentic",
Confidence: 0.95,
},
},
Scores: map[string]float64{
"mem-abc-123": 0.87,
},
})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.Recall(context.Background(), RecallRequest{
Query: "how does typing work in Go",
TopK: 5,
Project: "go-agentic",
})
require.NoError(t, err)
assert.Len(t, result.Memories, 1)
assert.Equal(t, "mem-abc-123", result.Memories[0].ID)
assert.Equal(t, "Go uses structural typing", result.Memories[0].Content)
assert.InDelta(t, 0.87, result.Scores["mem-abc-123"], 0.001)
}
func TestClient_Recall_Good_EmptyResults(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(RecallResponse{
Memories: []Memory{},
Scores: map[string]float64{},
})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.Recall(context.Background(), RecallRequest{
Query: "something obscure",
})
require.NoError(t, err)
assert.Empty(t, result.Memories)
assert.Empty(t, result.Scores)
}
func TestClient_Recall_Bad_EmptyQuery(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
result, err := client.Recall(context.Background(), RecallRequest{})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "query is required")
}
func TestClient_Forget_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodDelete, r.Method)
assert.Equal(t, "/v1/brain/forget/mem-abc-123", r.URL.Path)
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]bool{"deleted": true})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.Forget(context.Background(), "mem-abc-123")
assert.NoError(t, err)
}
func TestClient_Forget_Bad_EmptyID(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
err := client.Forget(context.Background(), "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "memory ID is required")
}
func TestClient_Forget_Bad_NotFound(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_ = json.NewEncoder(w).Encode(APIError{Message: "memory not found"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.Forget(context.Background(), "nonexistent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "memory not found")
}
func TestClient_EnsureCollection_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/v1/brain/collections", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.EnsureCollection(context.Background())
assert.NoError(t, err)
}
func TestClient_EnsureCollection_Bad_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(APIError{Message: "collection setup failed"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.EnsureCollection(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "collection setup failed")
}

View file

@ -1,359 +0,0 @@
package lifecycle
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"forge.lthn.ai/core/go-log"
)
// Client is the API client for the core-agentic service.
type Client struct {
// BaseURL is the base URL of the API server.
BaseURL string
// Token is the authentication token.
Token string
// HTTPClient is the HTTP client used for requests.
HTTPClient *http.Client
// AgentID is the identifier for this agent when claiming tasks.
AgentID string
}
// NewClient creates a new agentic API client with the given base URL and token.
func NewClient(baseURL, token string) *Client {
return &Client{
BaseURL: strings.TrimSuffix(baseURL, "/"),
Token: token,
HTTPClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// NewClientFromConfig creates a new client from a Config struct.
func NewClientFromConfig(cfg *Config) *Client {
client := NewClient(cfg.BaseURL, cfg.Token)
client.AgentID = cfg.AgentID
return client
}
// ListTasks retrieves a list of tasks matching the given options.
func (c *Client) ListTasks(ctx context.Context, opts ListOptions) ([]Task, error) {
const op = "agentic.Client.ListTasks"
// Build query parameters
params := url.Values{}
if opts.Status != "" {
params.Set("status", string(opts.Status))
}
if opts.Priority != "" {
params.Set("priority", string(opts.Priority))
}
if opts.Project != "" {
params.Set("project", opts.Project)
}
if opts.ClaimedBy != "" {
params.Set("claimed_by", opts.ClaimedBy)
}
if opts.Limit > 0 {
params.Set("limit", strconv.Itoa(opts.Limit))
}
if len(opts.Labels) > 0 {
params.Set("labels", strings.Join(opts.Labels, ","))
}
endpoint := c.BaseURL + "/api/tasks"
if len(params) > 0 {
endpoint += "?" + params.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var tasks []Task
if err := json.NewDecoder(resp.Body).Decode(&tasks); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return tasks, nil
}
// GetTask retrieves a single task by its ID.
func (c *Client) GetTask(ctx context.Context, id string) (*Task, error) {
const op = "agentic.Client.GetTask"
if id == "" {
return nil, log.E(op, "task ID is required", nil)
}
endpoint := fmt.Sprintf("%s/api/tasks/%s", c.BaseURL, url.PathEscape(id))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var task Task
if err := json.NewDecoder(resp.Body).Decode(&task); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &task, nil
}
// ClaimTask claims a task for the current agent.
func (c *Client) ClaimTask(ctx context.Context, id string) (*Task, error) {
const op = "agentic.Client.ClaimTask"
if id == "" {
return nil, log.E(op, "task ID is required", nil)
}
endpoint := fmt.Sprintf("%s/api/tasks/%s/claim", c.BaseURL, url.PathEscape(id))
// Include agent ID in the claim request if available
var body io.Reader
if c.AgentID != "" {
data, _ := json.Marshal(map[string]string{"agent_id": c.AgentID})
body = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
// Read body once to allow multiple decode attempts
bodyData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, log.E(op, "failed to read response", err)
}
// Try decoding as ClaimResponse first
var result ClaimResponse
if err := json.Unmarshal(bodyData, &result); err == nil && result.Task != nil {
return result.Task, nil
}
// Try decoding as just a Task for simpler API responses
var task Task
if err := json.Unmarshal(bodyData, &task); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &task, nil
}
// UpdateTask updates a task with new status, progress, or notes.
func (c *Client) UpdateTask(ctx context.Context, id string, update TaskUpdate) error {
const op = "agentic.Client.UpdateTask"
if id == "" {
return log.E(op, "task ID is required", nil)
}
endpoint := fmt.Sprintf("%s/api/tasks/%s", c.BaseURL, url.PathEscape(id))
data, err := json.Marshal(update)
if err != nil {
return log.E(op, "failed to marshal update", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(data))
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return log.E(op, "API error", err)
}
return nil
}
// CompleteTask marks a task as completed with the given result.
func (c *Client) CompleteTask(ctx context.Context, id string, result TaskResult) error {
const op = "agentic.Client.CompleteTask"
if id == "" {
return log.E(op, "task ID is required", nil)
}
endpoint := fmt.Sprintf("%s/api/tasks/%s/complete", c.BaseURL, url.PathEscape(id))
data, err := json.Marshal(result)
if err != nil {
return log.E(op, "failed to marshal result", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return log.E(op, "API error", err)
}
return nil
}
// setHeaders adds common headers to the request.
func (c *Client) setHeaders(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+c.Token)
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "core-agentic-client/1.0")
}
// checkResponse checks if the response indicates an error.
func (c *Client) checkResponse(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
body, _ := io.ReadAll(resp.Body)
// Try to parse as APIError
var apiErr APIError
if err := json.Unmarshal(body, &apiErr); err == nil && apiErr.Message != "" {
apiErr.Code = resp.StatusCode
return &apiErr
}
// Return generic error
return &APIError{
Code: resp.StatusCode,
Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, http.StatusText(resp.StatusCode)),
Details: string(body),
}
}
// CreateTask creates a new task via POST /api/tasks.
func (c *Client) CreateTask(ctx context.Context, task Task) (*Task, error) {
const op = "agentic.Client.CreateTask"
data, err := json.Marshal(task)
if err != nil {
return nil, log.E(op, "failed to marshal task", err)
}
endpoint := c.BaseURL + "/api/tasks"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var created Task
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &created, nil
}
// Ping tests the connection to the API server.
func (c *Client) Ping(ctx context.Context) error {
const op = "agentic.Client.Ping"
endpoint := c.BaseURL + "/v1/health"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
return log.E(op, fmt.Sprintf("server returned status %d", resp.StatusCode), nil)
}
return nil
}

View file

@ -1,356 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test fixtures
var testTask = Task{
ID: "task-123",
Title: "Implement feature X",
Description: "Add the new feature X to the system",
Priority: PriorityHigh,
Status: StatusPending,
Labels: []string{"feature", "backend"},
Files: []string{"pkg/feature/feature.go"},
CreatedAt: time.Now().Add(-24 * time.Hour),
Project: "core",
}
var testTasks = []Task{
testTask,
{
ID: "task-456",
Title: "Fix bug Y",
Description: "Fix the bug in component Y",
Priority: PriorityCritical,
Status: StatusPending,
Labels: []string{"bug", "urgent"},
CreatedAt: time.Now().Add(-2 * time.Hour),
Project: "core",
},
}
func TestNewClient_Good(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
assert.Equal(t, "https://api.example.com", client.BaseURL)
assert.Equal(t, "test-token", client.Token)
assert.NotNil(t, client.HTTPClient)
}
func TestNewClient_Good_TrailingSlash(t *testing.T) {
client := NewClient("https://api.example.com/", "test-token")
assert.Equal(t, "https://api.example.com", client.BaseURL)
}
func TestNewClientFromConfig_Good(t *testing.T) {
cfg := &Config{
BaseURL: "https://api.example.com",
Token: "config-token",
AgentID: "agent-001",
}
client := NewClientFromConfig(cfg)
assert.Equal(t, "https://api.example.com", client.BaseURL)
assert.Equal(t, "config-token", client.Token)
assert.Equal(t, "agent-001", client.AgentID)
}
func TestClient_ListTasks_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/api/tasks", r.URL.Path)
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(testTasks)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
tasks, err := client.ListTasks(context.Background(), ListOptions{})
require.NoError(t, err)
assert.Len(t, tasks, 2)
assert.Equal(t, "task-123", tasks[0].ID)
assert.Equal(t, "task-456", tasks[1].ID)
}
func TestClient_ListTasks_Good_WithFilters(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
assert.Equal(t, "pending", query.Get("status"))
assert.Equal(t, "high", query.Get("priority"))
assert.Equal(t, "core", query.Get("project"))
assert.Equal(t, "10", query.Get("limit"))
assert.Equal(t, "bug,urgent", query.Get("labels"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode([]Task{testTask})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
opts := ListOptions{
Status: StatusPending,
Priority: PriorityHigh,
Project: "core",
Limit: 10,
Labels: []string{"bug", "urgent"},
}
tasks, err := client.ListTasks(context.Background(), opts)
require.NoError(t, err)
assert.Len(t, tasks, 1)
}
func TestClient_ListTasks_Bad_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(APIError{Message: "internal error"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
tasks, err := client.ListTasks(context.Background(), ListOptions{})
assert.Error(t, err)
assert.Nil(t, tasks)
assert.Contains(t, err.Error(), "internal error")
}
func TestClient_GetTask_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/api/tasks/task-123", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(testTask)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task, err := client.GetTask(context.Background(), "task-123")
require.NoError(t, err)
assert.Equal(t, "task-123", task.ID)
assert.Equal(t, "Implement feature X", task.Title)
assert.Equal(t, PriorityHigh, task.Priority)
}
func TestClient_GetTask_Bad_EmptyID(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
task, err := client.GetTask(context.Background(), "")
assert.Error(t, err)
assert.Nil(t, task)
assert.Contains(t, err.Error(), "task ID is required")
}
func TestClient_GetTask_Bad_NotFound(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_ = json.NewEncoder(w).Encode(APIError{Message: "task not found"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task, err := client.GetTask(context.Background(), "nonexistent")
assert.Error(t, err)
assert.Nil(t, task)
assert.Contains(t, err.Error(), "task not found")
}
func TestClient_ClaimTask_Good(t *testing.T) {
claimedTask := testTask
claimedTask.Status = StatusInProgress
claimedTask.ClaimedBy = "agent-001"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/api/tasks/task-123/claim", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimedTask})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.AgentID = "agent-001"
task, err := client.ClaimTask(context.Background(), "task-123")
require.NoError(t, err)
assert.Equal(t, StatusInProgress, task.Status)
assert.Equal(t, "agent-001", task.ClaimedBy)
}
func TestClient_ClaimTask_Good_SimpleResponse(t *testing.T) {
// Some APIs might return just the task without wrapping
claimedTask := testTask
claimedTask.Status = StatusInProgress
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(claimedTask)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task, err := client.ClaimTask(context.Background(), "task-123")
require.NoError(t, err)
assert.Equal(t, "task-123", task.ID)
}
func TestClient_ClaimTask_Bad_EmptyID(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
task, err := client.ClaimTask(context.Background(), "")
assert.Error(t, err)
assert.Nil(t, task)
assert.Contains(t, err.Error(), "task ID is required")
}
func TestClient_ClaimTask_Bad_AlreadyClaimed(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusConflict)
_ = json.NewEncoder(w).Encode(APIError{Message: "task already claimed"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task, err := client.ClaimTask(context.Background(), "task-123")
assert.Error(t, err)
assert.Nil(t, task)
assert.Contains(t, err.Error(), "task already claimed")
}
func TestClient_UpdateTask_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPatch, r.Method)
assert.Equal(t, "/api/tasks/task-123", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
var update TaskUpdate
err := json.NewDecoder(r.Body).Decode(&update)
require.NoError(t, err)
assert.Equal(t, StatusInProgress, update.Status)
assert.Equal(t, 50, update.Progress)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.UpdateTask(context.Background(), "task-123", TaskUpdate{
Status: StatusInProgress,
Progress: 50,
Notes: "Making progress",
})
assert.NoError(t, err)
}
func TestClient_UpdateTask_Bad_EmptyID(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
err := client.UpdateTask(context.Background(), "", TaskUpdate{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "task ID is required")
}
func TestClient_CompleteTask_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/api/tasks/task-123/complete", r.URL.Path)
var result TaskResult
err := json.NewDecoder(r.Body).Decode(&result)
require.NoError(t, err)
assert.True(t, result.Success)
assert.Equal(t, "Feature implemented", result.Output)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.CompleteTask(context.Background(), "task-123", TaskResult{
Success: true,
Output: "Feature implemented",
Artifacts: []string{"pkg/feature/feature.go"},
})
assert.NoError(t, err)
}
func TestClient_CompleteTask_Bad_EmptyID(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
err := client.CompleteTask(context.Background(), "", TaskResult{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "task ID is required")
}
func TestClient_Ping_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/health", r.URL.Path)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.Ping(context.Background())
assert.NoError(t, err)
}
func TestClient_Ping_Bad_ServerDown(t *testing.T) {
client := NewClient("http://localhost:99999", "test-token")
client.HTTPClient.Timeout = 100 * time.Millisecond
err := client.Ping(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "request failed")
}
func TestAPIError_Error_Good(t *testing.T) {
err := &APIError{
Code: 404,
Message: "task not found",
}
assert.Equal(t, "task not found", err.Error())
err.Details = "task-123 does not exist"
assert.Equal(t, "task not found: task-123 does not exist", err.Error())
}
func TestTaskStatus_Good(t *testing.T) {
assert.Equal(t, TaskStatus("pending"), StatusPending)
assert.Equal(t, TaskStatus("in_progress"), StatusInProgress)
assert.Equal(t, TaskStatus("completed"), StatusCompleted)
assert.Equal(t, TaskStatus("blocked"), StatusBlocked)
}
func TestTaskPriority_Good(t *testing.T) {
assert.Equal(t, TaskPriority("critical"), PriorityCritical)
assert.Equal(t, TaskPriority("high"), PriorityHigh)
assert.Equal(t, TaskPriority("medium"), PriorityMedium)
assert.Equal(t, TaskPriority("low"), PriorityLow)
}

View file

@ -1,338 +0,0 @@
// Package agentic provides AI collaboration features for task management.
package lifecycle
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"forge.lthn.ai/core/go-log"
)
// PROptions contains options for creating a pull request.
type PROptions struct {
// Title is the PR title.
Title string `json:"title"`
// Body is the PR description.
Body string `json:"body"`
// Draft marks the PR as a draft.
Draft bool `json:"draft"`
// Labels are labels to add to the PR.
Labels []string `json:"labels"`
// Base is the base branch (defaults to main).
Base string `json:"base"`
}
// AutoCommit creates a git commit with a task reference.
// The commit message follows the format:
//
// feat(scope): description
//
// Task: #123
// Co-Authored-By: Claude <noreply@anthropic.com>
func AutoCommit(ctx context.Context, task *Task, dir string, message string) error {
const op = "agentic.AutoCommit"
if task == nil {
return log.E(op, "task is required", nil)
}
if message == "" {
return log.E(op, "commit message is required", nil)
}
// Build full commit message
fullMessage := buildCommitMessage(task, message)
// Stage all changes
if _, err := runGitCommandCtx(ctx, dir, "add", "-A"); err != nil {
return log.E(op, "failed to stage changes", err)
}
// Create commit
if _, err := runGitCommandCtx(ctx, dir, "commit", "-m", fullMessage); err != nil {
return log.E(op, "failed to create commit", err)
}
return nil
}
// buildCommitMessage formats a commit message with task reference.
func buildCommitMessage(task *Task, message string) string {
var sb strings.Builder
// Write the main message
sb.WriteString(message)
sb.WriteString("\n\n")
// Add task reference
sb.WriteString("Task: #")
sb.WriteString(task.ID)
sb.WriteString("\n")
// Add co-author
sb.WriteString("Co-Authored-By: Claude <noreply@anthropic.com>\n")
return sb.String()
}
// CreatePR creates a pull request using the gh CLI.
func CreatePR(ctx context.Context, task *Task, dir string, opts PROptions) (string, error) {
const op = "agentic.CreatePR"
if task == nil {
return "", log.E(op, "task is required", nil)
}
// Build title if not provided
title := opts.Title
if title == "" {
title = task.Title
}
// Build body if not provided
body := opts.Body
if body == "" {
body = buildPRBody(task)
}
// Build gh command arguments
args := []string{"pr", "create", "--title", title, "--body", body}
if opts.Draft {
args = append(args, "--draft")
}
if opts.Base != "" {
args = append(args, "--base", opts.Base)
}
for _, label := range opts.Labels {
args = append(args, "--label", label)
}
// Run gh pr create
output, err := runCommandCtx(ctx, dir, "gh", args...)
if err != nil {
return "", log.E(op, "failed to create PR", err)
}
// Extract PR URL from output
prURL := strings.TrimSpace(output)
return prURL, nil
}
// buildPRBody creates a PR body from task details.
func buildPRBody(task *Task) string {
var sb strings.Builder
sb.WriteString("## Summary\n\n")
sb.WriteString(task.Description)
sb.WriteString("\n\n")
sb.WriteString("## Task Reference\n\n")
sb.WriteString("- Task ID: #")
sb.WriteString(task.ID)
sb.WriteString("\n")
sb.WriteString("- Priority: ")
sb.WriteString(string(task.Priority))
sb.WriteString("\n")
if len(task.Labels) > 0 {
sb.WriteString("- Labels: ")
sb.WriteString(strings.Join(task.Labels, ", "))
sb.WriteString("\n")
}
sb.WriteString("\n---\n")
sb.WriteString("Generated with AI assistance\n")
return sb.String()
}
// SyncStatus syncs the task status back to the agentic service.
func SyncStatus(ctx context.Context, client *Client, task *Task, update TaskUpdate) error {
const op = "agentic.SyncStatus"
if client == nil {
return log.E(op, "client is required", nil)
}
if task == nil {
return log.E(op, "task is required", nil)
}
return client.UpdateTask(ctx, task.ID, update)
}
// CommitAndSync commits changes and syncs task status.
func CommitAndSync(ctx context.Context, client *Client, task *Task, dir string, message string, progress int) error {
const op = "agentic.CommitAndSync"
// Create commit
if err := AutoCommit(ctx, task, dir, message); err != nil {
return log.E(op, "failed to commit", err)
}
// Sync status if client provided
if client != nil {
update := TaskUpdate{
Status: StatusInProgress,
Progress: progress,
Notes: "Committed: " + message,
}
if err := SyncStatus(ctx, client, task, update); err != nil {
// Log but don't fail on sync errors
return log.E(op, "commit succeeded but sync failed", err)
}
}
return nil
}
// PushChanges pushes committed changes to the remote.
func PushChanges(ctx context.Context, dir string) error {
const op = "agentic.PushChanges"
_, err := runGitCommandCtx(ctx, dir, "push")
if err != nil {
return log.E(op, "failed to push changes", err)
}
return nil
}
// CreateBranch creates a new branch for the task.
func CreateBranch(ctx context.Context, task *Task, dir string) (string, error) {
const op = "agentic.CreateBranch"
if task == nil {
return "", log.E(op, "task is required", nil)
}
// Generate branch name from task
branchName := generateBranchName(task)
// Create and checkout branch
_, err := runGitCommandCtx(ctx, dir, "checkout", "-b", branchName)
if err != nil {
return "", log.E(op, "failed to create branch", err)
}
return branchName, nil
}
// generateBranchName creates a branch name from task details.
func generateBranchName(task *Task) string {
// Determine prefix based on labels
prefix := "feat"
for _, label := range task.Labels {
switch strings.ToLower(label) {
case "bug", "bugfix", "fix":
prefix = "fix"
case "docs", "documentation":
prefix = "docs"
case "refactor":
prefix = "refactor"
case "test", "tests":
prefix = "test"
case "chore":
prefix = "chore"
}
}
// Sanitize title for branch name
title := strings.ToLower(task.Title)
title = strings.Map(func(r rune) rune {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') {
return r
}
if r == ' ' || r == '-' || r == '_' {
return '-'
}
return -1
}, title)
// Remove consecutive dashes
for strings.Contains(title, "--") {
title = strings.ReplaceAll(title, "--", "-")
}
title = strings.Trim(title, "-")
// Truncate if too long
if len(title) > 40 {
title = title[:40]
title = strings.TrimRight(title, "-")
}
return fmt.Sprintf("%s/%s-%s", prefix, task.ID, title)
}
// runGitCommandCtx runs a git command with context.
func runGitCommandCtx(ctx context.Context, dir string, args ...string) (string, error) {
return runCommandCtx(ctx, dir, "git", args...)
}
// runCommandCtx runs an arbitrary command with context.
func runCommandCtx(ctx context.Context, dir string, command string, args ...string) (string, error) {
cmd := exec.CommandContext(ctx, command, args...)
cmd.Dir = dir
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if stderr.Len() > 0 {
return "", log.E("runCommandCtx", stderr.String(), err)
}
return "", err
}
return stdout.String(), nil
}
// GetCurrentBranch returns the current git branch name.
func GetCurrentBranch(ctx context.Context, dir string) (string, error) {
const op = "agentic.GetCurrentBranch"
output, err := runGitCommandCtx(ctx, dir, "rev-parse", "--abbrev-ref", "HEAD")
if err != nil {
return "", log.E(op, "failed to get current branch", err)
}
return strings.TrimSpace(output), nil
}
// HasUncommittedChanges checks if there are uncommitted changes.
func HasUncommittedChanges(ctx context.Context, dir string) (bool, error) {
const op = "agentic.HasUncommittedChanges"
output, err := runGitCommandCtx(ctx, dir, "status", "--porcelain")
if err != nil {
return false, log.E(op, "failed to get git status", err)
}
return strings.TrimSpace(output) != "", nil
}
// GetDiff returns the current diff for staged and unstaged changes.
func GetDiff(ctx context.Context, dir string, staged bool) (string, error) {
const op = "agentic.GetDiff"
args := []string{"diff"}
if staged {
args = append(args, "--staged")
}
output, err := runGitCommandCtx(ctx, dir, args...)
if err != nil {
return "", log.E(op, "failed to get diff", err)
}
return output, nil
}

View file

@ -1,474 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// initGitRepo creates a temporary git repo with an initial commit.
func initGitRepo(t *testing.T) string {
t.Helper()
dir := t.TempDir()
// Initialise a git repo.
_, err := runCommandCtx(context.Background(), dir, "git", "init")
require.NoError(t, err, "git init should succeed")
// Configure git identity for commits.
_, err = runCommandCtx(context.Background(), dir, "git", "config", "user.email", "test@example.com")
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "config", "user.name", "Test User")
require.NoError(t, err)
// Create initial commit so HEAD exists.
readmePath := filepath.Join(dir, "README.md")
err = os.WriteFile(readmePath, []byte("# Test Repo\n"), 0644)
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "add", "-A")
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "initial commit")
require.NoError(t, err)
return dir
}
// --- runCommandCtx / runGitCommandCtx tests ---
func TestRunCommandCtx_Good(t *testing.T) {
output, err := runCommandCtx(context.Background(), "/tmp", "echo", "hello world")
require.NoError(t, err)
assert.Contains(t, output, "hello world")
}
func TestRunCommandCtx_Bad_NonexistentCommand(t *testing.T) {
_, err := runCommandCtx(context.Background(), "/tmp", "nonexistent-command-xyz")
assert.Error(t, err)
}
func TestRunCommandCtx_Bad_CommandFails(t *testing.T) {
_, err := runCommandCtx(context.Background(), "/tmp", "false")
assert.Error(t, err)
}
func TestRunCommandCtx_Bad_StderrIncluded(t *testing.T) {
// git status in a non-git directory should produce stderr.
dir := t.TempDir()
_, err := runCommandCtx(context.Background(), dir, "git", "status")
assert.Error(t, err)
}
func TestRunGitCommandCtx_Good(t *testing.T) {
dir := initGitRepo(t)
output, err := runGitCommandCtx(context.Background(), dir, "log", "--oneline", "-1")
require.NoError(t, err)
assert.Contains(t, output, "initial commit")
}
// --- GetCurrentBranch tests ---
func TestGetCurrentBranch_Good(t *testing.T) {
dir := initGitRepo(t)
branch, err := GetCurrentBranch(context.Background(), dir)
require.NoError(t, err)
// Depending on git config, default branch could be master or main.
assert.True(t, branch == "main" || branch == "master",
"expected main or master, got %q", branch)
}
func TestGetCurrentBranch_Bad_NotAGitRepo(t *testing.T) {
dir := t.TempDir()
_, err := GetCurrentBranch(context.Background(), dir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get current branch")
}
// --- HasUncommittedChanges tests ---
func TestHasUncommittedChanges_Good_Clean(t *testing.T) {
dir := initGitRepo(t)
hasChanges, err := HasUncommittedChanges(context.Background(), dir)
require.NoError(t, err)
assert.False(t, hasChanges, "fresh repo with initial commit should be clean")
}
func TestHasUncommittedChanges_Good_WithChanges(t *testing.T) {
dir := initGitRepo(t)
// Create a new file.
err := os.WriteFile(filepath.Join(dir, "new-file.txt"), []byte("content"), 0644)
require.NoError(t, err)
hasChanges, err := HasUncommittedChanges(context.Background(), dir)
require.NoError(t, err)
assert.True(t, hasChanges, "should detect untracked file")
}
func TestHasUncommittedChanges_Good_WithModifiedFile(t *testing.T) {
dir := initGitRepo(t)
// Modify the existing README.
err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("# Updated\n"), 0644)
require.NoError(t, err)
hasChanges, err := HasUncommittedChanges(context.Background(), dir)
require.NoError(t, err)
assert.True(t, hasChanges, "should detect modified file")
}
func TestHasUncommittedChanges_Bad_NotAGitRepo(t *testing.T) {
dir := t.TempDir()
_, err := HasUncommittedChanges(context.Background(), dir)
assert.Error(t, err)
}
// --- GetDiff tests ---
func TestGetDiff_Good_Unstaged(t *testing.T) {
dir := initGitRepo(t)
// Modify a tracked file.
err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("# Modified\n"), 0644)
require.NoError(t, err)
diff, err := GetDiff(context.Background(), dir, false)
require.NoError(t, err)
assert.Contains(t, diff, "Modified", "diff should show the change")
}
func TestGetDiff_Good_Staged(t *testing.T) {
dir := initGitRepo(t)
// Modify and stage a file.
err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("# Staged change\n"), 0644)
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "add", "README.md")
require.NoError(t, err)
diff, err := GetDiff(context.Background(), dir, true)
require.NoError(t, err)
assert.Contains(t, diff, "Staged change", "staged diff should show the change")
}
func TestGetDiff_Good_NoDiff(t *testing.T) {
dir := initGitRepo(t)
diff, err := GetDiff(context.Background(), dir, false)
require.NoError(t, err)
assert.Empty(t, diff, "clean repo should have no diff")
}
func TestGetDiff_Bad_NotAGitRepo(t *testing.T) {
dir := t.TempDir()
_, err := GetDiff(context.Background(), dir, false)
assert.Error(t, err)
}
// --- AutoCommit tests (with real git) ---
func TestAutoCommit_Good(t *testing.T) {
dir := initGitRepo(t)
// Create a file to commit.
err := os.WriteFile(filepath.Join(dir, "feature.go"), []byte("package main\n"), 0644)
require.NoError(t, err)
task := &Task{ID: "T-100", Title: "Add feature"}
err = AutoCommit(context.Background(), task, dir, "feat: add feature module")
require.NoError(t, err)
// Verify commit was created.
output, err := runGitCommandCtx(context.Background(), dir, "log", "--oneline", "-1")
require.NoError(t, err)
assert.Contains(t, output, "feat: add feature module")
// Verify task reference in full message.
fullLog, err := runGitCommandCtx(context.Background(), dir, "log", "-1", "--pretty=format:%B")
require.NoError(t, err)
assert.Contains(t, fullLog, "Task: #T-100")
assert.Contains(t, fullLog, "Co-Authored-By: Claude <noreply@anthropic.com>")
}
func TestAutoCommit_Bad_NoChangesToCommit(t *testing.T) {
dir := initGitRepo(t)
// No changes to commit.
task := &Task{ID: "T-200", Title: "No changes"}
err := AutoCommit(context.Background(), task, dir, "feat: nothing")
assert.Error(t, err, "should fail when there is nothing to commit")
}
// --- CreateBranch tests (with real git) ---
func TestCreateBranch_Good(t *testing.T) {
dir := initGitRepo(t)
task := &Task{
ID: "BR-42",
Title: "Implement new feature",
Labels: []string{"enhancement"},
}
branchName, err := CreateBranch(context.Background(), task, dir)
require.NoError(t, err)
assert.Equal(t, "feat/BR-42-implement-new-feature", branchName)
// Verify we're on the new branch.
currentBranch, err := GetCurrentBranch(context.Background(), dir)
require.NoError(t, err)
assert.Equal(t, branchName, currentBranch)
}
func TestCreateBranch_Good_BugLabel(t *testing.T) {
dir := initGitRepo(t)
task := &Task{
ID: "BR-43",
Title: "Fix login bug",
Labels: []string{"bug"},
}
branchName, err := CreateBranch(context.Background(), task, dir)
require.NoError(t, err)
assert.Equal(t, "fix/BR-43-fix-login-bug", branchName)
}
// --- PushChanges test ---
func TestPushChanges_Bad_NoRemote(t *testing.T) {
dir := initGitRepo(t)
// No remote configured, push should fail.
err := PushChanges(context.Background(), dir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to push changes")
}
// --- CommitAndSync tests ---
func TestCommitAndSync_Good_WithoutClient(t *testing.T) {
dir := initGitRepo(t)
// Create a file to commit.
err := os.WriteFile(filepath.Join(dir, "sync.go"), []byte("package sync\n"), 0644)
require.NoError(t, err)
task := &Task{ID: "CS-1", Title: "Sync test"}
// nil client: should commit but skip sync.
err = CommitAndSync(context.Background(), nil, task, dir, "feat: sync test", 50)
require.NoError(t, err)
// Verify commit.
output, err := runGitCommandCtx(context.Background(), dir, "log", "--oneline", "-1")
require.NoError(t, err)
assert.Contains(t, output, "feat: sync test")
}
func TestCommitAndSync_Good_WithClient(t *testing.T) {
dir := initGitRepo(t)
// Create a file to commit.
err := os.WriteFile(filepath.Join(dir, "synced.go"), []byte("package synced\n"), 0644)
require.NoError(t, err)
var receivedUpdate TaskUpdate
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPatch {
_ = json.NewDecoder(r.Body).Decode(&receivedUpdate)
w.WriteHeader(http.StatusOK)
}
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task := &Task{ID: "CS-2", Title: "Sync with client"}
err = CommitAndSync(context.Background(), client, task, dir, "feat: synced", 75)
require.NoError(t, err)
// Verify the update was sent.
assert.Equal(t, StatusInProgress, receivedUpdate.Status)
assert.Equal(t, 75, receivedUpdate.Progress)
assert.Contains(t, receivedUpdate.Notes, "feat: synced")
}
func TestCommitAndSync_Bad_CommitFails(t *testing.T) {
dir := initGitRepo(t)
// No changes to commit.
task := &Task{ID: "CS-3", Title: "Will fail"}
err := CommitAndSync(context.Background(), nil, task, dir, "feat: no changes", 50)
assert.Error(t, err, "should fail when commit fails")
}
func TestCommitAndSync_Bad_SyncFails(t *testing.T) {
dir := initGitRepo(t)
// Create a file to commit.
err := os.WriteFile(filepath.Join(dir, "fail-sync.go"), []byte("package failsync\n"), 0644)
require.NoError(t, err)
// Server returns an error.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(APIError{Message: "sync failed"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task := &Task{ID: "CS-4", Title: "Sync fails"}
err = CommitAndSync(context.Background(), client, task, dir, "feat: sync-fail", 50)
assert.Error(t, err, "should report sync failure")
assert.Contains(t, err.Error(), "sync failed")
}
// --- SyncStatus with working client ---
func TestSyncStatus_Good(t *testing.T) {
var receivedUpdate TaskUpdate
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewDecoder(r.Body).Decode(&receivedUpdate)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task := &Task{ID: "sync-1", Title: "Sync test"}
err := SyncStatus(context.Background(), client, task, TaskUpdate{
Status: StatusCompleted,
Progress: 100,
Notes: "All done",
})
require.NoError(t, err)
assert.Equal(t, StatusCompleted, receivedUpdate.Status)
assert.Equal(t, 100, receivedUpdate.Progress)
}
// --- CreatePR with default title/body ---
func TestCreatePR_Good_DefaultTitleFromTask(t *testing.T) {
// CreatePR requires gh CLI which may not be available.
// Test the option building logic by checking that the title
// defaults to the task title.
task := &Task{
ID: "PR-1",
Title: "Add authentication",
Description: "OAuth2 login",
Priority: PriorityHigh,
}
opts := PROptions{}
// Verify the defaulting logic that would be used.
title := opts.Title
if title == "" {
title = task.Title
}
assert.Equal(t, "Add authentication", title)
body := opts.Body
if body == "" {
body = buildPRBody(task)
}
assert.Contains(t, body, "OAuth2 login")
}
func TestCreatePR_Good_CustomOptions(t *testing.T) {
opts := PROptions{
Title: "Custom title",
Body: "Custom body",
Draft: true,
Labels: []string{"enhancement", "v2"},
Base: "develop",
}
assert.Equal(t, "Custom title", opts.Title)
assert.True(t, opts.Draft)
assert.Equal(t, "develop", opts.Base)
assert.Len(t, opts.Labels, 2)
}
// --- Client checkResponse edge cases ---
func TestClient_CheckResponse_Good_GenericError(t *testing.T) {
// Test checkResponse with a non-JSON error body.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte("plain text error"))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
_, err := client.GetTask(context.Background(), "test-task")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Bad Gateway")
}
func TestClient_CheckResponse_Good_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
_, err := client.GetTask(context.Background(), "test-task")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Forbidden")
}
// --- Client Ping edge case ---
func TestClient_Ping_Bad_ServerReturns4xx(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.Ping(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "status 401")
}
// --- Client ClaimTask without AgentID ---
func TestClient_ClaimTask_Good_NoAgentID(t *testing.T) {
claimedTask := Task{
ID: "task-no-agent",
Status: StatusInProgress,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify no body sent when AgentID is empty.
assert.Equal(t, http.MethodPost, r.Method)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(claimedTask)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
// Explicitly leave AgentID empty.
client.AgentID = ""
task, err := client.ClaimTask(context.Background(), "task-no-agent")
require.NoError(t, err)
assert.Equal(t, "task-no-agent", task.ID)
}

View file

@ -1,199 +0,0 @@
package lifecycle
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBuildCommitMessage(t *testing.T) {
task := &Task{
ID: "ABC123",
Title: "Test Task",
}
message := buildCommitMessage(task, "add new feature")
assert.Contains(t, message, "add new feature")
assert.Contains(t, message, "Task: #ABC123")
assert.Contains(t, message, "Co-Authored-By: Claude <noreply@anthropic.com>")
}
func TestBuildPRBody(t *testing.T) {
task := &Task{
ID: "PR-456",
Title: "Add authentication",
Description: "Implement user authentication with OAuth2",
Priority: PriorityHigh,
Labels: []string{"enhancement", "security"},
}
body := buildPRBody(task)
assert.Contains(t, body, "## Summary")
assert.Contains(t, body, "Implement user authentication with OAuth2")
assert.Contains(t, body, "## Task Reference")
assert.Contains(t, body, "Task ID: #PR-456")
assert.Contains(t, body, "Priority: high")
assert.Contains(t, body, "Labels: enhancement, security")
assert.Contains(t, body, "Generated with AI assistance")
}
func TestBuildPRBody_NoLabels(t *testing.T) {
task := &Task{
ID: "PR-789",
Title: "Fix bug",
Description: "Fix the login bug",
Priority: PriorityMedium,
Labels: nil,
}
body := buildPRBody(task)
assert.Contains(t, body, "## Summary")
assert.Contains(t, body, "Fix the login bug")
assert.NotContains(t, body, "Labels:")
}
func TestGenerateBranchName(t *testing.T) {
tests := []struct {
name string
task *Task
expected string
}{
{
name: "feature task",
task: &Task{
ID: "123",
Title: "Add user authentication",
Labels: []string{"enhancement"},
},
expected: "feat/123-add-user-authentication",
},
{
name: "bug fix task",
task: &Task{
ID: "456",
Title: "Fix login error",
Labels: []string{"bug"},
},
expected: "fix/456-fix-login-error",
},
{
name: "docs task",
task: &Task{
ID: "789",
Title: "Update README",
Labels: []string{"documentation"},
},
expected: "docs/789-update-readme",
},
{
name: "refactor task",
task: &Task{
ID: "101",
Title: "Refactor auth module",
Labels: []string{"refactor"},
},
expected: "refactor/101-refactor-auth-module",
},
{
name: "test task",
task: &Task{
ID: "202",
Title: "Add unit tests",
Labels: []string{"test"},
},
expected: "test/202-add-unit-tests",
},
{
name: "chore task",
task: &Task{
ID: "303",
Title: "Update dependencies",
Labels: []string{"chore"},
},
expected: "chore/303-update-dependencies",
},
{
name: "long title truncated",
task: &Task{
ID: "404",
Title: "This is a very long title that should be truncated to fit the branch name limit",
Labels: nil,
},
expected: "feat/404-this-is-a-very-long-title-that-should-be",
},
{
name: "special characters removed",
task: &Task{
ID: "505",
Title: "Fix: user's auth (OAuth2) [important]",
Labels: nil,
},
expected: "feat/505-fix-users-auth-oauth2-important",
},
{
name: "no labels defaults to feat",
task: &Task{
ID: "606",
Title: "New feature",
Labels: nil,
},
expected: "feat/606-new-feature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := generateBranchName(tt.task)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAutoCommit_Bad_NilTask(t *testing.T) {
err := AutoCommit(context.TODO(), nil, ".", "test message")
assert.Error(t, err)
assert.Contains(t, err.Error(), "task is required")
}
func TestAutoCommit_Bad_EmptyMessage(t *testing.T) {
task := &Task{ID: "123", Title: "Test"}
err := AutoCommit(context.TODO(), task, ".", "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "commit message is required")
}
func TestSyncStatus_Bad_NilClient(t *testing.T) {
task := &Task{ID: "123", Title: "Test"}
update := TaskUpdate{Status: StatusInProgress}
err := SyncStatus(context.TODO(), nil, task, update)
assert.Error(t, err)
assert.Contains(t, err.Error(), "client is required")
}
func TestSyncStatus_Bad_NilTask(t *testing.T) {
client := &Client{BaseURL: "http://test"}
update := TaskUpdate{Status: StatusInProgress}
err := SyncStatus(context.TODO(), client, nil, update)
assert.Error(t, err)
assert.Contains(t, err.Error(), "task is required")
}
func TestCreateBranch_Bad_NilTask(t *testing.T) {
branch, err := CreateBranch(context.TODO(), nil, ".")
assert.Error(t, err)
assert.Empty(t, branch)
assert.Contains(t, err.Error(), "task is required")
}
func TestCreatePR_Bad_NilTask(t *testing.T) {
url, err := CreatePR(context.TODO(), nil, ".", PROptions{})
assert.Error(t, err)
assert.Empty(t, url)
assert.Contains(t, err.Error(), "task is required")
}

View file

@ -1,293 +0,0 @@
package lifecycle
import (
"os"
"path/filepath"
"strings"
"forge.lthn.ai/core/go-io"
"forge.lthn.ai/core/go-log"
"gopkg.in/yaml.v3"
)
// Config holds the configuration for connecting to the core-agentic service.
type Config struct {
// BaseURL is the URL of the core-agentic API server.
BaseURL string `yaml:"base_url" json:"base_url"`
// Token is the authentication token for API requests.
Token string `yaml:"token" json:"token"`
// DefaultProject is the project to use when none is specified.
DefaultProject string `yaml:"default_project" json:"default_project"`
// AgentID is the identifier for this agent (optional, used for claiming tasks).
AgentID string `yaml:"agent_id" json:"agent_id"`
}
// configFileName is the name of the YAML config file.
const configFileName = "agentic.yaml"
// envFileName is the name of the environment file.
const envFileName = ".env"
// DefaultBaseURL is the default API endpoint if none is configured.
// Set AGENTIC_BASE_URL to override:
// - Lab: https://api.lthn.sh
// - Prod: https://api.lthn.ai
const DefaultBaseURL = "https://api.lthn.sh"
// LoadConfig loads the agentic configuration from the specified directory.
// It first checks for a .env file, then falls back to ~/.core/agentic.yaml.
// If dir is empty, it checks the current directory first.
//
// Environment variables take precedence:
// - AGENTIC_BASE_URL: API base URL
// - AGENTIC_TOKEN: Authentication token
// - AGENTIC_PROJECT: Default project
// - AGENTIC_AGENT_ID: Agent identifier
func LoadConfig(dir string) (*Config, error) {
cfg := &Config{
BaseURL: DefaultBaseURL,
}
// Try loading from .env file in the specified directory
if dir != "" {
envPath := filepath.Join(dir, envFileName)
if err := loadEnvFile(envPath, cfg); err == nil {
// Successfully loaded from .env
applyEnvOverrides(cfg)
if cfg.Token != "" {
return cfg, nil
}
}
}
// Try loading from current directory .env
if dir == "" {
cwd, err := os.Getwd()
if err == nil {
envPath := filepath.Join(cwd, envFileName)
if err := loadEnvFile(envPath, cfg); err == nil {
applyEnvOverrides(cfg)
if cfg.Token != "" {
return cfg, nil
}
}
}
}
// Try loading from ~/.core/agentic.yaml
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, log.E("agentic.LoadConfig", "failed to get home directory", err)
}
configPath := filepath.Join(homeDir, ".core", configFileName)
if err := loadYAMLConfig(configPath, cfg); err != nil && !os.IsNotExist(err) {
return nil, log.E("agentic.LoadConfig", "failed to load config", err)
}
// Apply environment variable overrides
applyEnvOverrides(cfg)
// Validate configuration
if cfg.Token == "" {
return nil, log.E("agentic.LoadConfig", "no authentication token configured", nil)
}
return cfg, nil
}
// loadEnvFile reads a .env file and extracts agentic configuration.
func loadEnvFile(path string, cfg *Config) error {
content, err := io.Local.Read(path)
if err != nil {
return err
}
for line := range strings.SplitSeq(content, "\n") {
line = strings.TrimSpace(line)
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Parse KEY=value
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// Remove quotes if present
value = strings.Trim(value, `"'`)
switch key {
case "AGENTIC_BASE_URL":
cfg.BaseURL = value
case "AGENTIC_TOKEN":
cfg.Token = value
case "AGENTIC_PROJECT":
cfg.DefaultProject = value
case "AGENTIC_AGENT_ID":
cfg.AgentID = value
}
}
return nil
}
// loadYAMLConfig reads configuration from a YAML file.
func loadYAMLConfig(path string, cfg *Config) error {
content, err := io.Local.Read(path)
if err != nil {
return err
}
return yaml.Unmarshal([]byte(content), cfg)
}
// applyEnvOverrides applies environment variable overrides to the config.
func applyEnvOverrides(cfg *Config) {
if v := os.Getenv("AGENTIC_BASE_URL"); v != "" {
cfg.BaseURL = v
}
if v := os.Getenv("AGENTIC_TOKEN"); v != "" {
cfg.Token = v
}
if v := os.Getenv("AGENTIC_PROJECT"); v != "" {
cfg.DefaultProject = v
}
if v := os.Getenv("AGENTIC_AGENT_ID"); v != "" {
cfg.AgentID = v
}
}
// SaveConfig saves the configuration to ~/.core/agentic.yaml.
func SaveConfig(cfg *Config) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return log.E("agentic.SaveConfig", "failed to get home directory", err)
}
configDir := filepath.Join(homeDir, ".core")
if err := io.Local.EnsureDir(configDir); err != nil {
return log.E("agentic.SaveConfig", "failed to create config directory", err)
}
configPath := filepath.Join(configDir, configFileName)
data, err := yaml.Marshal(cfg)
if err != nil {
return log.E("agentic.SaveConfig", "failed to marshal config", err)
}
if err := io.Local.Write(configPath, string(data)); err != nil {
return log.E("agentic.SaveConfig", "failed to write config file", err)
}
return nil
}
// ConfigPath returns the path to the config file in the user's home directory.
func ConfigPath() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", log.E("agentic.ConfigPath", "failed to get home directory", err)
}
return filepath.Join(homeDir, ".core", configFileName), nil
}
// AllowanceConfig controls allowance store backend selection.
type AllowanceConfig struct {
// StoreBackend is the storage backend: "memory", "sqlite", or "redis". Default: "memory".
StoreBackend string `yaml:"store_backend" json:"store_backend"`
// StorePath is the file path for the SQLite database.
// Default: ~/.config/agentic/allowance.db (only used when StoreBackend is "sqlite").
StorePath string `yaml:"store_path" json:"store_path"`
// RedisAddr is the host:port for the Redis server (only used when StoreBackend is "redis").
RedisAddr string `yaml:"redis_addr" json:"redis_addr"`
}
// DefaultAllowanceStorePath returns the default SQLite path for allowance data.
func DefaultAllowanceStorePath() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", log.E("agentic.DefaultAllowanceStorePath", "failed to get home directory", err)
}
return filepath.Join(homeDir, ".config", "agentic", "allowance.db"), nil
}
// NewAllowanceStoreFromConfig creates an AllowanceStore based on the given config.
// It returns a MemoryStore for "memory" (or empty) backend and a SQLiteStore for "sqlite".
func NewAllowanceStoreFromConfig(cfg AllowanceConfig) (AllowanceStore, error) {
switch cfg.StoreBackend {
case "", "memory":
return NewMemoryStore(), nil
case "sqlite":
dbPath := cfg.StorePath
if dbPath == "" {
var err error
dbPath, err = DefaultAllowanceStorePath()
if err != nil {
return nil, err
}
}
return NewSQLiteStore(dbPath)
case "redis":
return NewRedisStore(cfg.RedisAddr)
default:
return nil, &APIError{
Code: 400,
Message: "unsupported store backend: " + cfg.StoreBackend,
}
}
}
// RegistryConfig controls agent registry backend selection.
type RegistryConfig struct {
// RegistryBackend is the storage backend: "memory", "sqlite", or "redis". Default: "memory".
RegistryBackend string `yaml:"registry_backend" json:"registry_backend"`
// RegistryPath is the file path for the SQLite database.
// Default: ~/.config/agentic/registry.db (only used when RegistryBackend is "sqlite").
RegistryPath string `yaml:"registry_path" json:"registry_path"`
// RegistryRedisAddr is the host:port for the Redis server (only used when RegistryBackend is "redis").
RegistryRedisAddr string `yaml:"registry_redis_addr" json:"registry_redis_addr"`
}
// DefaultRegistryPath returns the default SQLite path for registry data.
func DefaultRegistryPath() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", log.E("agentic.DefaultRegistryPath", "failed to get home directory", err)
}
return filepath.Join(homeDir, ".config", "agentic", "registry.db"), nil
}
// NewAgentRegistryFromConfig creates an AgentRegistry based on the given config.
// It returns a MemoryRegistry for "memory" (or empty) backend, a SQLiteRegistry
// for "sqlite", and a RedisRegistry for "redis".
func NewAgentRegistryFromConfig(cfg RegistryConfig) (AgentRegistry, error) {
switch cfg.RegistryBackend {
case "", "memory":
return NewMemoryRegistry(), nil
case "sqlite":
dbPath := cfg.RegistryPath
if dbPath == "" {
var err error
dbPath, err = DefaultRegistryPath()
if err != nil {
return nil, err
}
}
return NewSQLiteRegistry(dbPath)
case "redis":
return NewRedisRegistry(cfg.RegistryRedisAddr)
default:
return nil, &APIError{
Code: 400,
Message: "unsupported registry backend: " + cfg.RegistryBackend,
}
}
}

View file

@ -1,455 +0,0 @@
package lifecycle
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoadConfig_Good_FromEnvFile(t *testing.T) {
// Create temp directory with .env file
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
envContent := `
AGENTIC_BASE_URL=https://test.api.com
AGENTIC_TOKEN=test-token-123
AGENTIC_PROJECT=my-project
AGENTIC_AGENT_ID=agent-001
`
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644)
require.NoError(t, err)
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "https://test.api.com", cfg.BaseURL)
assert.Equal(t, "test-token-123", cfg.Token)
assert.Equal(t, "my-project", cfg.DefaultProject)
assert.Equal(t, "agent-001", cfg.AgentID)
}
func TestLoadConfig_Good_FromEnvVars(t *testing.T) {
// Create temp directory with .env file (partial config)
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
envContent := `
AGENTIC_TOKEN=env-file-token
`
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644)
require.NoError(t, err)
// Set environment variables that should override
_ = os.Setenv("AGENTIC_BASE_URL", "https://env-override.com")
_ = os.Setenv("AGENTIC_TOKEN", "env-override-token")
defer func() {
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_TOKEN")
}()
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "https://env-override.com", cfg.BaseURL)
assert.Equal(t, "env-override-token", cfg.Token)
}
func TestLoadConfig_Bad_NoToken(t *testing.T) {
// Create temp directory without config
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
// Create empty .env
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(""), 0644)
require.NoError(t, err)
// Ensure no env vars are set
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_BASE_URL")
_, err = LoadConfig(tmpDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no authentication token")
}
func TestLoadConfig_Good_EnvFileWithQuotes(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
// Test with quoted values
envContent := `
AGENTIC_TOKEN="quoted-token"
AGENTIC_BASE_URL='single-quoted-url'
`
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644)
require.NoError(t, err)
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "quoted-token", cfg.Token)
assert.Equal(t, "single-quoted-url", cfg.BaseURL)
}
func TestLoadConfig_Good_EnvFileWithComments(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
envContent := `
# This is a comment
AGENTIC_TOKEN=token-with-comments
# Another comment
AGENTIC_PROJECT=commented-project
`
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644)
require.NoError(t, err)
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "token-with-comments", cfg.Token)
assert.Equal(t, "commented-project", cfg.DefaultProject)
}
func TestSaveConfig_Good(t *testing.T) {
// Create temp home directory
tmpHome, err := os.MkdirTemp("", "agentic-home")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpHome) }()
// Override HOME for the test
originalHome := os.Getenv("HOME")
_ = os.Setenv("HOME", tmpHome)
defer func() { _ = os.Setenv("HOME", originalHome) }()
cfg := &Config{
BaseURL: "https://saved.api.com",
Token: "saved-token",
DefaultProject: "saved-project",
AgentID: "saved-agent",
}
err = SaveConfig(cfg)
require.NoError(t, err)
// Verify file was created
configPath := filepath.Join(tmpHome, ".core", "agentic.yaml")
_, err = os.Stat(configPath)
assert.NoError(t, err)
// Read back the config
data, err := os.ReadFile(configPath)
require.NoError(t, err)
assert.Contains(t, string(data), "saved.api.com")
assert.Contains(t, string(data), "saved-token")
}
func TestConfigPath_Good(t *testing.T) {
path, err := ConfigPath()
require.NoError(t, err)
assert.Contains(t, path, ".core")
assert.Contains(t, path, "agentic.yaml")
}
func TestLoadConfig_Good_DefaultBaseURL(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
// Only provide token, should use default base URL
envContent := `
AGENTIC_TOKEN=test-token
`
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644)
require.NoError(t, err)
// Clear any env overrides
_ = os.Unsetenv("AGENTIC_BASE_URL")
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, DefaultBaseURL, cfg.BaseURL)
}
func TestLoadConfig_Good_FromYAMLFallback(t *testing.T) {
// Set up a temp home with ~/.core/agentic.yaml
tmpHome, err := os.MkdirTemp("", "agentic-home")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpHome) }()
originalHome := os.Getenv("HOME")
_ = os.Setenv("HOME", tmpHome)
defer func() { _ = os.Setenv("HOME", originalHome) }()
// Clear all env vars so we fall through to YAML.
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_PROJECT")
_ = os.Unsetenv("AGENTIC_AGENT_ID")
// Create ~/.core/agentic.yaml
configDir := filepath.Join(tmpHome, ".core")
err = os.MkdirAll(configDir, 0755)
require.NoError(t, err)
yamlContent := `base_url: https://yaml.api.com
token: yaml-token
default_project: yaml-project
agent_id: yaml-agent
`
err = os.WriteFile(filepath.Join(configDir, "agentic.yaml"), []byte(yamlContent), 0644)
require.NoError(t, err)
// Load from a dir with no .env to force YAML fallback.
tmpDir, err := os.MkdirTemp("", "agentic-noenv")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "https://yaml.api.com", cfg.BaseURL)
assert.Equal(t, "yaml-token", cfg.Token)
assert.Equal(t, "yaml-project", cfg.DefaultProject)
assert.Equal(t, "yaml-agent", cfg.AgentID)
}
func TestLoadConfig_Good_EnvOverridesYAML(t *testing.T) {
// Set up a temp home with ~/.core/agentic.yaml
tmpHome, err := os.MkdirTemp("", "agentic-home")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpHome) }()
originalHome := os.Getenv("HOME")
_ = os.Setenv("HOME", tmpHome)
defer func() { _ = os.Setenv("HOME", originalHome) }()
// Create ~/.core/agentic.yaml
configDir := filepath.Join(tmpHome, ".core")
err = os.MkdirAll(configDir, 0755)
require.NoError(t, err)
yamlContent := `base_url: https://yaml.api.com
token: yaml-token
`
err = os.WriteFile(filepath.Join(configDir, "agentic.yaml"), []byte(yamlContent), 0644)
require.NoError(t, err)
// Set env overrides for project and agent.
_ = os.Setenv("AGENTIC_PROJECT", "env-project")
_ = os.Setenv("AGENTIC_AGENT_ID", "env-agent")
defer func() {
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_PROJECT")
_ = os.Unsetenv("AGENTIC_AGENT_ID")
}()
tmpDir, err := os.MkdirTemp("", "agentic-noenv")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "env-project", cfg.DefaultProject, "env var should override YAML")
assert.Equal(t, "env-agent", cfg.AgentID, "env var should override YAML")
}
func TestLoadConfig_Good_EnvFileWithTokenNoOverride(t *testing.T) {
// Test that .env with a token returns immediately without
// falling through to YAML.
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
envContent := `AGENTIC_TOKEN=env-file-only`
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644)
require.NoError(t, err)
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_PROJECT")
_ = os.Unsetenv("AGENTIC_AGENT_ID")
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "env-file-only", cfg.Token)
assert.Equal(t, DefaultBaseURL, cfg.BaseURL)
}
func TestLoadConfig_Good_EnvFileWithMalformedLines(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
// Lines without = sign should be skipped.
envContent := `
AGENTIC_TOKEN=valid-token
MALFORMED_LINE_NO_EQUALS
ANOTHER_BAD_LINE
AGENTIC_PROJECT=valid-project
`
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644)
require.NoError(t, err)
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "valid-token", cfg.Token)
assert.Equal(t, "valid-project", cfg.DefaultProject)
}
func TestApplyEnvOverrides_Good_AllVars(t *testing.T) {
_ = os.Setenv("AGENTIC_BASE_URL", "https://override-url.com")
_ = os.Setenv("AGENTIC_TOKEN", "override-token")
_ = os.Setenv("AGENTIC_PROJECT", "override-project")
_ = os.Setenv("AGENTIC_AGENT_ID", "override-agent")
defer func() {
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_PROJECT")
_ = os.Unsetenv("AGENTIC_AGENT_ID")
}()
cfg := &Config{}
applyEnvOverrides(cfg)
assert.Equal(t, "https://override-url.com", cfg.BaseURL)
assert.Equal(t, "override-token", cfg.Token)
assert.Equal(t, "override-project", cfg.DefaultProject)
assert.Equal(t, "override-agent", cfg.AgentID)
}
func TestApplyEnvOverrides_Good_NoVarsSet(t *testing.T) {
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_PROJECT")
_ = os.Unsetenv("AGENTIC_AGENT_ID")
cfg := &Config{
BaseURL: "original-url",
Token: "original-token",
}
applyEnvOverrides(cfg)
assert.Equal(t, "original-url", cfg.BaseURL, "should not change without env var")
assert.Equal(t, "original-token", cfg.Token, "should not change without env var")
}
func TestSaveConfig_Good_RoundTrip(t *testing.T) {
tmpHome, err := os.MkdirTemp("", "agentic-home")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpHome) }()
originalHome := os.Getenv("HOME")
_ = os.Setenv("HOME", tmpHome)
defer func() { _ = os.Setenv("HOME", originalHome) }()
// Clear env vars so LoadConfig falls through to YAML.
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_PROJECT")
_ = os.Unsetenv("AGENTIC_AGENT_ID")
original := &Config{
BaseURL: "https://roundtrip.api.com",
Token: "roundtrip-token",
DefaultProject: "roundtrip-project",
AgentID: "roundtrip-agent",
}
err = SaveConfig(original)
require.NoError(t, err)
// Load it back by pointing to a dir with no .env.
tmpDir, err := os.MkdirTemp("", "agentic-noenv")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
loaded, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, original.BaseURL, loaded.BaseURL)
assert.Equal(t, original.Token, loaded.Token)
assert.Equal(t, original.DefaultProject, loaded.DefaultProject)
assert.Equal(t, original.AgentID, loaded.AgentID)
}
func TestLoadConfig_Good_EmptyDirUsesCurrentDir(t *testing.T) {
// Create a temp directory with .env and chdir into it.
tmpDir, err := os.MkdirTemp("", "agentic-cwd")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
envContent := `AGENTIC_TOKEN=cwd-token
AGENTIC_BASE_URL=https://cwd.api.com
`
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644)
require.NoError(t, err)
// Clear env vars.
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_PROJECT")
_ = os.Unsetenv("AGENTIC_AGENT_ID")
// Save and restore cwd.
originalCwd, err := os.Getwd()
require.NoError(t, err)
defer func() { _ = os.Chdir(originalCwd) }()
err = os.Chdir(tmpDir)
require.NoError(t, err)
cfg, err := LoadConfig("")
require.NoError(t, err)
assert.Equal(t, "cwd-token", cfg.Token)
assert.Equal(t, "https://cwd.api.com", cfg.BaseURL)
}
func TestLoadConfig_Good_EnvFileNoToken_FallsToYAML(t *testing.T) {
tmpHome, err := os.MkdirTemp("", "agentic-home")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpHome) }()
originalHome := os.Getenv("HOME")
_ = os.Setenv("HOME", tmpHome)
defer func() { _ = os.Setenv("HOME", originalHome) }()
_ = os.Unsetenv("AGENTIC_TOKEN")
_ = os.Unsetenv("AGENTIC_BASE_URL")
_ = os.Unsetenv("AGENTIC_PROJECT")
_ = os.Unsetenv("AGENTIC_AGENT_ID")
// Create .env without a token.
tmpDir, err := os.MkdirTemp("", "agentic-test")
require.NoError(t, err)
defer func() { _ = os.RemoveAll(tmpDir) }()
err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte("AGENTIC_PROJECT=env-proj\n"), 0644)
require.NoError(t, err)
// Create YAML with token.
configDir := filepath.Join(tmpHome, ".core")
err = os.MkdirAll(configDir, 0755)
require.NoError(t, err)
yamlContent := `token: yaml-fallback-token
`
err = os.WriteFile(filepath.Join(configDir, "agentic.yaml"), []byte(yamlContent), 0644)
require.NoError(t, err)
cfg, err := LoadConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, "yaml-fallback-token", cfg.Token)
}

View file

@ -1,335 +0,0 @@
// Package agentic provides AI collaboration features for task management.
package lifecycle
import (
"bytes"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"forge.lthn.ai/core/go-io"
"forge.lthn.ai/core/go-log"
)
// FileContent represents the content of a file for AI context.
type FileContent struct {
// Path is the relative path to the file.
Path string `json:"path"`
// Content is the file content.
Content string `json:"content"`
// Language is the detected programming language.
Language string `json:"language"`
}
// TaskContext contains gathered context for AI collaboration.
type TaskContext struct {
// Task is the task being worked on.
Task *Task `json:"task"`
// Files is a list of relevant file contents.
Files []FileContent `json:"files"`
// GitStatus is the current git status output.
GitStatus string `json:"git_status"`
// RecentCommits is the recent commit log.
RecentCommits string `json:"recent_commits"`
// RelatedCode contains code snippets related to the task.
RelatedCode []FileContent `json:"related_code"`
}
// BuildTaskContext gathers context for AI collaboration on a task.
func BuildTaskContext(task *Task, dir string) (*TaskContext, error) {
const op = "agentic.BuildTaskContext"
if task == nil {
return nil, log.E(op, "task is required", nil)
}
if dir == "" {
cwd, err := os.Getwd()
if err != nil {
return nil, log.E(op, "failed to get working directory", err)
}
dir = cwd
}
ctx := &TaskContext{
Task: task,
}
// Gather files mentioned in the task
files, err := GatherRelatedFiles(task, dir)
if err != nil {
// Non-fatal: continue without files
files = nil
}
ctx.Files = files
// Get git status
gitStatus, _ := runGitCommand(dir, "status", "--porcelain")
ctx.GitStatus = gitStatus
// Get recent commits
recentCommits, _ := runGitCommand(dir, "log", "--oneline", "-10")
ctx.RecentCommits = recentCommits
// Find related code by searching for keywords
relatedCode, err := findRelatedCode(task, dir)
if err != nil {
relatedCode = nil
}
ctx.RelatedCode = relatedCode
return ctx, nil
}
// GatherRelatedFiles reads files mentioned in the task.
func GatherRelatedFiles(task *Task, dir string) ([]FileContent, error) {
const op = "agentic.GatherRelatedFiles"
if task == nil {
return nil, log.E(op, "task is required", nil)
}
var files []FileContent
// Read files explicitly mentioned in the task
for _, relPath := range task.Files {
fullPath := filepath.Join(dir, relPath)
content, err := io.Local.Read(fullPath)
if err != nil {
// Skip files that don't exist
continue
}
files = append(files, FileContent{
Path: relPath,
Content: content,
Language: detectLanguage(relPath),
})
}
return files, nil
}
// findRelatedCode searches for code related to the task by keywords.
func findRelatedCode(task *Task, dir string) ([]FileContent, error) {
const op = "agentic.findRelatedCode"
if task == nil {
return nil, log.E(op, "task is required", nil)
}
// Extract keywords from title and description
keywords := extractKeywords(task.Title + " " + task.Description)
if len(keywords) == 0 {
return nil, nil
}
var files []FileContent
seen := make(map[string]bool)
// Search for each keyword using git grep
for _, keyword := range keywords {
if len(keyword) < 3 {
continue
}
output, err := runGitCommand(dir, "grep", "-l", "-i", keyword, "--", "*.go", "*.ts", "*.js", "*.py")
if err != nil {
continue
}
// Parse matched files
for line := range strings.SplitSeq(output, "\n") {
line = strings.TrimSpace(line)
if line == "" || seen[line] {
continue
}
seen[line] = true
// Limit to 10 related files
if len(files) >= 10 {
break
}
fullPath := filepath.Join(dir, line)
content, err := io.Local.Read(fullPath)
if err != nil {
continue
}
// Truncate large files
if len(content) > 5000 {
content = content[:5000] + "\n... (truncated)"
}
files = append(files, FileContent{
Path: line,
Content: content,
Language: detectLanguage(line),
})
}
if len(files) >= 10 {
break
}
}
return files, nil
}
// extractKeywords extracts meaningful words from text for searching.
func extractKeywords(text string) []string {
// Remove common words and extract identifiers
text = strings.ToLower(text)
// Split by non-alphanumeric characters
re := regexp.MustCompile(`[^a-zA-Z0-9]+`)
words := re.Split(text, -1)
// Filter stop words and short words
stopWords := map[string]bool{
"the": true, "a": true, "an": true, "and": true, "or": true, "but": true,
"in": true, "on": true, "at": true, "to": true, "for": true, "of": true,
"with": true, "by": true, "from": true, "is": true, "are": true, "was": true,
"be": true, "been": true, "being": true, "have": true, "has": true, "had": true,
"do": true, "does": true, "did": true, "will": true, "would": true, "could": true,
"should": true, "may": true, "might": true, "must": true, "shall": true,
"this": true, "that": true, "these": true, "those": true, "it": true,
"add": true, "create": true, "update": true, "fix": true, "remove": true,
"implement": true, "new": true, "file": true, "code": true,
}
var keywords []string
for _, word := range words {
word = strings.TrimSpace(word)
if len(word) >= 3 && !stopWords[word] {
keywords = append(keywords, word)
}
}
// Limit to first 5 keywords
if len(keywords) > 5 {
keywords = keywords[:5]
}
return keywords
}
// detectLanguage detects the programming language from a file extension.
func detectLanguage(path string) string {
ext := strings.ToLower(filepath.Ext(path))
languages := map[string]string{
".go": "go",
".ts": "typescript",
".tsx": "typescript",
".js": "javascript",
".jsx": "javascript",
".py": "python",
".rs": "rust",
".java": "java",
".kt": "kotlin",
".swift": "swift",
".c": "c",
".cpp": "cpp",
".h": "c",
".hpp": "cpp",
".rb": "ruby",
".php": "php",
".cs": "csharp",
".fs": "fsharp",
".scala": "scala",
".sh": "bash",
".bash": "bash",
".zsh": "zsh",
".yaml": "yaml",
".yml": "yaml",
".json": "json",
".xml": "xml",
".html": "html",
".css": "css",
".scss": "scss",
".sql": "sql",
".md": "markdown",
}
if lang, ok := languages[ext]; ok {
return lang
}
return "text"
}
// runGitCommand runs a git command and returns the output.
func runGitCommand(dir string, args ...string) (string, error) {
cmd := exec.Command("git", args...)
cmd.Dir = dir
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", err
}
return stdout.String(), nil
}
// FormatContext formats the TaskContext for AI consumption.
func (tc *TaskContext) FormatContext() string {
var sb strings.Builder
sb.WriteString("# Task Context\n\n")
// Task info
sb.WriteString("## Task\n")
sb.WriteString("ID: " + tc.Task.ID + "\n")
sb.WriteString("Title: " + tc.Task.Title + "\n")
sb.WriteString("Priority: " + string(tc.Task.Priority) + "\n")
sb.WriteString("Status: " + string(tc.Task.Status) + "\n")
sb.WriteString("\n### Description\n")
sb.WriteString(tc.Task.Description + "\n\n")
// Files
if len(tc.Files) > 0 {
sb.WriteString("## Task Files\n")
for _, f := range tc.Files {
sb.WriteString("### " + f.Path + " (" + f.Language + ")\n")
sb.WriteString("```" + f.Language + "\n")
sb.WriteString(f.Content)
sb.WriteString("\n```\n\n")
}
}
// Git status
if tc.GitStatus != "" {
sb.WriteString("## Git Status\n")
sb.WriteString("```\n")
sb.WriteString(tc.GitStatus)
sb.WriteString("\n```\n\n")
}
// Recent commits
if tc.RecentCommits != "" {
sb.WriteString("## Recent Commits\n")
sb.WriteString("```\n")
sb.WriteString(tc.RecentCommits)
sb.WriteString("\n```\n\n")
}
// Related code
if len(tc.RelatedCode) > 0 {
sb.WriteString("## Related Code\n")
for _, f := range tc.RelatedCode {
sb.WriteString("### " + f.Path + " (" + f.Language + ")\n")
sb.WriteString("```" + f.Language + "\n")
sb.WriteString(f.Content)
sb.WriteString("\n```\n\n")
}
}
return sb.String()
}

View file

@ -1,248 +0,0 @@
package lifecycle
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// initGitRepoWithCode creates a git repo with searchable Go code.
func initGitRepoWithCode(t *testing.T) string {
t.Helper()
dir := initGitRepo(t)
// Create Go files with known content for git grep.
files := map[string]string{
"auth.go": `package main
// Authenticate validates user credentials.
func Authenticate(user, pass string) bool {
return user != "" && pass != ""
}
`,
"handler.go": `package main
// HandleRequest processes HTTP requests for authentication.
func HandleRequest() {
// authentication logic here
}
`,
"util.go": `package main
// TokenValidator checks JWT tokens.
func TokenValidator(token string) bool {
return len(token) > 0
}
`,
}
for name, content := range files {
err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0644)
require.NoError(t, err)
}
// Stage and commit all files so git grep can find them.
_, err := runCommandCtx(context.Background(), dir, "git", "add", "-A")
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "add code files")
require.NoError(t, err)
return dir
}
func TestFindRelatedCode_Good_MatchesKeywords(t *testing.T) {
dir := initGitRepoWithCode(t)
task := &Task{
ID: "code-1",
Title: "Fix authentication handler",
Description: "The authentication handler needs refactoring",
}
files, err := findRelatedCode(task, dir)
require.NoError(t, err)
assert.NotEmpty(t, files, "should find files matching 'authentication' keyword")
// Verify language detection.
for _, f := range files {
assert.Equal(t, "go", f.Language)
}
}
func TestFindRelatedCode_Good_NoKeywords(t *testing.T) {
dir := initGitRepoWithCode(t)
task := &Task{
ID: "code-2",
Title: "do it", // too short, all stop words
Description: "fix the bug in the code",
}
files, err := findRelatedCode(task, dir)
require.NoError(t, err)
// Keywords extracted from "do it fix the bug in the code" -- most are stop words.
// Only "bug" is 3+ chars and not a stop word, but may not match any files.
// Result can be nil or empty -- both are acceptable.
_ = files
}
func TestFindRelatedCode_Bad_NilTask(t *testing.T) {
files, err := findRelatedCode(nil, ".")
assert.Error(t, err)
assert.Nil(t, files)
}
func TestFindRelatedCode_Good_LimitsTo10Files(t *testing.T) {
dir := initGitRepoWithCode(t)
// Create 15 files all containing the keyword "validation".
for i := range 15 {
name := filepath.Join(dir, "validation_"+string(rune('a'+i))+".go")
content := "package main\n// validation logic\nfunc Validate" + string(rune('A'+i)) + "() {}\n"
err := os.WriteFile(name, []byte(content), 0644)
require.NoError(t, err)
}
_, err := runCommandCtx(context.Background(), dir, "git", "add", "-A")
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "add validation files")
require.NoError(t, err)
task := &Task{
ID: "code-3",
Title: "validation refactoring",
Description: "Refactor all validation logic",
}
files, err := findRelatedCode(task, dir)
require.NoError(t, err)
assert.LessOrEqual(t, len(files), 10, "should limit to 10 related files")
}
func TestFindRelatedCode_Good_TruncatesLargeFiles(t *testing.T) {
dir := initGitRepoWithCode(t)
// Create a file larger than 5000 chars containing a searchable keyword.
largeContent := "package main\n// largecontent\n"
for len(largeContent) < 6000 {
largeContent += "// This is filler content for testing truncation purposes.\n"
}
err := os.WriteFile(filepath.Join(dir, "large.go"), []byte(largeContent), 0644)
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "add", "-A")
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "add large file")
require.NoError(t, err)
task := &Task{
ID: "code-4",
Title: "largecontent analysis",
Description: "Analyse the largecontent module",
}
files, err := findRelatedCode(task, dir)
require.NoError(t, err)
for _, f := range files {
if f.Path == "large.go" {
assert.True(t, len(f.Content) <= 5020, "content should be truncated")
assert.Contains(t, f.Content, "... (truncated)")
return
}
}
// If large.go wasn't found by git grep, that's acceptable.
}
func TestBuildTaskContext_Good_WithGitRepo(t *testing.T) {
dir := initGitRepoWithCode(t)
task := &Task{
ID: "ctx-1",
Title: "Test context building with authentication",
Description: "Build context in a git repo with searchable code",
Priority: PriorityMedium,
Status: StatusPending,
Files: []string{"auth.go"},
CreatedAt: time.Now(),
}
ctx, err := BuildTaskContext(task, dir)
require.NoError(t, err)
assert.NotNil(t, ctx)
assert.Equal(t, task, ctx.Task)
// Should have gathered the auth.go file.
assert.Len(t, ctx.Files, 1)
assert.Equal(t, "auth.go", ctx.Files[0].Path)
assert.Contains(t, ctx.Files[0].Content, "Authenticate")
// Should have recent commits.
assert.NotEmpty(t, ctx.RecentCommits)
// Should have found related code.
assert.NotEmpty(t, ctx.RelatedCode, "should find code related to 'authentication'")
}
func TestBuildTaskContext_Good_EmptyDir(t *testing.T) {
task := &Task{
ID: "ctx-2",
Title: "Test with empty dir",
Description: "Testing",
Priority: PriorityLow,
Status: StatusPending,
CreatedAt: time.Now(),
}
// Empty dir defaults to cwd -- BuildTaskContext handles errors gracefully.
ctx, err := BuildTaskContext(task, "")
require.NoError(t, err)
assert.NotNil(t, ctx)
}
func TestFormatContext_Good_EmptySections(t *testing.T) {
task := &Task{
ID: "fmt-1",
Title: "Minimal task",
Description: "No files, no git",
Priority: PriorityLow,
Status: StatusPending,
}
ctx := &TaskContext{
Task: task,
Files: nil,
GitStatus: "",
RecentCommits: "",
RelatedCode: nil,
}
formatted := ctx.FormatContext()
assert.Contains(t, formatted, "# Task Context")
assert.Contains(t, formatted, "fmt-1")
assert.NotContains(t, formatted, "## Task Files")
assert.NotContains(t, formatted, "## Git Status")
assert.NotContains(t, formatted, "## Recent Commits")
assert.NotContains(t, formatted, "## Related Code")
}
func TestRunGitCommand_Good(t *testing.T) {
dir := initGitRepo(t)
output, err := runGitCommand(dir, "log", "--oneline", "-1")
require.NoError(t, err)
assert.Contains(t, output, "initial commit")
}
func TestRunGitCommand_Bad_NotAGitRepo(t *testing.T) {
dir := t.TempDir()
_, err := runGitCommand(dir, "status")
assert.Error(t, err)
}

View file

@ -1,214 +0,0 @@
package lifecycle
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuildTaskContext_Good(t *testing.T) {
// Create a temp directory with some files
tmpDir := t.TempDir()
// Create a test file
testFile := filepath.Join(tmpDir, "main.go")
err := os.WriteFile(testFile, []byte("package main\n\nfunc main() {}\n"), 0644)
require.NoError(t, err)
task := &Task{
ID: "test-123",
Title: "Test Task",
Description: "A test task description",
Priority: PriorityMedium,
Status: StatusPending,
Files: []string{"main.go"},
CreatedAt: time.Now(),
}
ctx, err := BuildTaskContext(task, tmpDir)
require.NoError(t, err)
assert.NotNil(t, ctx)
assert.Equal(t, task, ctx.Task)
assert.Len(t, ctx.Files, 1)
assert.Equal(t, "main.go", ctx.Files[0].Path)
assert.Equal(t, "go", ctx.Files[0].Language)
}
func TestBuildTaskContext_Bad_NilTask(t *testing.T) {
ctx, err := BuildTaskContext(nil, ".")
assert.Error(t, err)
assert.Nil(t, ctx)
assert.Contains(t, err.Error(), "task is required")
}
func TestGatherRelatedFiles_Good(t *testing.T) {
tmpDir := t.TempDir()
// Create test files
files := map[string]string{
"app.go": "package app\n\nfunc Run() {}\n",
"config.ts": "export const config = {};\n",
"README.md": "# Project\n",
}
for name, content := range files {
path := filepath.Join(tmpDir, name)
err := os.WriteFile(path, []byte(content), 0644)
require.NoError(t, err)
}
task := &Task{
ID: "task-1",
Title: "Test",
Files: []string{"app.go", "config.ts"},
}
gathered, err := GatherRelatedFiles(task, tmpDir)
require.NoError(t, err)
assert.Len(t, gathered, 2)
// Check languages detected correctly
foundGo := false
foundTS := false
for _, f := range gathered {
if f.Path == "app.go" {
foundGo = true
assert.Equal(t, "go", f.Language)
}
if f.Path == "config.ts" {
foundTS = true
assert.Equal(t, "typescript", f.Language)
}
}
assert.True(t, foundGo, "should find app.go")
assert.True(t, foundTS, "should find config.ts")
}
func TestGatherRelatedFiles_Bad_NilTask(t *testing.T) {
files, err := GatherRelatedFiles(nil, ".")
assert.Error(t, err)
assert.Nil(t, files)
}
func TestGatherRelatedFiles_Good_MissingFiles(t *testing.T) {
tmpDir := t.TempDir()
task := &Task{
ID: "task-1",
Title: "Test",
Files: []string{"nonexistent.go", "also-missing.ts"},
}
// Should not error, just return empty list
gathered, err := GatherRelatedFiles(task, tmpDir)
require.NoError(t, err)
assert.Empty(t, gathered)
}
func TestDetectLanguage(t *testing.T) {
tests := []struct {
path string
expected string
}{
{"main.go", "go"},
{"app.ts", "typescript"},
{"app.tsx", "typescript"},
{"script.js", "javascript"},
{"script.jsx", "javascript"},
{"main.py", "python"},
{"lib.rs", "rust"},
{"App.java", "java"},
{"config.yaml", "yaml"},
{"config.yml", "yaml"},
{"data.json", "json"},
{"index.html", "html"},
{"styles.css", "css"},
{"styles.scss", "scss"},
{"query.sql", "sql"},
{"README.md", "markdown"},
{"unknown.xyz", "text"},
{"", "text"},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
result := detectLanguage(tt.path)
assert.Equal(t, tt.expected, result)
})
}
}
func TestExtractKeywords(t *testing.T) {
tests := []struct {
name string
text string
expected int // minimum number of keywords expected
}{
{
name: "simple title",
text: "Add user authentication feature",
expected: 2,
},
{
name: "with stop words",
text: "The quick brown fox jumps over the lazy dog",
expected: 3,
},
{
name: "technical text",
text: "Implement OAuth2 authentication with JWT tokens",
expected: 3,
},
{
name: "empty",
text: "",
expected: 0,
},
{
name: "only stop words",
text: "the a an and or but in on at",
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keywords := extractKeywords(tt.text)
assert.GreaterOrEqual(t, len(keywords), tt.expected)
// Keywords should not exceed 5
assert.LessOrEqual(t, len(keywords), 5)
})
}
}
func TestTaskContext_FormatContext(t *testing.T) {
task := &Task{
ID: "test-456",
Title: "Test Formatting",
Description: "This is a test description",
Priority: PriorityHigh,
Status: StatusInProgress,
}
ctx := &TaskContext{
Task: task,
Files: []FileContent{{Path: "main.go", Content: "package main", Language: "go"}},
GitStatus: " M main.go",
RecentCommits: "abc123 Initial commit",
RelatedCode: []FileContent{{Path: "util.go", Content: "package util", Language: "go"}},
}
formatted := ctx.FormatContext()
assert.Contains(t, formatted, "# Task Context")
assert.Contains(t, formatted, "test-456")
assert.Contains(t, formatted, "Test Formatting")
assert.Contains(t, formatted, "## Task Files")
assert.Contains(t, formatted, "## Git Status")
assert.Contains(t, formatted, "## Recent Commits")
assert.Contains(t, formatted, "## Related Code")
}

View file

@ -1,757 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"forge.lthn.ai/core/go/pkg/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ============================================================================
// service.go — NewService, OnStartup, handleTask, doCommit, doPrompt
// ============================================================================
func TestNewService_Good(t *testing.T) {
c, err := core.New()
require.NoError(t, err)
opts := DefaultServiceOptions()
factory := NewService(opts)
result, err := factory(c)
require.NoError(t, err)
require.NotNil(t, result)
svc, ok := result.(*Service)
require.True(t, ok, "factory should return *Service")
assert.Equal(t, opts.DefaultTools, svc.Opts().DefaultTools)
assert.Equal(t, opts.AllowEdit, svc.Opts().AllowEdit)
}
func TestNewService_Good_CustomOpts(t *testing.T) {
c, err := core.New()
require.NoError(t, err)
opts := ServiceOptions{
DefaultTools: []string{"Bash", "Read", "Write", "Edit"},
AllowEdit: true,
}
factory := NewService(opts)
result, err := factory(c)
require.NoError(t, err)
svc := result.(*Service)
assert.True(t, svc.Opts().AllowEdit)
assert.Len(t, svc.Opts().DefaultTools, 4)
}
func TestOnStartup_Good(t *testing.T) {
c, err := core.New()
require.NoError(t, err)
opts := DefaultServiceOptions()
svc := &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}
err = svc.OnStartup(context.Background())
assert.NoError(t, err)
}
// mockClaude creates a mock "claude" binary that exits with code 1 and
// prepends its directory to PATH, restoring PATH when the test finishes.
func mockClaude(t *testing.T) {
t.Helper()
mockBin := filepath.Join(t.TempDir(), "claude")
err := os.WriteFile(mockBin, []byte("#!/bin/sh\nexit 1\n"), 0755)
require.NoError(t, err)
origPath := os.Getenv("PATH")
t.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath)
}
func TestHandleTask_Good_TaskCommit(t *testing.T) {
mockClaude(t)
c, err := core.New()
require.NoError(t, err)
opts := DefaultServiceOptions()
svc := &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}
task := TaskCommit{
Path: t.TempDir(),
Name: "test",
CanEdit: false,
}
result, handled, err := svc.handleTask(c, task)
assert.Nil(t, result)
assert.True(t, handled, "TaskCommit should be handled")
assert.Error(t, err, "mock claude should exit 1")
}
func TestHandleTask_Good_TaskCommitCanEdit(t *testing.T) {
mockClaude(t)
c, err := core.New()
require.NoError(t, err)
opts := DefaultServiceOptions()
svc := &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}
task := TaskCommit{
Path: t.TempDir(),
Name: "test-edit",
CanEdit: true,
}
result, handled, err := svc.handleTask(c, task)
assert.Nil(t, result)
assert.True(t, handled)
assert.Error(t, err, "mock claude should exit 1")
}
func TestHandleTask_Good_TaskPrompt(t *testing.T) {
mockClaude(t)
c, err := core.New()
require.NoError(t, err)
opts := DefaultServiceOptions()
svc := &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}
task := TaskPrompt{
Prompt: "test prompt",
WorkDir: t.TempDir(),
}
result, handled, err := svc.handleTask(c, task)
assert.Nil(t, result)
assert.True(t, handled, "TaskPrompt should be handled")
assert.Error(t, err, "mock claude should exit 1")
}
func TestHandleTask_Good_TaskPromptWithTaskID(t *testing.T) {
mockClaude(t)
c, err := core.New()
require.NoError(t, err)
opts := DefaultServiceOptions()
svc := &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}
task := TaskPrompt{
Prompt: "test prompt",
WorkDir: t.TempDir(),
taskID: "task-123",
}
result, handled, err := svc.handleTask(c, task)
assert.Nil(t, result)
assert.True(t, handled)
assert.Error(t, err, "mock claude should exit 1")
}
func TestHandleTask_Good_TaskPromptWithCustomTools(t *testing.T) {
mockClaude(t)
c, err := core.New()
require.NoError(t, err)
opts := DefaultServiceOptions()
svc := &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}
task := TaskPrompt{
Prompt: "test prompt",
WorkDir: t.TempDir(),
AllowedTools: []string{"Bash", "Read"},
}
result, handled, err := svc.handleTask(c, task)
assert.Nil(t, result)
assert.True(t, handled)
assert.Error(t, err, "mock claude should exit 1")
}
func TestHandleTask_Good_TaskPromptEmptyDefaultTools(t *testing.T) {
mockClaude(t)
c, err := core.New()
require.NoError(t, err)
opts := ServiceOptions{
DefaultTools: nil, // empty tools
}
svc := &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}
task := TaskPrompt{
Prompt: "test prompt",
WorkDir: t.TempDir(),
}
result, handled, err := svc.handleTask(c, task)
assert.Nil(t, result)
assert.True(t, handled)
assert.Error(t, err, "mock claude should exit 1")
}
func TestHandleTask_Good_UnknownTask(t *testing.T) {
c, err := core.New()
require.NoError(t, err)
opts := DefaultServiceOptions()
svc := &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}
// A string is not a recognised task type.
result, handled, err := svc.handleTask(c, "unknown-task")
assert.Nil(t, result)
assert.False(t, handled, "unknown task should not be handled")
assert.NoError(t, err)
}
// ============================================================================
// completion.go — CreatePR full coverage
// ============================================================================
func TestCreatePR_Good_WithGhMock(t *testing.T) {
dir := initGitRepo(t)
// Create a script that pretends to be "gh"
mockBin := filepath.Join(t.TempDir(), "gh")
mockScript := `#!/bin/sh
echo "https://github.com/owner/repo/pull/42"
`
err := os.WriteFile(mockBin, []byte(mockScript), 0755)
require.NoError(t, err)
// Prepend mock bin directory to PATH
origPath := os.Getenv("PATH")
_ = os.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath)
defer func() { _ = os.Setenv("PATH", origPath) }()
task := &Task{
ID: "PR-10",
Title: "Test PR",
Description: "Test PR description",
Priority: PriorityMedium,
}
prURL, err := CreatePR(context.Background(), task, dir, PROptions{})
require.NoError(t, err)
assert.Equal(t, "https://github.com/owner/repo/pull/42", prURL)
}
func TestCreatePR_Good_WithAllOptions(t *testing.T) {
dir := initGitRepo(t)
mockBin := filepath.Join(t.TempDir(), "gh")
mockScript := `#!/bin/sh
echo "https://github.com/owner/repo/pull/99"
`
err := os.WriteFile(mockBin, []byte(mockScript), 0755)
require.NoError(t, err)
origPath := os.Getenv("PATH")
_ = os.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath)
defer func() { _ = os.Setenv("PATH", origPath) }()
task := &Task{
ID: "PR-20",
Title: "Full options PR",
Description: "All options test",
Priority: PriorityHigh,
Labels: []string{"enhancement", "v2"},
}
opts := PROptions{
Title: "Custom title",
Body: "Custom body",
Draft: true,
Labels: []string{"enhancement", "v2"},
Base: "develop",
}
prURL, err := CreatePR(context.Background(), task, dir, opts)
require.NoError(t, err)
assert.Equal(t, "https://github.com/owner/repo/pull/99", prURL)
}
func TestCreatePR_Good_DefaultTitleAndBody(t *testing.T) {
dir := initGitRepo(t)
mockBin := filepath.Join(t.TempDir(), "gh")
mockScript := `#!/bin/sh
echo "https://github.com/owner/repo/pull/55"
`
err := os.WriteFile(mockBin, []byte(mockScript), 0755)
require.NoError(t, err)
origPath := os.Getenv("PATH")
_ = os.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath)
defer func() { _ = os.Setenv("PATH", origPath) }()
task := &Task{
ID: "PR-30",
Title: "Default title from task",
Description: "Default body from task description",
Priority: PriorityCritical,
}
// Empty PROptions — title and body should default from task.
prURL, err := CreatePR(context.Background(), task, dir, PROptions{})
require.NoError(t, err)
assert.Contains(t, prURL, "pull/55")
}
func TestCreatePR_Bad_GhFails(t *testing.T) {
dir := initGitRepo(t)
mockBin := filepath.Join(t.TempDir(), "gh")
mockScript := `#!/bin/sh
echo "error: not logged in" >&2
exit 1
`
err := os.WriteFile(mockBin, []byte(mockScript), 0755)
require.NoError(t, err)
origPath := os.Getenv("PATH")
_ = os.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath)
defer func() { _ = os.Setenv("PATH", origPath) }()
task := &Task{
ID: "PR-40",
Title: "Failing PR",
}
prURL, err := CreatePR(context.Background(), task, dir, PROptions{})
assert.Error(t, err)
assert.Empty(t, prURL)
assert.Contains(t, err.Error(), "failed to create PR")
}
func TestCreatePR_Bad_GhNotFound(t *testing.T) {
dir := initGitRepo(t)
// Set PATH to an empty directory so "gh" is not found.
emptyDir := t.TempDir()
origPath := os.Getenv("PATH")
_ = os.Setenv("PATH", emptyDir)
defer func() { _ = os.Setenv("PATH", origPath) }()
task := &Task{
ID: "PR-50",
Title: "No gh binary",
}
prURL, err := CreatePR(context.Background(), task, dir, PROptions{})
assert.Error(t, err)
assert.Empty(t, prURL)
}
// ============================================================================
// completion.go — PushChanges success path
// ============================================================================
func TestPushChanges_Good_WithRemote(t *testing.T) {
// Create a bare remote repo and a working repo that pushes to it.
remoteDir := t.TempDir()
_, err := runCommandCtx(context.Background(), remoteDir, "git", "init", "--bare")
require.NoError(t, err)
dir := initGitRepo(t)
_, err = runCommandCtx(context.Background(), dir, "git", "remote", "add", "origin", remoteDir)
require.NoError(t, err)
// Push the initial commit to set up upstream.
branch, err := GetCurrentBranch(context.Background(), dir)
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "push", "-u", "origin", branch)
require.NoError(t, err)
// Create a new commit to push.
err = os.WriteFile(filepath.Join(dir, "push-test.txt"), []byte("push me\n"), 0644)
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "add", "-A")
require.NoError(t, err)
_, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "push test")
require.NoError(t, err)
err = PushChanges(context.Background(), dir)
assert.NoError(t, err)
}
// ============================================================================
// completion.go — CreateBranch in non-git dir
// ============================================================================
func TestCreateBranch_Bad_NotAGitRepo(t *testing.T) {
dir := t.TempDir()
task := &Task{
ID: "BR-99",
Title: "Not a repo",
}
branchName, err := CreateBranch(context.Background(), task, dir)
assert.Error(t, err)
assert.Empty(t, branchName)
assert.Contains(t, err.Error(), "failed to create branch")
}
// ============================================================================
// config.go — SaveConfig error paths, ConfigPath
// ============================================================================
func TestSaveConfig_Good_CreatesConfigDir(t *testing.T) {
tmpHome := t.TempDir()
originalHome := os.Getenv("HOME")
_ = os.Setenv("HOME", tmpHome)
defer func() { _ = os.Setenv("HOME", originalHome) }()
cfg := &Config{
BaseURL: "https://test.example.com",
Token: "test-token-123",
}
err := SaveConfig(cfg)
require.NoError(t, err)
// Verify .core directory was created.
info, err := os.Stat(filepath.Join(tmpHome, ".core"))
require.NoError(t, err)
assert.True(t, info.IsDir())
}
func TestSaveConfig_Good_OverwritesExisting(t *testing.T) {
tmpHome := t.TempDir()
originalHome := os.Getenv("HOME")
_ = os.Setenv("HOME", tmpHome)
defer func() { _ = os.Setenv("HOME", originalHome) }()
// Write first config.
cfg1 := &Config{Token: "first-token"}
err := SaveConfig(cfg1)
require.NoError(t, err)
// Overwrite with second config.
cfg2 := &Config{Token: "second-token"}
err = SaveConfig(cfg2)
require.NoError(t, err)
// Verify second config is saved.
data, err := os.ReadFile(filepath.Join(tmpHome, ".core", "agentic.yaml"))
require.NoError(t, err)
assert.Contains(t, string(data), "second-token")
assert.NotContains(t, string(data), "first-token")
}
func TestConfigPath_Good_ContainsExpectedComponents(t *testing.T) {
path, err := ConfigPath()
require.NoError(t, err)
// Path should end with .core/agentic.yaml.
assert.True(t, filepath.IsAbs(path), "path should be absolute")
dir, file := filepath.Split(path)
assert.Equal(t, "agentic.yaml", file)
assert.Contains(t, dir, ".core")
}
// ============================================================================
// allowance_service.go — ResetAgent error path
// ============================================================================
// resetErrorStore extends errorStore with a ResetUsage failure mode.
type resetErrorStore struct {
*MemoryStore
failReset bool
}
func (e *resetErrorStore) ResetUsage(agentID string) error {
if e.failReset {
return errors.New("simulated ResetUsage error")
}
return e.MemoryStore.ResetUsage(agentID)
}
func TestResetAgent_Bad_StoreError(t *testing.T) {
store := &resetErrorStore{
MemoryStore: NewMemoryStore(),
failReset: true,
}
svc := NewAllowanceService(store)
err := svc.ResetAgent("agent-1")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to reset usage")
}
func TestResetAgent_Good_Success(t *testing.T) {
store := &resetErrorStore{
MemoryStore: NewMemoryStore(),
failReset: false,
}
svc := NewAllowanceService(store)
// Pre-populate some usage.
_ = store.IncrementUsage("agent-1", 5000, 3)
err := svc.ResetAgent("agent-1")
require.NoError(t, err)
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, int64(0), usage.TokensUsed)
assert.Equal(t, 0, usage.JobsStarted)
}
// ============================================================================
// client.go — error paths for ListTasks, GetTask, ClaimTask, UpdateTask,
// CompleteTask, Ping
// ============================================================================
func TestClient_ListTasks_Bad_ConnectionRefused(t *testing.T) {
client := NewClient("http://127.0.0.1:1", "test-token")
client.HTTPClient.Timeout = 100 * 1000000 // 100ms in nanoseconds
tasks, err := client.ListTasks(context.Background(), ListOptions{})
assert.Error(t, err)
assert.Nil(t, tasks)
assert.Contains(t, err.Error(), "request failed")
}
func TestClient_ListTasks_Bad_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json"))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
tasks, err := client.ListTasks(context.Background(), ListOptions{})
assert.Error(t, err)
assert.Nil(t, tasks)
assert.Contains(t, err.Error(), "failed to decode response")
}
func TestClient_GetTask_Bad_ConnectionRefused(t *testing.T) {
client := NewClient("http://127.0.0.1:1", "test-token")
client.HTTPClient.Timeout = 100 * 1000000
task, err := client.GetTask(context.Background(), "task-1")
assert.Error(t, err)
assert.Nil(t, task)
assert.Contains(t, err.Error(), "request failed")
}
func TestClient_GetTask_Bad_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{invalid"))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task, err := client.GetTask(context.Background(), "task-1")
assert.Error(t, err)
assert.Nil(t, task)
assert.Contains(t, err.Error(), "failed to decode response")
}
func TestClient_ClaimTask_Bad_ConnectionRefused(t *testing.T) {
client := NewClient("http://127.0.0.1:1", "test-token")
client.HTTPClient.Timeout = 100 * 1000000
task, err := client.ClaimTask(context.Background(), "task-1")
assert.Error(t, err)
assert.Nil(t, task)
assert.Contains(t, err.Error(), "request failed")
}
func TestClient_ClaimTask_Bad_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("completely broken json"))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.AgentID = "agent-1"
task, err := client.ClaimTask(context.Background(), "task-1")
assert.Error(t, err)
assert.Nil(t, task)
assert.Contains(t, err.Error(), "failed to decode response")
}
func TestClient_UpdateTask_Bad_ConnectionRefused(t *testing.T) {
client := NewClient("http://127.0.0.1:1", "test-token")
client.HTTPClient.Timeout = 100 * 1000000
err := client.UpdateTask(context.Background(), "task-1", TaskUpdate{
Status: StatusInProgress,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "request failed")
}
func TestClient_UpdateTask_Bad_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(APIError{Message: "server error"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.UpdateTask(context.Background(), "task-1", TaskUpdate{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "server error")
}
func TestClient_CompleteTask_Bad_ConnectionRefused(t *testing.T) {
client := NewClient("http://127.0.0.1:1", "test-token")
client.HTTPClient.Timeout = 100 * 1000000
err := client.CompleteTask(context.Background(), "task-1", TaskResult{
Success: true,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "request failed")
}
func TestClient_CompleteTask_Bad_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(APIError{Message: "bad request"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.CompleteTask(context.Background(), "task-1", TaskResult{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "bad request")
}
func TestClient_Ping_Bad_ServerReturns5xx(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
err := client.Ping(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "status 503")
}
// ============================================================================
// context.go — BuildTaskContext edge cases
// ============================================================================
func TestBuildTaskContext_Good_FilesGatherError(t *testing.T) {
// Task with files but in a non-existent directory.
task := &Task{
ID: "ctx-err-1",
Title: "Files error test",
Files: []string{"nonexistent.go"},
}
dir := t.TempDir()
ctx, err := BuildTaskContext(task, dir)
require.NoError(t, err, "BuildTaskContext should not fail even if files are missing")
assert.NotNil(t, ctx)
assert.Empty(t, ctx.Files, "no files should be gathered")
}
// ============================================================================
// completion.go — generateBranchName edge cases
// ============================================================================
func TestGenerateBranchName_Good_TestsLabel(t *testing.T) {
task := &Task{
ID: "GEN-1",
Title: "Add tests for core",
Labels: []string{"tests"},
}
name := generateBranchName(task)
assert.Equal(t, "test/GEN-1-add-tests-for-core", name)
}
func TestGenerateBranchName_Good_EmptyTitle(t *testing.T) {
task := &Task{
ID: "GEN-2",
Title: "",
Labels: nil,
}
name := generateBranchName(task)
assert.Equal(t, "feat/GEN-2-", name)
}
func TestGenerateBranchName_Good_BugfixLabel(t *testing.T) {
task := &Task{
ID: "GEN-3",
Title: "Fix memory leak",
Labels: []string{"bugfix"},
}
name := generateBranchName(task)
assert.Equal(t, "fix/GEN-3-fix-memory-leak", name)
}
func TestGenerateBranchName_Good_DocsLabel(t *testing.T) {
task := &Task{
ID: "GEN-4",
Title: "Update docs",
Labels: []string{"docs"},
}
name := generateBranchName(task)
assert.Equal(t, "docs/GEN-4-update-docs", name)
}
func TestGenerateBranchName_Good_FixLabel(t *testing.T) {
task := &Task{
ID: "GEN-5",
Title: "Fix something",
Labels: []string{"fix"},
}
name := generateBranchName(task)
assert.Equal(t, "fix/GEN-5-fix-something", name)
}
// ============================================================================
// AutoCommit additional edge cases
// ============================================================================
func TestAutoCommit_Bad_NotAGitRepo(t *testing.T) {
dir := t.TempDir()
task := &Task{ID: "AC-1", Title: "Not a repo"}
err := AutoCommit(context.Background(), task, dir, "feat: test")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to stage changes")
}

View file

@ -1,259 +0,0 @@
package lifecycle
import (
"cmp"
"context"
"slices"
"time"
"forge.lthn.ai/core/go-log"
)
const (
// DefaultMaxRetries is the default number of dispatch attempts before dead-lettering.
DefaultMaxRetries = 3
// baseBackoff is the base duration for exponential backoff between retries.
baseBackoff = 5 * time.Second
)
// Dispatcher orchestrates task dispatch by combining the agent registry,
// task router, allowance service, and API client.
type Dispatcher struct {
registry AgentRegistry
router TaskRouter
allowance *AllowanceService
client *Client // can be nil for tests
events EventEmitter
}
// NewDispatcher creates a new Dispatcher with the given dependencies.
func NewDispatcher(registry AgentRegistry, router TaskRouter, allowance *AllowanceService, client *Client) *Dispatcher {
return &Dispatcher{
registry: registry,
router: router,
allowance: allowance,
client: client,
}
}
// SetEventEmitter attaches an event emitter to the dispatcher for lifecycle notifications.
func (d *Dispatcher) SetEventEmitter(em EventEmitter) {
d.events = em
}
// emit is a convenience helper that publishes an event if an emitter is set.
func (d *Dispatcher) emit(ctx context.Context, event Event) {
if d.events != nil {
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
_ = d.events.Emit(ctx, event)
}
}
// Dispatch assigns a task to the best available agent. It queries the registry
// for available agents, routes the task, checks the agent's allowance, claims
// the task via the API client (if present), and records usage. Returns the
// assigned agent ID.
func (d *Dispatcher) Dispatch(ctx context.Context, task *Task) (string, error) {
const op = "Dispatcher.Dispatch"
// 1. Get available agents from registry.
agents := d.registry.List()
// 2. Route task to best agent.
agentID, err := d.router.Route(task, agents)
if err != nil {
d.emit(ctx, Event{
Type: EventDispatchFailedNoAgent,
TaskID: task.ID,
})
return "", log.E(op, "routing failed", err)
}
// 3. Check allowance for the selected agent.
check, err := d.allowance.Check(agentID, "")
if err != nil {
return "", log.E(op, "allowance check failed", err)
}
if !check.Allowed {
d.emit(ctx, Event{
Type: EventDispatchFailedQuota,
TaskID: task.ID,
AgentID: agentID,
Payload: check.Reason,
})
return "", log.E(op, "agent quota exceeded: "+check.Reason, nil)
}
// 4. Claim the task via the API client (if available).
if d.client != nil {
if _, err := d.client.ClaimTask(ctx, task.ID); err != nil {
return "", log.E(op, "failed to claim task", err)
}
d.emit(ctx, Event{
Type: EventTaskClaimed,
TaskID: task.ID,
AgentID: agentID,
})
}
// 5. Record job start usage.
if err := d.allowance.RecordUsage(UsageReport{
AgentID: agentID,
JobID: task.ID,
Event: QuotaEventJobStarted,
Timestamp: time.Now().UTC(),
}); err != nil {
return "", log.E(op, "failed to record usage", err)
}
d.emit(ctx, Event{
Type: EventTaskDispatched,
TaskID: task.ID,
AgentID: agentID,
})
return agentID, nil
}
// priorityRank maps a TaskPriority to a numeric rank for sorting.
// Lower values are dispatched first.
func priorityRank(p TaskPriority) int {
switch p {
case PriorityCritical:
return 0
case PriorityHigh:
return 1
case PriorityMedium:
return 2
case PriorityLow:
return 3
default:
return 4
}
}
// sortTasksByPriority sorts tasks by priority (Critical first) then by
// CreatedAt (oldest first) as a tie-breaker. Uses slices.SortStableFunc for determinism.
func sortTasksByPriority(tasks []Task) {
slices.SortStableFunc(tasks, func(a, b Task) int {
ri, rj := priorityRank(a.Priority), priorityRank(b.Priority)
if ri != rj {
return cmp.Compare(ri, rj)
}
if a.CreatedAt.Before(b.CreatedAt) {
return -1
}
if a.CreatedAt.After(b.CreatedAt) {
return 1
}
return 0
})
}
// backoffDuration returns the exponential backoff duration for the given retry
// count. First retry waits baseBackoff (5s), second waits 10s, third 20s, etc.
func backoffDuration(retryCount int) time.Duration {
if retryCount <= 0 {
return 0
}
d := baseBackoff
for range retryCount - 1 {
d *= 2
}
return d
}
// shouldSkipRetry returns true if a task has been retried and the backoff
// period has not yet elapsed since the last attempt.
func shouldSkipRetry(task *Task, now time.Time) bool {
if task.RetryCount <= 0 {
return false
}
if task.LastAttempt == nil {
return false
}
return task.LastAttempt.Add(backoffDuration(task.RetryCount)).After(now)
}
// effectiveMaxRetries returns the max retries for a task, using DefaultMaxRetries
// when the task does not specify one.
func effectiveMaxRetries(task *Task) int {
if task.MaxRetries > 0 {
return task.MaxRetries
}
return DefaultMaxRetries
}
// DispatchLoop polls for pending tasks at the given interval and dispatches
// each one. Tasks are sorted by priority (Critical > High > Medium > Low) with
// oldest-first tie-breaking. Failed dispatches are retried with exponential
// backoff. Tasks exceeding their retry limit are dead-lettered with StatusFailed.
// It runs until the context is cancelled and returns ctx.Err().
func (d *Dispatcher) DispatchLoop(ctx context.Context, interval time.Duration) error {
const op = "Dispatcher.DispatchLoop"
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if d.client == nil {
continue
}
tasks, err := d.client.ListTasks(ctx, ListOptions{Status: StatusPending})
if err != nil {
// Log but continue — transient API errors should not stop the loop.
_ = log.E(op, "failed to list pending tasks", err)
continue
}
// Sort by priority then by creation time.
sortTasksByPriority(tasks)
now := time.Now().UTC()
for i := range tasks {
if ctx.Err() != nil {
return ctx.Err()
}
task := &tasks[i]
// Check if backoff period has not elapsed for retried tasks.
if shouldSkipRetry(task, now) {
continue
}
if _, err := d.Dispatch(ctx, task); err != nil {
// Increment retry count and record the attempt time.
task.RetryCount++
attemptTime := now
task.LastAttempt = &attemptTime
maxRetries := effectiveMaxRetries(task)
if task.RetryCount >= maxRetries {
// Dead-letter: mark as failed via the API.
if updateErr := d.client.UpdateTask(ctx, task.ID, TaskUpdate{
Status: StatusFailed,
Notes: "max retries exceeded",
}); updateErr != nil {
_ = log.E(op, "failed to dead-letter task "+task.ID, updateErr)
}
d.emit(ctx, Event{
Type: EventTaskDeadLettered,
TaskID: task.ID,
Payload: "max retries exceeded",
})
} else {
_ = log.E(op, "failed to dispatch task "+task.ID, err)
}
}
}
}
}
}

View file

@ -1,578 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// setupDispatcher creates a Dispatcher with a memory registry, default router,
// and memory allowance store, pre-loaded with agents and allowances.
func setupDispatcher(t *testing.T, client *Client) (*Dispatcher, *MemoryRegistry, *MemoryStore) {
t.Helper()
reg := NewMemoryRegistry()
router := NewDefaultRouter()
store := NewMemoryStore()
svc := NewAllowanceService(store)
d := NewDispatcher(reg, router, svc, client)
return d, reg, store
}
func registerAgent(t *testing.T, reg *MemoryRegistry, store *MemoryStore, id string, caps []string, maxLoad int) {
t.Helper()
_ = reg.Register(AgentInfo{
ID: id,
Name: id,
Capabilities: caps,
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
MaxLoad: maxLoad,
})
_ = store.SetAllowance(&AgentAllowance{
AgentID: id,
DailyTokenLimit: 100000,
DailyJobLimit: 50,
ConcurrentJobs: 5,
})
}
// --- Dispatch tests ---
func TestDispatcher_Dispatch_Good_NilClient(t *testing.T) {
d, reg, store := setupDispatcher(t, nil)
registerAgent(t, reg, store, "agent-1", []string{"go"}, 5)
task := &Task{ID: "task-1", Labels: []string{"go"}, Priority: PriorityMedium}
agentID, err := d.Dispatch(context.Background(), task)
require.NoError(t, err)
assert.Equal(t, "agent-1", agentID)
// Verify usage was recorded.
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, 1, usage.JobsStarted)
assert.Equal(t, 1, usage.ActiveJobs)
}
func TestDispatcher_Dispatch_Good_WithHTTPClient(t *testing.T) {
claimedTask := Task{ID: "task-1", Status: StatusInProgress, ClaimedBy: "agent-1"}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && r.URL.Path == "/api/tasks/task-1/claim" {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimedTask})
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
d, reg, store := setupDispatcher(t, client)
registerAgent(t, reg, store, "agent-1", nil, 5)
task := &Task{ID: "task-1", Priority: PriorityHigh}
agentID, err := d.Dispatch(context.Background(), task)
require.NoError(t, err)
assert.Equal(t, "agent-1", agentID)
// Verify usage recorded.
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, 1, usage.JobsStarted)
}
func TestDispatcher_Dispatch_Good_PicksBestAgent(t *testing.T) {
d, reg, store := setupDispatcher(t, nil)
registerAgent(t, reg, store, "heavy", []string{"go"}, 5)
registerAgent(t, reg, store, "light", []string{"go"}, 5)
// Give "heavy" some load.
_ = reg.Register(AgentInfo{
ID: "heavy",
Name: "heavy",
Capabilities: []string{"go"},
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
CurrentLoad: 4,
MaxLoad: 5,
})
task := &Task{ID: "task-1", Labels: []string{"go"}, Priority: PriorityMedium}
agentID, err := d.Dispatch(context.Background(), task)
require.NoError(t, err)
assert.Equal(t, "light", agentID) // light has score 1.0, heavy has 0.2
}
func TestDispatcher_Dispatch_Bad_NoAgents(t *testing.T) {
d, _, _ := setupDispatcher(t, nil)
task := &Task{ID: "task-1", Priority: PriorityMedium}
_, err := d.Dispatch(context.Background(), task)
require.Error(t, err)
}
func TestDispatcher_Dispatch_Bad_AllowanceExceeded(t *testing.T) {
d, reg, store := setupDispatcher(t, nil)
registerAgent(t, reg, store, "agent-1", nil, 5)
// Exhaust the agent's daily job limit.
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyJobLimit: 1,
})
_ = store.IncrementUsage("agent-1", 0, 1)
task := &Task{ID: "task-1", Priority: PriorityMedium}
_, err := d.Dispatch(context.Background(), task)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota exceeded")
}
func TestDispatcher_Dispatch_Bad_ClaimFails(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusConflict)
_ = json.NewEncoder(w).Encode(APIError{Code: 409, Message: "already claimed"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
d, reg, store := setupDispatcher(t, client)
registerAgent(t, reg, store, "agent-1", nil, 5)
task := &Task{ID: "task-1", Priority: PriorityMedium}
_, err := d.Dispatch(context.Background(), task)
require.Error(t, err)
assert.Contains(t, err.Error(), "claim task")
// Verify usage was NOT recorded when claim fails.
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, 0, usage.JobsStarted)
}
// --- DispatchLoop tests ---
func TestDispatcher_DispatchLoop_Good_Cancellation(t *testing.T) {
d, _, _ := setupDispatcher(t, nil)
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately.
err := d.DispatchLoop(ctx, 100*time.Millisecond)
require.ErrorIs(t, err, context.Canceled)
}
func TestDispatcher_DispatchLoop_Good_DispatchesPendingTasks(t *testing.T) {
pendingTasks := []Task{
{ID: "task-1", Status: StatusPending, Priority: PriorityMedium},
{ID: "task-2", Status: StatusPending, Priority: PriorityHigh},
}
var mu sync.Mutex
claimedIDs := make(map[string]bool)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/tasks":
w.Header().Set("Content-Type", "application/json")
mu.Lock()
// Return only tasks not yet claimed.
var remaining []Task
for _, t := range pendingTasks {
if !claimedIDs[t.ID] {
remaining = append(remaining, t)
}
}
mu.Unlock()
_ = json.NewEncoder(w).Encode(remaining)
case r.Method == http.MethodPost:
// Extract task ID from claim URL.
w.Header().Set("Content-Type", "application/json")
// Parse the task ID from the path.
for _, t := range pendingTasks {
if r.URL.Path == "/api/tasks/"+t.ID+"/claim" {
mu.Lock()
claimedIDs[t.ID] = true
mu.Unlock()
claimed := t
claimed.Status = StatusInProgress
_ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimed})
return
}
}
w.WriteHeader(http.StatusNotFound)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
d, reg, store := setupDispatcher(t, client)
registerAgent(t, reg, store, "agent-1", nil, 10)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
err := d.DispatchLoop(ctx, 50*time.Millisecond)
require.ErrorIs(t, err, context.DeadlineExceeded)
// Verify tasks were claimed.
mu.Lock()
defer mu.Unlock()
assert.True(t, claimedIDs["task-1"])
assert.True(t, claimedIDs["task-2"])
}
func TestDispatcher_DispatchLoop_Good_NilClientSkipsTick(t *testing.T) {
d, _, _ := setupDispatcher(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
err := d.DispatchLoop(ctx, 50*time.Millisecond)
require.ErrorIs(t, err, context.DeadlineExceeded)
// No panics — nil client is handled gracefully.
}
// --- Concurrent dispatch ---
func TestDispatcher_Dispatch_Good_Concurrent(t *testing.T) {
d, reg, store := setupDispatcher(t, nil)
registerAgent(t, reg, store, "agent-1", nil, 0)
// Override allowance to truly unlimited (registerAgent hardcodes ConcurrentJobs: 5)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyJobLimit: 100,
ConcurrentJobs: 0, // 0 = unlimited
})
var wg sync.WaitGroup
for i := range 10 {
wg.Add(1)
go func(n int) {
defer wg.Done()
task := &Task{ID: "task-" + string(rune('a'+n)), Priority: PriorityMedium}
_, _ = d.Dispatch(context.Background(), task)
}(i)
}
wg.Wait()
// Verify usage was recorded for all dispatches.
usage, _ := store.GetUsage("agent-1")
assert.Equal(t, 10, usage.JobsStarted)
}
// --- Phase 7: Priority sorting tests ---
func TestSortTasksByPriority_Good(t *testing.T) {
base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
tasks := []Task{
{ID: "low-old", Priority: PriorityLow, CreatedAt: base},
{ID: "critical-new", Priority: PriorityCritical, CreatedAt: base.Add(2 * time.Hour)},
{ID: "medium-old", Priority: PriorityMedium, CreatedAt: base},
{ID: "high-old", Priority: PriorityHigh, CreatedAt: base},
{ID: "critical-old", Priority: PriorityCritical, CreatedAt: base},
}
sortTasksByPriority(tasks)
// Critical tasks first, oldest critical before newer critical.
assert.Equal(t, "critical-old", tasks[0].ID)
assert.Equal(t, "critical-new", tasks[1].ID)
// Then high.
assert.Equal(t, "high-old", tasks[2].ID)
// Then medium.
assert.Equal(t, "medium-old", tasks[3].ID)
// Then low.
assert.Equal(t, "low-old", tasks[4].ID)
}
// --- Phase 7: Backoff duration tests ---
func TestBackoffDuration_Good(t *testing.T) {
// retryCount=0 → 0 (no backoff).
assert.Equal(t, time.Duration(0), backoffDuration(0))
// retryCount=1 → 5s (base).
assert.Equal(t, 5*time.Second, backoffDuration(1))
// retryCount=2 → 10s.
assert.Equal(t, 10*time.Second, backoffDuration(2))
// retryCount=3 → 20s.
assert.Equal(t, 20*time.Second, backoffDuration(3))
// retryCount=4 → 40s.
assert.Equal(t, 40*time.Second, backoffDuration(4))
}
// --- Phase 7: shouldSkipRetry tests ---
func TestShouldSkipRetry_Good(t *testing.T) {
now := time.Now().UTC()
recent := now.Add(-2 * time.Second) // 2s ago, backoff for retry 1 is 5s → skip.
task := &Task{
ID: "task-1",
RetryCount: 1,
LastAttempt: &recent,
}
assert.True(t, shouldSkipRetry(task, now))
// After backoff elapses, should NOT skip.
old := now.Add(-10 * time.Second) // 10s ago, backoff for retry 1 is 5s → ready.
task.LastAttempt = &old
assert.False(t, shouldSkipRetry(task, now))
}
func TestShouldSkipRetry_Bad_NoRetry(t *testing.T) {
now := time.Now().UTC()
// RetryCount=0 → never skip.
task := &Task{ID: "task-1", RetryCount: 0}
assert.False(t, shouldSkipRetry(task, now))
// RetryCount=0 even with a LastAttempt set → never skip.
recent := now.Add(-1 * time.Second)
task.LastAttempt = &recent
assert.False(t, shouldSkipRetry(task, now))
// RetryCount>0 but nil LastAttempt → never skip.
task2 := &Task{ID: "task-2", RetryCount: 2, LastAttempt: nil}
assert.False(t, shouldSkipRetry(task2, now))
}
// --- Phase 7: DispatchLoop priority order test ---
func TestDispatcher_DispatchLoop_Good_PriorityOrder(t *testing.T) {
base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
pendingTasks := []Task{
{ID: "low-1", Status: StatusPending, Priority: PriorityLow, CreatedAt: base},
{ID: "critical-1", Status: StatusPending, Priority: PriorityCritical, CreatedAt: base},
{ID: "medium-1", Status: StatusPending, Priority: PriorityMedium, CreatedAt: base},
{ID: "high-1", Status: StatusPending, Priority: PriorityHigh, CreatedAt: base},
{ID: "critical-2", Status: StatusPending, Priority: PriorityCritical, CreatedAt: base.Add(time.Second)},
}
var mu sync.Mutex
var claimOrder []string
claimedIDs := make(map[string]bool)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/tasks":
w.Header().Set("Content-Type", "application/json")
mu.Lock()
var remaining []Task
for _, tk := range pendingTasks {
if !claimedIDs[tk.ID] {
remaining = append(remaining, tk)
}
}
mu.Unlock()
_ = json.NewEncoder(w).Encode(remaining)
case r.Method == http.MethodPost:
w.Header().Set("Content-Type", "application/json")
for _, tk := range pendingTasks {
if r.URL.Path == "/api/tasks/"+tk.ID+"/claim" {
mu.Lock()
claimedIDs[tk.ID] = true
claimOrder = append(claimOrder, tk.ID)
mu.Unlock()
claimed := tk
claimed.Status = StatusInProgress
_ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimed})
return
}
}
w.WriteHeader(http.StatusNotFound)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
d, reg, store := setupDispatcher(t, client)
registerAgent(t, reg, store, "agent-1", nil, 10)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
err := d.DispatchLoop(ctx, 50*time.Millisecond)
require.ErrorIs(t, err, context.DeadlineExceeded)
mu.Lock()
defer mu.Unlock()
// All 5 tasks should have been claimed.
require.Len(t, claimOrder, 5)
// Critical tasks first (oldest before newest), then high, medium, low.
assert.Equal(t, "critical-1", claimOrder[0])
assert.Equal(t, "critical-2", claimOrder[1])
assert.Equal(t, "high-1", claimOrder[2])
assert.Equal(t, "medium-1", claimOrder[3])
assert.Equal(t, "low-1", claimOrder[4])
}
// --- Phase 7: DispatchLoop retry backoff test ---
func TestDispatcher_DispatchLoop_Good_RetryBackoff(t *testing.T) {
// A task with a recent LastAttempt and RetryCount=1 should be skipped
// because the backoff period (5s) has not elapsed.
recentAttempt := time.Now().UTC()
pendingTasks := []Task{
{
ID: "retrying-task",
Status: StatusPending,
Priority: PriorityHigh,
CreatedAt: time.Now().UTC().Add(-time.Hour),
RetryCount: 1,
LastAttempt: &recentAttempt,
},
{
ID: "fresh-task",
Status: StatusPending,
Priority: PriorityLow,
CreatedAt: time.Now().UTC(),
},
}
var mu sync.Mutex
claimedIDs := make(map[string]bool)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/tasks":
w.Header().Set("Content-Type", "application/json")
mu.Lock()
var remaining []Task
for _, tk := range pendingTasks {
if !claimedIDs[tk.ID] {
remaining = append(remaining, tk)
}
}
mu.Unlock()
_ = json.NewEncoder(w).Encode(remaining)
case r.Method == http.MethodPost:
w.Header().Set("Content-Type", "application/json")
for _, tk := range pendingTasks {
if r.URL.Path == "/api/tasks/"+tk.ID+"/claim" {
mu.Lock()
claimedIDs[tk.ID] = true
mu.Unlock()
claimed := tk
claimed.Status = StatusInProgress
_ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimed})
return
}
}
w.WriteHeader(http.StatusNotFound)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
d, reg, store := setupDispatcher(t, client)
registerAgent(t, reg, store, "agent-1", nil, 10)
// Run the loop for a short period — not long enough for the 5s backoff to elapse.
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
defer cancel()
err := d.DispatchLoop(ctx, 50*time.Millisecond)
require.ErrorIs(t, err, context.DeadlineExceeded)
mu.Lock()
defer mu.Unlock()
// The fresh task should have been claimed.
assert.True(t, claimedIDs["fresh-task"])
// The retrying task should NOT have been claimed because backoff has not elapsed.
assert.False(t, claimedIDs["retrying-task"])
}
// --- Phase 7: DispatchLoop dead-letter test ---
func TestDispatcher_DispatchLoop_Good_DeadLetter(t *testing.T) {
// A task with RetryCount at MaxRetries-1 that fails dispatch should be dead-lettered.
pendingTasks := []Task{
{
ID: "doomed-task",
Status: StatusPending,
Priority: PriorityHigh,
CreatedAt: time.Now().UTC().Add(-time.Hour),
MaxRetries: 1, // Will fail after 1 attempt.
RetryCount: 0,
},
}
var mu sync.Mutex
var deadLettered bool
var deadLetterNotes string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/tasks":
w.Header().Set("Content-Type", "application/json")
mu.Lock()
done := deadLettered
mu.Unlock()
if done {
// Return empty list once dead-lettered.
_ = json.NewEncoder(w).Encode([]Task{})
} else {
_ = json.NewEncoder(w).Encode(pendingTasks)
}
case r.Method == http.MethodPost && r.URL.Path == "/api/tasks/doomed-task/claim":
// Claim always fails to trigger retry logic.
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(APIError{Code: 500, Message: "server error"})
case r.Method == http.MethodPatch && r.URL.Path == "/api/tasks/doomed-task":
// This is the UpdateTask call for dead-lettering.
var update TaskUpdate
_ = json.NewDecoder(r.Body).Decode(&update)
mu.Lock()
deadLettered = true
deadLetterNotes = update.Notes
mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
d, reg, store := setupDispatcher(t, client)
registerAgent(t, reg, store, "agent-1", nil, 10)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
err := d.DispatchLoop(ctx, 50*time.Millisecond)
require.ErrorIs(t, err, context.DeadlineExceeded)
mu.Lock()
defer mu.Unlock()
assert.True(t, deadLettered, "task should have been dead-lettered")
assert.Equal(t, "max retries exceeded", deadLetterNotes)
}

View file

@ -1,19 +0,0 @@
package lifecycle
import (
"embed"
"strings"
)
//go:embed prompts/*.md
var promptsFS embed.FS
// Prompt returns the content of an embedded prompt file.
// Name should be without the .md extension (e.g., "commit").
func Prompt(name string) string {
data, err := promptsFS.ReadFile("prompts/" + name + ".md")
if err != nil {
return ""
}
return strings.TrimSpace(string(data))
}

View file

@ -1,26 +0,0 @@
package lifecycle
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPrompt_Good_CommitExists(t *testing.T) {
content := Prompt("commit")
assert.NotEmpty(t, content, "commit prompt should exist")
assert.Contains(t, content, "Commit")
}
func TestPrompt_Bad_NonexistentReturnsEmpty(t *testing.T) {
content := Prompt("nonexistent-prompt-that-does-not-exist")
assert.Empty(t, content, "nonexistent prompt should return empty string")
}
func TestPrompt_Good_ContentIsTrimmed(t *testing.T) {
content := Prompt("commit")
// Should not start or end with whitespace.
assert.Equal(t, content[0:1] != " " && content[0:1] != "\n", true, "should not start with whitespace")
lastChar := content[len(content)-1:]
assert.Equal(t, lastChar != " " && lastChar != "\n", true, "should not end with whitespace")
}

View file

@ -1,114 +0,0 @@
package lifecycle
import (
"context"
"sync"
"time"
)
// EventType identifies the kind of lifecycle event.
type EventType string
const (
// EventTaskDispatched is emitted when a task is successfully routed and claimed.
EventTaskDispatched EventType = "task_dispatched"
// EventTaskClaimed is emitted when a task claim succeeds via the API client.
EventTaskClaimed EventType = "task_claimed"
// EventDispatchFailedNoAgent is emitted when no eligible agent is available.
EventDispatchFailedNoAgent EventType = "dispatch_failed_no_agent"
// EventDispatchFailedQuota is emitted when an agent's quota is exceeded.
EventDispatchFailedQuota EventType = "dispatch_failed_quota"
// EventTaskDeadLettered is emitted when a task exceeds its retry limit.
EventTaskDeadLettered EventType = "task_dead_lettered"
// EventQuotaWarning is emitted when an agent reaches 80%+ quota usage.
EventQuotaWarning EventType = "quota_warning"
// EventQuotaExceeded is emitted when an agent exceeds their quota.
EventQuotaExceeded EventType = "quota_exceeded"
// EventUsageRecorded is emitted when usage is recorded for an agent.
EventUsageRecorded EventType = "usage_recorded"
)
// Event represents a lifecycle event in the agentic system.
type Event struct {
// Type identifies what happened.
Type EventType `json:"type"`
// TaskID is the task involved, if any.
TaskID string `json:"task_id,omitempty"`
// AgentID is the agent involved, if any.
AgentID string `json:"agent_id,omitempty"`
// Timestamp is when the event occurred.
Timestamp time.Time `json:"timestamp"`
// Payload carries additional event-specific data.
Payload any `json:"payload,omitempty"`
}
// EventEmitter is the interface for publishing lifecycle events.
type EventEmitter interface {
// Emit publishes an event. Implementations should be non-blocking.
Emit(ctx context.Context, event Event) error
}
// ChannelEmitter is an in-process EventEmitter backed by a buffered channel.
// Events are dropped (not blocked) when the buffer is full.
type ChannelEmitter struct {
ch chan Event
}
// NewChannelEmitter creates a ChannelEmitter with the given buffer size.
func NewChannelEmitter(bufSize int) *ChannelEmitter {
if bufSize < 1 {
bufSize = 64
}
return &ChannelEmitter{ch: make(chan Event, bufSize)}
}
// Emit sends an event to the channel. If the buffer is full, the event is
// dropped silently to avoid blocking the dispatch path.
func (e *ChannelEmitter) Emit(_ context.Context, event Event) error {
select {
case e.ch <- event:
default:
// Buffer full — drop the event rather than blocking.
}
return nil
}
// Events returns the underlying channel for consumers to read from.
func (e *ChannelEmitter) Events() <-chan Event {
return e.ch
}
// Close closes the underlying channel, signalling consumers to stop reading.
func (e *ChannelEmitter) Close() {
close(e.ch)
}
// MultiEmitter fans out events to multiple emitters. Emission continues even
// if one emitter fails — errors are collected but not returned.
type MultiEmitter struct {
mu sync.RWMutex
emitters []EventEmitter
}
// NewMultiEmitter creates a MultiEmitter that fans out to the given emitters.
func NewMultiEmitter(emitters ...EventEmitter) *MultiEmitter {
return &MultiEmitter{emitters: emitters}
}
// Emit sends the event to all registered emitters. Non-blocking: each emitter
// is called in sequence but ChannelEmitter.Emit is itself non-blocking.
func (m *MultiEmitter) Emit(ctx context.Context, event Event) error {
m.mu.RLock()
defer m.mu.RUnlock()
for _, em := range m.emitters {
_ = em.Emit(ctx, event)
}
return nil
}
// Add appends an emitter to the fan-out list.
func (m *MultiEmitter) Add(emitter EventEmitter) {
m.mu.Lock()
defer m.mu.Unlock()
m.emitters = append(m.emitters, emitter)
}

View file

@ -1,283 +0,0 @@
package lifecycle
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Dispatcher event emission tests ---
func TestDispatcher_EmitsTaskDispatched(t *testing.T) {
em := NewChannelEmitter(10)
d, reg, store := setupDispatcher(t, nil)
d.SetEventEmitter(em)
registerAgent(t, reg, store, "agent-1", []string{"go"}, 5)
task := &Task{ID: "t1", Labels: []string{"go"}}
agentID, err := d.Dispatch(context.Background(), task)
require.NoError(t, err)
assert.Equal(t, "agent-1", agentID)
// Should have received EventTaskDispatched.
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventTaskDispatched, got[0].Type)
assert.Equal(t, "t1", got[0].TaskID)
assert.Equal(t, "agent-1", got[0].AgentID)
}
func TestDispatcher_EmitsDispatchFailedNoAgent(t *testing.T) {
em := NewChannelEmitter(10)
d, _, _ := setupDispatcher(t, nil)
d.SetEventEmitter(em)
// No agents registered.
task := &Task{ID: "t2", Labels: []string{"go"}}
_, err := d.Dispatch(context.Background(), task)
require.Error(t, err)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventDispatchFailedNoAgent, got[0].Type)
assert.Equal(t, "t2", got[0].TaskID)
}
func TestDispatcher_EmitsDispatchFailedQuota(t *testing.T) {
em := NewChannelEmitter(10)
d, reg, store := setupDispatcher(t, nil)
d.SetEventEmitter(em)
// Register agent with zero daily job limit (will be exceeded immediately).
_ = reg.Register(AgentInfo{
ID: "agent-q", Name: "agent-q", Capabilities: []string{"go"},
Status: AgentAvailable, LastHeartbeat: time.Now().UTC(), MaxLoad: 5,
})
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-q",
DailyJobLimit: 1,
ConcurrentJobs: 5,
})
// Use up the single job.
_ = store.IncrementUsage("agent-q", 0, 1)
task := &Task{ID: "t3", Labels: []string{"go"}}
_, err := d.Dispatch(context.Background(), task)
require.Error(t, err)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventDispatchFailedQuota, got[0].Type)
assert.Equal(t, "t3", got[0].TaskID)
assert.Equal(t, "agent-q", got[0].AgentID)
}
func TestDispatcher_NoEventsWithoutEmitter(t *testing.T) {
// Verify no panic when emitter is nil.
d, reg, store := setupDispatcher(t, nil)
registerAgent(t, reg, store, "agent-1", []string{"go"}, 5)
task := &Task{ID: "t4", Labels: []string{"go"}}
_, err := d.Dispatch(context.Background(), task)
require.NoError(t, err)
// No panic = pass.
}
// --- AllowanceService event emission tests ---
func TestAllowanceService_EmitsQuotaExceeded(t *testing.T) {
em := NewChannelEmitter(10)
store := NewMemoryStore()
svc := NewAllowanceService(store)
svc.SetEventEmitter(em)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100,
})
// Use all tokens.
_ = store.IncrementUsage("agent-1", 100, 0)
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventQuotaExceeded, got[0].Type)
assert.Equal(t, "agent-1", got[0].AgentID)
}
func TestAllowanceService_EmitsQuotaWarning(t *testing.T) {
em := NewChannelEmitter(10)
store := NewMemoryStore()
svc := NewAllowanceService(store)
svc.SetEventEmitter(em)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100,
})
// Use 85% of tokens — should trigger warning.
_ = store.IncrementUsage("agent-1", 85, 0)
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.True(t, result.Allowed)
assert.Equal(t, AllowanceWarning, result.Status)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventQuotaWarning, got[0].Type)
assert.Equal(t, "agent-1", got[0].AgentID)
}
func TestAllowanceService_EmitsUsageRecorded(t *testing.T) {
em := NewChannelEmitter(10)
store := NewMemoryStore()
svc := NewAllowanceService(store)
svc.SetEventEmitter(em)
_ = store.SetAllowance(&AgentAllowance{AgentID: "agent-1"})
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Event: QuotaEventJobStarted,
Timestamp: time.Now().UTC(),
})
require.NoError(t, err)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventUsageRecorded, got[0].Type)
assert.Equal(t, "agent-1", got[0].AgentID)
}
func TestAllowanceService_EmitsUsageRecordedOnCompletion(t *testing.T) {
em := NewChannelEmitter(10)
store := NewMemoryStore()
svc := NewAllowanceService(store)
svc.SetEventEmitter(em)
_ = store.SetAllowance(&AgentAllowance{AgentID: "agent-1"})
// Start a job first.
_ = store.IncrementUsage("agent-1", 0, 1)
err := svc.RecordUsage(UsageReport{
AgentID: "agent-1",
JobID: "job-1",
Model: "claude-sonnet",
TokensIn: 500,
TokensOut: 200,
Event: QuotaEventJobCompleted,
Timestamp: time.Now().UTC(),
})
require.NoError(t, err)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventUsageRecorded, got[0].Type)
}
func TestAllowanceService_QuotaExceededOnJobLimit(t *testing.T) {
em := NewChannelEmitter(10)
store := NewMemoryStore()
svc := NewAllowanceService(store)
svc.SetEventEmitter(em)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyJobLimit: 2,
})
_ = store.IncrementUsage("agent-1", 0, 2)
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventQuotaExceeded, got[0].Type)
assert.Contains(t, got[0].Payload, "daily job limit")
}
func TestAllowanceService_QuotaExceededOnConcurrent(t *testing.T) {
em := NewChannelEmitter(10)
store := NewMemoryStore()
svc := NewAllowanceService(store)
svc.SetEventEmitter(em)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
ConcurrentJobs: 1,
})
_ = store.IncrementUsage("agent-1", 0, 1)
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventQuotaExceeded, got[0].Type)
assert.Contains(t, got[0].Payload, "concurrent")
}
func TestAllowanceService_QuotaExceededOnModelAllowlist(t *testing.T) {
em := NewChannelEmitter(10)
store := NewMemoryStore()
svc := NewAllowanceService(store)
svc.SetEventEmitter(em)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
ModelAllowlist: []string{"claude-sonnet"},
})
result, err := svc.Check("agent-1", "gpt-4")
require.NoError(t, err)
assert.False(t, result.Allowed)
got := drainEvents(em, 1, time.Second)
require.Len(t, got, 1)
assert.Equal(t, EventQuotaExceeded, got[0].Type)
assert.Contains(t, got[0].Payload, "allowlist")
}
func TestAllowanceService_NoEventsWithoutEmitter(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
// No emitter set.
_ = store.SetAllowance(&AgentAllowance{
AgentID: "agent-1",
DailyTokenLimit: 100,
})
_ = store.IncrementUsage("agent-1", 100, 0)
result, err := svc.Check("agent-1", "")
require.NoError(t, err)
assert.False(t, result.Allowed)
// No panic = pass.
}
// --- Helpers ---
// drainEvents reads up to n events from the emitter within the timeout.
func drainEvents(em *ChannelEmitter, n int, timeout time.Duration) []Event {
var events []Event
deadline := time.After(timeout)
for range n {
select {
case e := <-em.Events():
events = append(events, e)
case <-deadline:
return events
}
}
return events
}

View file

@ -1,153 +0,0 @@
package lifecycle
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestChannelEmitter_EmitAndReceive(t *testing.T) {
em := NewChannelEmitter(10)
ctx := context.Background()
event := Event{
Type: EventTaskDispatched,
TaskID: "task-1",
AgentID: "agent-1",
Timestamp: time.Now().UTC(),
Payload: "test payload",
}
err := em.Emit(ctx, event)
require.NoError(t, err)
select {
case got := <-em.Events():
assert.Equal(t, EventTaskDispatched, got.Type)
assert.Equal(t, "task-1", got.TaskID)
assert.Equal(t, "agent-1", got.AgentID)
assert.Equal(t, "test payload", got.Payload)
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
}
func TestChannelEmitter_BufferOverflowDrops(t *testing.T) {
em := NewChannelEmitter(2)
ctx := context.Background()
// Fill the buffer.
require.NoError(t, em.Emit(ctx, Event{Type: EventTaskDispatched, TaskID: "1"}))
require.NoError(t, em.Emit(ctx, Event{Type: EventTaskDispatched, TaskID: "2"}))
// Third event should be dropped, not block.
err := em.Emit(ctx, Event{Type: EventTaskDispatched, TaskID: "3"})
require.NoError(t, err)
// Only 2 events in the channel.
assert.Len(t, em.ch, 2)
}
func TestChannelEmitter_DefaultBufferSize(t *testing.T) {
em := NewChannelEmitter(0)
assert.Equal(t, 64, cap(em.ch))
}
func TestMultiEmitter_FanOut(t *testing.T) {
em1 := NewChannelEmitter(10)
em2 := NewChannelEmitter(10)
multi := NewMultiEmitter(em1, em2)
ctx := context.Background()
event := Event{
Type: EventQuotaWarning,
AgentID: "agent-x",
}
err := multi.Emit(ctx, event)
require.NoError(t, err)
// Both emitters should have received the event.
select {
case got := <-em1.Events():
assert.Equal(t, EventQuotaWarning, got.Type)
case <-time.After(time.Second):
t.Fatal("em1: timed out")
}
select {
case got := <-em2.Events():
assert.Equal(t, EventQuotaWarning, got.Type)
case <-time.After(time.Second):
t.Fatal("em2: timed out")
}
}
func TestMultiEmitter_Add(t *testing.T) {
em1 := NewChannelEmitter(10)
multi := NewMultiEmitter(em1)
ctx := context.Background()
em2 := NewChannelEmitter(10)
multi.Add(em2)
err := multi.Emit(ctx, Event{Type: EventUsageRecorded})
require.NoError(t, err)
assert.Len(t, em1.ch, 1)
assert.Len(t, em2.ch, 1)
}
func TestMultiEmitter_ContinuesOnFailure(t *testing.T) {
failing := &failingEmitter{}
good := NewChannelEmitter(10)
multi := NewMultiEmitter(failing, good)
ctx := context.Background()
err := multi.Emit(ctx, Event{Type: EventTaskClaimed})
require.NoError(t, err) // MultiEmitter swallows errors.
// The good emitter should still have received the event.
assert.Len(t, good.ch, 1)
}
func TestChannelEmitter_ConcurrentEmit(t *testing.T) {
em := NewChannelEmitter(100)
ctx := context.Background()
var wg sync.WaitGroup
for range 50 {
wg.Go(func() {
_ = em.Emit(ctx, Event{Type: EventTaskDispatched})
})
}
wg.Wait()
assert.Equal(t, 50, len(em.ch))
}
func TestEventTypes_AllDefined(t *testing.T) {
types := []EventType{
EventTaskDispatched,
EventTaskClaimed,
EventDispatchFailedNoAgent,
EventDispatchFailedQuota,
EventTaskDeadLettered,
EventQuotaWarning,
EventQuotaExceeded,
EventUsageRecorded,
}
for _, et := range types {
assert.NotEmpty(t, string(et))
}
}
// failingEmitter always returns an error.
type failingEmitter struct{}
func (f *failingEmitter) Emit(_ context.Context, _ Event) error {
return &APIError{Code: 500, Message: "emitter failed"}
}

View file

@ -1,302 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestTaskLifecycle_ClaimProcessComplete tests the full task lifecycle:
// claim a pending task, check allowance, record usage events, complete the task.
func TestTaskLifecycle_ClaimProcessComplete(t *testing.T) {
// Set up allowance infrastructure.
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "lifecycle-agent",
DailyTokenLimit: 100000,
DailyJobLimit: 10,
ConcurrentJobs: 3,
})
// Phase 1: Pre-dispatch allowance check should pass.
check, err := svc.Check("lifecycle-agent", "")
require.NoError(t, err)
assert.True(t, check.Allowed)
assert.Equal(t, AllowanceOK, check.Status)
// Phase 2: Simulate claiming a task via the HTTP client.
pendingTask := Task{
ID: "lifecycle-001",
Title: "Full lifecycle test",
Priority: PriorityHigh,
Status: StatusPending,
Project: "core",
}
claimedTask := pendingTask
claimedTask.Status = StatusInProgress
claimedTask.ClaimedBy = "lifecycle-agent"
now := time.Now().UTC()
claimedTask.ClaimedAt = &now
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/api/tasks/lifecycle-001/claim":
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimedTask})
case r.Method == http.MethodPatch && r.URL.Path == "/api/tasks/lifecycle-001":
w.WriteHeader(http.StatusOK)
case r.Method == http.MethodPost && r.URL.Path == "/api/tasks/lifecycle-001/complete":
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.AgentID = "lifecycle-agent"
// Claim the task.
claimed, err := client.ClaimTask(context.Background(), "lifecycle-001")
require.NoError(t, err)
assert.Equal(t, StatusInProgress, claimed.Status)
assert.Equal(t, "lifecycle-agent", claimed.ClaimedBy)
// Phase 3: Record job start in the allowance system.
err = svc.RecordUsage(UsageReport{
AgentID: "lifecycle-agent",
JobID: "lifecycle-001",
Event: QuotaEventJobStarted,
})
require.NoError(t, err)
usage, _ := store.GetUsage("lifecycle-agent")
assert.Equal(t, 1, usage.ActiveJobs)
assert.Equal(t, 1, usage.JobsStarted)
// Phase 4: Update task progress.
err = client.UpdateTask(context.Background(), "lifecycle-001", TaskUpdate{
Status: StatusInProgress,
Progress: 50,
Notes: "Halfway through",
})
require.NoError(t, err)
// Phase 5: Record job completion with token usage.
err = svc.RecordUsage(UsageReport{
AgentID: "lifecycle-agent",
JobID: "lifecycle-001",
Model: "claude-sonnet",
TokensIn: 5000,
TokensOut: 3000,
Event: QuotaEventJobCompleted,
})
require.NoError(t, err)
usage, _ = store.GetUsage("lifecycle-agent")
assert.Equal(t, 0, usage.ActiveJobs)
assert.Equal(t, int64(8000), usage.TokensUsed)
// Phase 6: Complete the task via the API.
err = client.CompleteTask(context.Background(), "lifecycle-001", TaskResult{
Success: true,
Output: "Task completed successfully",
Artifacts: []string{"output.go"},
})
require.NoError(t, err)
// Phase 7: Verify allowance is still within limits.
check, err = svc.Check("lifecycle-agent", "")
require.NoError(t, err)
assert.True(t, check.Allowed)
assert.Equal(t, AllowanceOK, check.Status)
assert.Equal(t, int64(92000), check.RemainingTokens)
assert.Equal(t, 9, check.RemainingJobs)
}
// TestTaskLifecycle_ClaimProcessFail tests the lifecycle when a job fails
// and verifies that 50% of tokens are returned.
func TestTaskLifecycle_ClaimProcessFail(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "fail-agent",
DailyTokenLimit: 50000,
DailyJobLimit: 5,
ConcurrentJobs: 2,
})
// Start job.
err := svc.RecordUsage(UsageReport{
AgentID: "fail-agent",
JobID: "fail-001",
Event: QuotaEventJobStarted,
})
require.NoError(t, err)
// Job fails with 10000 tokens consumed.
err = svc.RecordUsage(UsageReport{
AgentID: "fail-agent",
JobID: "fail-001",
Model: "claude-sonnet",
TokensIn: 6000,
TokensOut: 4000,
Event: QuotaEventJobFailed,
})
require.NoError(t, err)
// Verify 50% returned: 10000 charged, 5000 returned = 5000 net.
usage, _ := store.GetUsage("fail-agent")
assert.Equal(t, int64(5000), usage.TokensUsed)
assert.Equal(t, 0, usage.ActiveJobs)
// Verify model usage is net: 10000 - 5000 = 5000.
modelUsage, _ := store.GetModelUsage("claude-sonnet")
assert.Equal(t, int64(5000), modelUsage)
// Check allowance - should still have room.
check, err := svc.Check("fail-agent", "")
require.NoError(t, err)
assert.True(t, check.Allowed)
assert.Equal(t, int64(45000), check.RemainingTokens)
}
// TestTaskLifecycle_ClaimProcessCancel tests the lifecycle when a job is
// cancelled and verifies that 100% of tokens are returned.
func TestTaskLifecycle_ClaimProcessCancel(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
_ = store.SetAllowance(&AgentAllowance{
AgentID: "cancel-agent",
DailyTokenLimit: 50000,
DailyJobLimit: 5,
ConcurrentJobs: 2,
})
// Start job.
err := svc.RecordUsage(UsageReport{
AgentID: "cancel-agent",
JobID: "cancel-001",
Event: QuotaEventJobStarted,
})
require.NoError(t, err)
// Job cancelled with 8000 tokens consumed.
err = svc.RecordUsage(UsageReport{
AgentID: "cancel-agent",
JobID: "cancel-001",
TokensIn: 5000,
TokensOut: 3000,
Event: QuotaEventJobCancelled,
})
require.NoError(t, err)
// Verify 100% returned: tokens should be 0 (only job start had 0 tokens).
usage, _ := store.GetUsage("cancel-agent")
assert.Equal(t, int64(0), usage.TokensUsed)
assert.Equal(t, 0, usage.ActiveJobs)
// Model usage should be zero for cancelled jobs.
modelUsage, _ := store.GetModelUsage("claude-sonnet")
assert.Equal(t, int64(0), modelUsage)
}
// TestTaskLifecycle_MultipleAgentsConcurrent verifies that multiple agents
// can operate on the same store concurrently without data races.
func TestTaskLifecycle_MultipleAgentsConcurrent(t *testing.T) {
store := NewMemoryStore()
svc := NewAllowanceService(store)
agents := []string{"agent-a", "agent-b", "agent-c"}
for _, agentID := range agents {
_ = store.SetAllowance(&AgentAllowance{
AgentID: agentID,
DailyTokenLimit: 100000,
DailyJobLimit: 50,
ConcurrentJobs: 5,
})
}
var wg sync.WaitGroup
for _, agentID := range agents {
wg.Add(1)
go func(aid string) {
defer wg.Done()
for range 10 {
// Check allowance.
result, err := svc.Check(aid, "")
assert.NoError(t, err)
assert.True(t, result.Allowed)
// Start job.
_ = svc.RecordUsage(UsageReport{
AgentID: aid,
JobID: aid + "-job",
Event: QuotaEventJobStarted,
})
// Complete job.
_ = svc.RecordUsage(UsageReport{
AgentID: aid,
JobID: aid + "-job",
Model: "claude-sonnet",
TokensIn: 100,
TokensOut: 50,
Event: QuotaEventJobCompleted,
})
}
}(agentID)
}
wg.Wait()
// Verify each agent has consistent usage.
for _, agentID := range agents {
usage, err := store.GetUsage(agentID)
require.NoError(t, err)
assert.Equal(t, int64(1500), usage.TokensUsed) // 10 jobs x 150 tokens
assert.Equal(t, 10, usage.JobsStarted) // 10 starts
assert.Equal(t, 0, usage.ActiveJobs) // all completed
}
}
// TestTaskLifecycle_ClaimedByFilter verifies that ListTasks with ClaimedBy
// filter sends the correct query parameter.
func TestTaskLifecycle_ClaimedByFilter(t *testing.T) {
claimedTask := Task{
ID: "claimed-task-1",
Title: "Agent's task",
Status: StatusInProgress,
ClaimedBy: "agent-x",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "agent-x", r.URL.Query().Get("claimed_by"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode([]Task{claimedTask})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
tasks, err := client.ListTasks(context.Background(), ListOptions{
ClaimedBy: "agent-x",
})
require.NoError(t, err)
require.Len(t, tasks, 1)
assert.Equal(t, "agent-x", tasks[0].ClaimedBy)
}

View file

@ -1,47 +0,0 @@
package lifecycle
import (
"context"
"fmt"
"io"
"time"
"forge.lthn.ai/core/go-log"
)
// StreamLogs polls a task's status and writes updates to writer. It polls at
// the given interval until the task reaches a terminal state (completed or
// blocked) or the context is cancelled. Returns ctx.Err() on cancellation.
func StreamLogs(ctx context.Context, client *Client, taskID string, interval time.Duration, writer io.Writer) error {
const op = "agentic.StreamLogs"
if taskID == "" {
return log.E(op, "task ID is required", nil)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
task, err := client.GetTask(ctx, taskID)
if err != nil {
// Write the error but continue polling -- transient failures
// should not stop the stream.
_, _ = fmt.Fprintf(writer, "[%s] Error: %s\n", time.Now().UTC().Format("2006-01-02 15:04:05"), err)
continue
}
line := fmt.Sprintf("[%s] Status: %s", time.Now().UTC().Format("2006-01-02 15:04:05"), task.Status)
_, _ = fmt.Fprintln(writer, line)
// Stop on terminal states.
if task.Status == StatusCompleted || task.Status == StatusBlocked {
return nil
}
}
}
}

View file

@ -1,139 +0,0 @@
package lifecycle
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStreamLogs_Good_CompletedTask(t *testing.T) {
var calls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/tasks/task-1", r.URL.Path)
n := calls.Add(1)
task := Task{ID: "task-1"}
switch {
case n <= 2:
task.Status = StatusInProgress
default:
task.Status = StatusCompleted
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(task)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
var buf bytes.Buffer
err := StreamLogs(context.Background(), client, "task-1", 10*time.Millisecond, &buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Status: in_progress")
assert.Contains(t, output, "Status: completed")
assert.GreaterOrEqual(t, int(calls.Load()), 3)
}
func TestStreamLogs_Good_BlockedTask(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
task := Task{ID: "task-2", Status: StatusBlocked}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(task)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
var buf bytes.Buffer
err := StreamLogs(context.Background(), client, "task-2", 10*time.Millisecond, &buf)
require.NoError(t, err)
assert.Contains(t, buf.String(), "Status: blocked")
}
func TestStreamLogs_Good_ContextCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
task := Task{ID: "task-3", Status: StatusInProgress}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(task)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
var buf bytes.Buffer
ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
defer cancel()
err := StreamLogs(ctx, client, "task-3", 20*time.Millisecond, &buf)
require.ErrorIs(t, err, context.DeadlineExceeded)
assert.Contains(t, buf.String(), "Status: in_progress")
}
func TestStreamLogs_Good_TransientErrorContinues(t *testing.T) {
var calls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := calls.Add(1)
if n == 1 {
// First call: server error.
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(APIError{Message: "transient"})
return
}
// Second call: completed.
task := Task{ID: "task-4", Status: StatusCompleted}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(task)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
var buf bytes.Buffer
err := StreamLogs(context.Background(), client, "task-4", 10*time.Millisecond, &buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error:")
assert.Contains(t, output, "Status: completed")
}
func TestStreamLogs_Bad_EmptyTaskID(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
var buf bytes.Buffer
err := StreamLogs(context.Background(), client, "", 10*time.Millisecond, &buf)
assert.Error(t, err)
assert.Contains(t, err.Error(), "task ID is required")
}
func TestStreamLogs_Good_ImmediateCancel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
task := Task{ID: "task-5", Status: StatusInProgress}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(task)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
var buf bytes.Buffer
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately.
err := StreamLogs(ctx, client, "task-5", 10*time.Millisecond, &buf)
require.ErrorIs(t, err, context.Canceled)
}

View file

@ -1,197 +0,0 @@
package lifecycle
import (
"context"
"time"
"forge.lthn.ai/core/go-log"
)
// PlanDispatcher orchestrates plan-based work by polling active plans,
// starting sessions, and routing work to agents. It wraps the existing
// agent registry, router, and allowance service alongside the API client.
type PlanDispatcher struct {
registry AgentRegistry
router TaskRouter
allowance *AllowanceService
client *Client
events EventEmitter
agentType string // e.g. "opus", "haiku", "codex"
}
// NewPlanDispatcher creates a PlanDispatcher for the given agent type.
func NewPlanDispatcher(
agentType string,
registry AgentRegistry,
router TaskRouter,
allowance *AllowanceService,
client *Client,
) *PlanDispatcher {
return &PlanDispatcher{
agentType: agentType,
registry: registry,
router: router,
allowance: allowance,
client: client,
}
}
// SetEventEmitter attaches an event emitter for lifecycle notifications.
func (pd *PlanDispatcher) SetEventEmitter(em EventEmitter) {
pd.events = em
}
func (pd *PlanDispatcher) emit(ctx context.Context, event Event) {
if pd.events != nil {
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
_ = pd.events.Emit(ctx, event)
}
}
// PlanDispatchLoop polls for active plans at the given interval and picks up
// the first plan with a pending or in-progress phase. It starts a session,
// marks the phase in-progress, and returns the plan + session for the caller
// to work on. Runs until context is cancelled.
func (pd *PlanDispatcher) PlanDispatchLoop(ctx context.Context, interval time.Duration) error {
const op = "PlanDispatcher.PlanDispatchLoop"
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
plan, session, err := pd.pickUpWork(ctx)
if err != nil {
_ = log.E(op, "failed to pick up work", err)
continue
}
if plan == nil {
continue // no work available
}
pd.emit(ctx, Event{
Type: EventTaskDispatched,
TaskID: plan.Slug,
AgentID: session.SessionID,
Payload: map[string]string{
"plan": plan.Slug,
"agent_type": pd.agentType,
},
})
}
}
}
// pickUpWork finds the first active plan with workable phases, starts a session,
// and marks the next phase in-progress. Returns nil if no work is available.
func (pd *PlanDispatcher) pickUpWork(ctx context.Context) (*Plan, *sessionStartResponse, error) {
const op = "PlanDispatcher.pickUpWork"
plans, err := pd.client.ListPlans(ctx, ListPlanOptions{Status: PlanActive})
if err != nil {
return nil, nil, log.E(op, "failed to list active plans", err)
}
for _, plan := range plans {
// Check agent allowance before taking work.
if pd.allowance != nil {
check, err := pd.allowance.Check(pd.agentType, "")
if err != nil || !check.Allowed {
continue
}
}
// Get full plan with phases.
fullPlan, err := pd.client.GetPlan(ctx, plan.Slug)
if err != nil {
_ = log.E(op, "failed to get plan "+plan.Slug, err)
continue
}
// Find the next workable phase.
phase := nextWorkablePhase(fullPlan.Phases)
if phase == nil {
continue
}
// Start session for this plan.
session, err := pd.client.StartSession(ctx, StartSessionRequest{
AgentType: pd.agentType,
PlanSlug: plan.Slug,
})
if err != nil {
_ = log.E(op, "failed to start session for "+plan.Slug, err)
continue
}
// Mark phase as in-progress.
if phase.Status == PhasePending {
if err := pd.client.UpdatePhaseStatus(ctx, plan.Slug, phase.Name, PhaseInProgress, ""); err != nil {
_ = log.E(op, "failed to update phase status", err)
}
}
// Record job start.
if pd.allowance != nil {
_ = pd.allowance.RecordUsage(UsageReport{
AgentID: pd.agentType,
JobID: plan.Slug,
Event: QuotaEventJobStarted,
Timestamp: time.Now().UTC(),
})
}
return fullPlan, session, nil
}
return nil, nil, nil
}
// CompleteWork ends a session and optionally marks the current phase as completed.
func (pd *PlanDispatcher) CompleteWork(ctx context.Context, planSlug, sessionID, phaseName string, summary string) error {
const op = "PlanDispatcher.CompleteWork"
// Mark phase completed.
if phaseName != "" {
if err := pd.client.UpdatePhaseStatus(ctx, planSlug, phaseName, PhaseCompleted, ""); err != nil {
_ = log.E(op, "failed to complete phase", err)
}
}
// End session.
if err := pd.client.EndSession(ctx, sessionID, "completed", summary); err != nil {
return log.E(op, "failed to end session", err)
}
// Record job completion.
if pd.allowance != nil {
_ = pd.allowance.RecordUsage(UsageReport{
AgentID: pd.agentType,
JobID: planSlug,
Event: QuotaEventJobCompleted,
Timestamp: time.Now().UTC(),
})
}
return nil
}
// nextWorkablePhase returns the first phase that is pending or in-progress.
func nextWorkablePhase(phases []Phase) *Phase {
for i := range phases {
switch phases[i].Status {
case PhasePending:
if phases[i].CanStart {
return &phases[i]
}
case PhaseInProgress:
return &phases[i]
}
}
return nil
}

View file

@ -1,525 +0,0 @@
package lifecycle
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"forge.lthn.ai/core/go-log"
)
// PlanStatus represents the state of a plan.
type PlanStatus string
const (
PlanDraft PlanStatus = "draft"
PlanActive PlanStatus = "active"
PlanPaused PlanStatus = "paused"
PlanCompleted PlanStatus = "completed"
PlanArchived PlanStatus = "archived"
)
// PhaseStatus represents the state of a phase within a plan.
type PhaseStatus string
const (
PhasePending PhaseStatus = "pending"
PhaseInProgress PhaseStatus = "in_progress"
PhaseCompleted PhaseStatus = "completed"
PhaseBlocked PhaseStatus = "blocked"
PhaseSkipped PhaseStatus = "skipped"
)
// Plan represents an agent plan from the PHP API.
type Plan struct {
Slug string `json:"slug"`
Title string `json:"title"`
Description string `json:"description,omitempty"`
Status PlanStatus `json:"status"`
CurrentPhase int `json:"current_phase,omitempty"`
Progress Progress `json:"progress,omitempty"`
Phases []Phase `json:"phases,omitempty"`
Metadata any `json:"metadata,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
// Phase represents a phase within a plan.
type Phase struct {
ID int `json:"id,omitempty"`
Order int `json:"order"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Status PhaseStatus `json:"status"`
Tasks []PhaseTask `json:"tasks,omitempty"`
TaskProgress TaskProgress `json:"task_progress,omitempty"`
RemainingTasks []string `json:"remaining_tasks,omitempty"`
Dependencies []int `json:"dependencies,omitempty"`
DependencyBlockers []string `json:"dependency_blockers,omitempty"`
CanStart bool `json:"can_start,omitempty"`
Checkpoints []any `json:"checkpoints,omitempty"`
StartedAt string `json:"started_at,omitempty"`
CompletedAt string `json:"completed_at,omitempty"`
Metadata any `json:"metadata,omitempty"`
}
// PhaseTask represents a task within a phase. Tasks are stored as a JSON array
// in the phase and may be simple strings or objects with status/notes.
type PhaseTask struct {
Name string `json:"name"`
Status string `json:"status,omitempty"`
Notes string `json:"notes,omitempty"`
}
// UnmarshalJSON handles the fact that tasks can be either strings or objects.
func (t *PhaseTask) UnmarshalJSON(data []byte) error {
// Try string first
var s string
if err := json.Unmarshal(data, &s); err == nil {
t.Name = s
t.Status = "pending"
return nil
}
// Try object
type taskAlias PhaseTask
var obj taskAlias
if err := json.Unmarshal(data, &obj); err != nil {
return err
}
*t = PhaseTask(obj)
return nil
}
// Progress represents plan progress metrics.
type Progress struct {
Total int `json:"total"`
Completed int `json:"completed"`
InProgress int `json:"in_progress"`
Pending int `json:"pending"`
Percentage int `json:"percentage"`
}
// TaskProgress represents task-level progress within a phase.
type TaskProgress struct {
Total int `json:"total"`
Completed int `json:"completed"`
Pending int `json:"pending"`
Percentage int `json:"percentage"`
}
// ListPlanOptions specifies filters for listing plans.
type ListPlanOptions struct {
Status PlanStatus `json:"status,omitempty"`
IncludeArchived bool `json:"include_archived,omitempty"`
}
// CreatePlanRequest is the payload for creating a new plan.
type CreatePlanRequest struct {
Title string `json:"title"`
Slug string `json:"slug,omitempty"`
Description string `json:"description,omitempty"`
Context map[string]any `json:"context,omitempty"`
Phases []CreatePhaseInput `json:"phases,omitempty"`
}
// CreatePhaseInput is a phase definition for plan creation.
type CreatePhaseInput struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Tasks []string `json:"tasks,omitempty"`
}
// planListResponse wraps the list endpoint response.
type planListResponse struct {
Plans []Plan `json:"plans"`
Total int `json:"total"`
}
// planCreateResponse wraps the create endpoint response.
type planCreateResponse struct {
Slug string `json:"slug"`
Title string `json:"title"`
Status string `json:"status"`
Phases int `json:"phases"`
}
// planUpdateResponse wraps the update endpoint response.
type planUpdateResponse struct {
Slug string `json:"slug"`
Status string `json:"status"`
}
// planArchiveResponse wraps the archive endpoint response.
type planArchiveResponse struct {
Slug string `json:"slug"`
Status string `json:"status"`
ArchivedAt string `json:"archived_at,omitempty"`
}
// ListPlans retrieves plans matching the given options.
func (c *Client) ListPlans(ctx context.Context, opts ListPlanOptions) ([]Plan, error) {
const op = "agentic.Client.ListPlans"
params := url.Values{}
if opts.Status != "" {
params.Set("status", string(opts.Status))
}
if opts.IncludeArchived {
params.Set("include_archived", "1")
}
endpoint := c.BaseURL + "/v1/plans"
if len(params) > 0 {
endpoint += "?" + params.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result planListResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return result.Plans, nil
}
// GetPlan retrieves a plan by slug (returns full detail with phases).
func (c *Client) GetPlan(ctx context.Context, slug string) (*Plan, error) {
const op = "agentic.Client.GetPlan"
if slug == "" {
return nil, log.E(op, "plan slug is required", nil)
}
endpoint := fmt.Sprintf("%s/v1/plans/%s", c.BaseURL, url.PathEscape(slug))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var plan Plan
if err := json.NewDecoder(resp.Body).Decode(&plan); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &plan, nil
}
// CreatePlan creates a new plan with optional phases.
func (c *Client) CreatePlan(ctx context.Context, req CreatePlanRequest) (*planCreateResponse, error) {
const op = "agentic.Client.CreatePlan"
if req.Title == "" {
return nil, log.E(op, "title is required", nil)
}
data, err := json.Marshal(req)
if err != nil {
return nil, log.E(op, "failed to marshal request", err)
}
endpoint := c.BaseURL + "/v1/plans"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(httpReq)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result planCreateResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &result, nil
}
// UpdatePlanStatus changes a plan's status.
func (c *Client) UpdatePlanStatus(ctx context.Context, slug string, status PlanStatus) error {
const op = "agentic.Client.UpdatePlanStatus"
if slug == "" {
return log.E(op, "plan slug is required", nil)
}
data, err := json.Marshal(map[string]string{"status": string(status)})
if err != nil {
return log.E(op, "failed to marshal request", err)
}
endpoint := fmt.Sprintf("%s/v1/plans/%s", c.BaseURL, url.PathEscape(slug))
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(data))
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
return c.checkResponse(resp)
}
// ArchivePlan archives a plan with an optional reason.
func (c *Client) ArchivePlan(ctx context.Context, slug string, reason string) error {
const op = "agentic.Client.ArchivePlan"
if slug == "" {
return log.E(op, "plan slug is required", nil)
}
endpoint := fmt.Sprintf("%s/v1/plans/%s", c.BaseURL, url.PathEscape(slug))
var body *bytes.Reader
if reason != "" {
data, _ := json.Marshal(map[string]string{"reason": reason})
body = bytes.NewReader(data)
}
var reqBody *bytes.Reader
if body != nil {
reqBody = body
}
var httpReq *http.Request
var err error
if reqBody != nil {
httpReq, err = http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, reqBody)
if err != nil {
return log.E(op, "failed to create request", err)
}
httpReq.Header.Set("Content-Type", "application/json")
} else {
httpReq, err = http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil)
if err != nil {
return log.E(op, "failed to create request", err)
}
}
c.setHeaders(httpReq)
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
return c.checkResponse(resp)
}
// GetPhase retrieves a specific phase within a plan.
func (c *Client) GetPhase(ctx context.Context, planSlug string, phase string) (*Phase, error) {
const op = "agentic.Client.GetPhase"
if planSlug == "" || phase == "" {
return nil, log.E(op, "plan slug and phase identifier are required", nil)
}
endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s",
c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result Phase
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &result, nil
}
// UpdatePhaseStatus changes a phase's status.
func (c *Client) UpdatePhaseStatus(ctx context.Context, planSlug, phase string, status PhaseStatus, notes string) error {
const op = "agentic.Client.UpdatePhaseStatus"
if planSlug == "" || phase == "" {
return log.E(op, "plan slug and phase identifier are required", nil)
}
payload := map[string]string{"status": string(status)}
if notes != "" {
payload["notes"] = notes
}
data, err := json.Marshal(payload)
if err != nil {
return log.E(op, "failed to marshal request", err)
}
endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s",
c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase))
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(data))
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
return c.checkResponse(resp)
}
// AddCheckpoint adds a checkpoint note to a phase.
func (c *Client) AddCheckpoint(ctx context.Context, planSlug, phase, note string, checkpointCtx map[string]any) error {
const op = "agentic.Client.AddCheckpoint"
if planSlug == "" || phase == "" || note == "" {
return log.E(op, "plan slug, phase, and note are required", nil)
}
payload := map[string]any{"note": note}
if len(checkpointCtx) > 0 {
payload["context"] = checkpointCtx
}
data, err := json.Marshal(payload)
if err != nil {
return log.E(op, "failed to marshal request", err)
}
endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s/checkpoint",
c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
return c.checkResponse(resp)
}
// UpdateTaskStatus updates a task within a phase.
func (c *Client) UpdateTaskStatus(ctx context.Context, planSlug, phase string, taskIdx int, status string, notes string) error {
const op = "agentic.Client.UpdateTaskStatus"
if planSlug == "" || phase == "" {
return log.E(op, "plan slug and phase are required", nil)
}
payload := map[string]any{}
if status != "" {
payload["status"] = status
}
if notes != "" {
payload["notes"] = notes
}
data, err := json.Marshal(payload)
if err != nil {
return log.E(op, "failed to marshal request", err)
}
endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s/tasks/%d",
c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase), taskIdx)
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(data))
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
return c.checkResponse(resp)
}
// ToggleTask toggles a task between pending and completed.
func (c *Client) ToggleTask(ctx context.Context, planSlug, phase string, taskIdx int) error {
const op = "agentic.Client.ToggleTask"
if planSlug == "" || phase == "" {
return log.E(op, "plan slug and phase are required", nil)
}
endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s/tasks/%d/toggle",
c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase), taskIdx)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil)
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
return c.checkResponse(resp)
}

View file

@ -1,44 +0,0 @@
# Commit Instructions
## Context
When asked to commit changes in these repositories:
1. **All changes are pre-approved** - Both tracked (modified) and untracked files have been reviewed and approved for commit
2. **Include everything** - Commit ALL modified files AND all untracked files (including new directories)
3. **Don't skip files** - Do not leave out untracked files assuming they need separate review
4. **Single commit preferred** - Combine all changes into one cohesive commit unless explicitly told otherwise
The user has already validated these changes. Proceed with confidence.
## Handling Ignored Files
If you see untracked directories that should typically be ignored (like `node_modules/`, `vendor/`, `.cache/`, `dist/`, `build/`):
1. **Fix the .gitignore** - Create or update `.gitignore` to exclude these directories
2. **Commit the .gitignore** - Include this fix in your commit
3. **Don't ask** - Just fix it and commit
Common patterns to add to .gitignore:
```
node_modules/
vendor/
.cache/
dist/
build/
*.log
.env
.DS_Store
```
## Commit Message Style
- Use conventional commit format: `type(scope): description`
- Common types: `refactor`, `feat`, `fix`, `docs`, `chore`
- Keep the first line under 72 characters
- Add body for complex changes explaining the "why"
- Include `Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>`
## Task
Review the uncommitted changes and create an appropriate commit. Be concise.

View file

@ -1,168 +0,0 @@
package lifecycle
import (
"iter"
"slices"
"sync"
"time"
)
// AgentStatus represents the availability state of an agent.
type AgentStatus string
const (
// AgentAvailable indicates the agent is ready to accept tasks.
AgentAvailable AgentStatus = "available"
// AgentBusy indicates the agent is working but may accept more tasks.
AgentBusy AgentStatus = "busy"
// AgentOffline indicates the agent has not sent a heartbeat recently.
AgentOffline AgentStatus = "offline"
)
// AgentInfo describes a registered agent and its current state.
type AgentInfo struct {
// ID is the unique identifier for the agent.
ID string `json:"id"`
// Name is the human-readable name of the agent.
Name string `json:"name"`
// Capabilities lists what the agent can handle (e.g. "go", "testing", "frontend").
Capabilities []string `json:"capabilities,omitempty"`
// Status is the current availability state.
Status AgentStatus `json:"status"`
// LastHeartbeat is the last time the agent reported in.
LastHeartbeat time.Time `json:"last_heartbeat"`
// CurrentLoad is the number of active jobs the agent is running.
CurrentLoad int `json:"current_load"`
// MaxLoad is the maximum concurrent jobs the agent supports. 0 means unlimited.
MaxLoad int `json:"max_load"`
}
// AgentRegistry manages the set of known agents and their health.
type AgentRegistry interface {
// Register adds or updates an agent in the registry.
Register(agent AgentInfo) error
// Deregister removes an agent from the registry.
Deregister(id string) error
// Get returns a copy of the agent info for the given ID.
Get(id string) (AgentInfo, error)
// List returns a copy of all registered agents.
List() []AgentInfo
// All returns an iterator over all registered agents.
All() iter.Seq[AgentInfo]
// Heartbeat updates the agent's LastHeartbeat timestamp and sets status
// to Available if the agent was previously Offline.
Heartbeat(id string) error
// Reap marks agents as Offline if their last heartbeat is older than ttl.
// Returns the IDs of agents that were reaped.
Reap(ttl time.Duration) []string
// Reaped returns an iterator over the IDs of agents that were reaped.
Reaped(ttl time.Duration) iter.Seq[string]
}
// MemoryRegistry is an in-memory AgentRegistry implementation guarded by a
// read-write mutex. It uses copy-on-read semantics consistent with MemoryStore.
type MemoryRegistry struct {
mu sync.RWMutex
agents map[string]*AgentInfo
}
// NewMemoryRegistry creates a new in-memory agent registry.
func NewMemoryRegistry() *MemoryRegistry {
return &MemoryRegistry{
agents: make(map[string]*AgentInfo),
}
}
// Register adds or updates an agent in the registry. Returns an error if the
// agent ID is empty.
func (r *MemoryRegistry) Register(agent AgentInfo) error {
if agent.ID == "" {
return &APIError{Code: 400, Message: "agent ID is required"}
}
r.mu.Lock()
defer r.mu.Unlock()
cp := agent
r.agents[agent.ID] = &cp
return nil
}
// Deregister removes an agent from the registry. Returns an error if the agent
// is not found.
func (r *MemoryRegistry) Deregister(id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.agents[id]; !ok {
return &APIError{Code: 404, Message: "agent not found: " + id}
}
delete(r.agents, id)
return nil
}
// Get returns a copy of the agent info for the given ID. Returns an error if
// the agent is not found.
func (r *MemoryRegistry) Get(id string) (AgentInfo, error) {
r.mu.RLock()
defer r.mu.RUnlock()
a, ok := r.agents[id]
if !ok {
return AgentInfo{}, &APIError{Code: 404, Message: "agent not found: " + id}
}
return *a, nil
}
// List returns a copy of all registered agents.
func (r *MemoryRegistry) List() []AgentInfo {
return slices.Collect(r.All())
}
// All returns an iterator over all registered agents.
func (r *MemoryRegistry) All() iter.Seq[AgentInfo] {
return func(yield func(AgentInfo) bool) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, a := range r.agents {
if !yield(*a) {
return
}
}
}
}
// Heartbeat updates the agent's LastHeartbeat timestamp. If the agent was
// Offline, it transitions to Available.
func (r *MemoryRegistry) Heartbeat(id string) error {
r.mu.Lock()
defer r.mu.Unlock()
a, ok := r.agents[id]
if !ok {
return &APIError{Code: 404, Message: "agent not found: " + id}
}
a.LastHeartbeat = time.Now().UTC()
if a.Status == AgentOffline {
a.Status = AgentAvailable
}
return nil
}
// Reap marks agents as Offline if their last heartbeat is older than ttl.
// Returns the IDs of agents that were reaped.
func (r *MemoryRegistry) Reap(ttl time.Duration) []string {
return slices.Collect(r.Reaped(ttl))
}
// Reaped returns an iterator over the IDs of agents that were reaped.
func (r *MemoryRegistry) Reaped(ttl time.Duration) iter.Seq[string] {
return func(yield func(string) bool) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now().UTC()
for id, a := range r.agents {
if a.Status != AgentOffline && now.Sub(a.LastHeartbeat) > ttl {
a.Status = AgentOffline
if !yield(id) {
return
}
}
}
}
}

View file

@ -1,280 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"errors"
"iter"
"slices"
"time"
"github.com/redis/go-redis/v9"
)
// RedisRegistry implements AgentRegistry using Redis as the backing store.
// It provides persistent, network-accessible agent registration suitable for
// multi-node deployments. Heartbeat refreshes key TTL for natural reaping via
// expiry.
type RedisRegistry struct {
client *redis.Client
prefix string
defaultTTL time.Duration
}
// redisRegistryConfig holds the configuration for a RedisRegistry.
type redisRegistryConfig struct {
password string
db int
prefix string
ttl time.Duration
}
// RedisRegistryOption is a functional option for configuring a RedisRegistry.
type RedisRegistryOption func(*redisRegistryConfig)
// WithRegistryRedisPassword sets the password for authenticating with Redis.
func WithRegistryRedisPassword(pw string) RedisRegistryOption {
return func(c *redisRegistryConfig) {
c.password = pw
}
}
// WithRegistryRedisDB selects the Redis database number.
func WithRegistryRedisDB(db int) RedisRegistryOption {
return func(c *redisRegistryConfig) {
c.db = db
}
}
// WithRegistryRedisPrefix sets the key prefix for all Redis keys.
// Default: "agentic".
func WithRegistryRedisPrefix(prefix string) RedisRegistryOption {
return func(c *redisRegistryConfig) {
c.prefix = prefix
}
}
// WithRegistryTTL sets the default TTL for agent keys. Default: 5 minutes.
// Heartbeat refreshes this TTL. Agents whose keys expire are naturally reaped.
func WithRegistryTTL(ttl time.Duration) RedisRegistryOption {
return func(c *redisRegistryConfig) {
c.ttl = ttl
}
}
// NewRedisRegistry creates a new Redis-backed agent registry connecting to the
// given address (host:port). It pings the server to verify connectivity.
func NewRedisRegistry(addr string, opts ...RedisRegistryOption) (*RedisRegistry, error) {
cfg := &redisRegistryConfig{
prefix: "agentic",
ttl: 5 * time.Minute,
}
for _, opt := range opts {
opt(cfg)
}
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: cfg.password,
DB: cfg.db,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
_ = client.Close()
return nil, &APIError{Code: 500, Message: "failed to connect to Redis: " + err.Error()}
}
return &RedisRegistry{
client: client,
prefix: cfg.prefix,
defaultTTL: cfg.ttl,
}, nil
}
// Close releases the underlying Redis connection.
func (r *RedisRegistry) Close() error {
return r.client.Close()
}
// --- key helpers ---
func (r *RedisRegistry) agentKey(id string) string {
return r.prefix + ":agent:" + id
}
func (r *RedisRegistry) agentPattern() string {
return r.prefix + ":agent:*"
}
// --- AgentRegistry interface ---
// Register adds or updates an agent in the registry.
func (r *RedisRegistry) Register(agent AgentInfo) error {
if agent.ID == "" {
return &APIError{Code: 400, Message: "agent ID is required"}
}
ctx := context.Background()
data, err := json.Marshal(agent)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal agent: " + err.Error()}
}
if err := r.client.Set(ctx, r.agentKey(agent.ID), data, r.defaultTTL).Err(); err != nil {
return &APIError{Code: 500, Message: "failed to register agent: " + err.Error()}
}
return nil
}
// Deregister removes an agent from the registry. Returns an error if the agent
// is not found.
func (r *RedisRegistry) Deregister(id string) error {
ctx := context.Background()
n, err := r.client.Del(ctx, r.agentKey(id)).Result()
if err != nil {
return &APIError{Code: 500, Message: "failed to deregister agent: " + err.Error()}
}
if n == 0 {
return &APIError{Code: 404, Message: "agent not found: " + id}
}
return nil
}
// Get returns a copy of the agent info for the given ID. Returns an error if
// the agent is not found.
func (r *RedisRegistry) Get(id string) (AgentInfo, error) {
ctx := context.Background()
val, err := r.client.Get(ctx, r.agentKey(id)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return AgentInfo{}, &APIError{Code: 404, Message: "agent not found: " + id}
}
return AgentInfo{}, &APIError{Code: 500, Message: "failed to get agent: " + err.Error()}
}
var a AgentInfo
if err := json.Unmarshal([]byte(val), &a); err != nil {
return AgentInfo{}, &APIError{Code: 500, Message: "failed to unmarshal agent: " + err.Error()}
}
return a, nil
}
// List returns a copy of all registered agents by scanning all agent keys.
func (r *RedisRegistry) List() []AgentInfo {
return slices.Collect(r.All())
}
// All returns an iterator over all registered agents.
func (r *RedisRegistry) All() iter.Seq[AgentInfo] {
return func(yield func(AgentInfo) bool) {
ctx := context.Background()
iter := r.client.Scan(ctx, 0, r.agentPattern(), 100).Iterator()
for iter.Next(ctx) {
val, err := r.client.Get(ctx, iter.Val()).Result()
if err != nil {
continue
}
var a AgentInfo
if err := json.Unmarshal([]byte(val), &a); err != nil {
continue
}
if !yield(a) {
return
}
}
}
}
// Heartbeat updates the agent's LastHeartbeat timestamp and refreshes the key
// TTL. If the agent was Offline, it transitions to Available.
func (r *RedisRegistry) Heartbeat(id string) error {
ctx := context.Background()
key := r.agentKey(id)
val, err := r.client.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return &APIError{Code: 404, Message: "agent not found: " + id}
}
return &APIError{Code: 500, Message: "failed to get agent for heartbeat: " + err.Error()}
}
var a AgentInfo
if err := json.Unmarshal([]byte(val), &a); err != nil {
return &APIError{Code: 500, Message: "failed to unmarshal agent: " + err.Error()}
}
a.LastHeartbeat = time.Now().UTC()
if a.Status == AgentOffline {
a.Status = AgentAvailable
}
data, err := json.Marshal(a)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal agent: " + err.Error()}
}
if err := r.client.Set(ctx, key, data, r.defaultTTL).Err(); err != nil {
return &APIError{Code: 500, Message: "failed to update agent heartbeat: " + err.Error()}
}
return nil
}
// Reap scans all agent keys and marks agents as Offline if their last heartbeat
// is older than ttl. This is a backup to natural TTL expiry. Returns the IDs
// of agents that were reaped.
func (r *RedisRegistry) Reap(ttl time.Duration) []string {
return slices.Collect(r.Reaped(ttl))
}
// Reaped returns an iterator over the IDs of agents that were reaped.
func (r *RedisRegistry) Reaped(ttl time.Duration) iter.Seq[string] {
return func(yield func(string) bool) {
ctx := context.Background()
cutoff := time.Now().UTC().Add(-ttl)
iter := r.client.Scan(ctx, 0, r.agentPattern(), 100).Iterator()
for iter.Next(ctx) {
key := iter.Val()
val, err := r.client.Get(ctx, key).Result()
if err != nil {
continue
}
var a AgentInfo
if err := json.Unmarshal([]byte(val), &a); err != nil {
continue
}
if a.Status != AgentOffline && a.LastHeartbeat.Before(cutoff) {
a.Status = AgentOffline
data, err := json.Marshal(a)
if err != nil {
continue
}
// Preserve remaining TTL (or use default if none).
remainingTTL, err := r.client.TTL(ctx, key).Result()
if err != nil || remainingTTL <= 0 {
remainingTTL = r.defaultTTL
}
if err := r.client.Set(ctx, key, data, remainingTTL).Err(); err != nil {
continue
}
if !yield(a.ID) {
return
}
}
}
}
}
// FlushPrefix deletes all keys matching the registry's prefix. Useful for
// testing cleanup.
func (r *RedisRegistry) FlushPrefix(ctx context.Context) error {
iter := r.client.Scan(ctx, 0, r.prefix+":*", 100).Iterator()
for iter.Next(ctx) {
if err := r.client.Del(ctx, iter.Val()).Err(); err != nil {
return err
}
}
return iter.Err()
}

View file

@ -1,327 +0,0 @@
package lifecycle
import (
"context"
"fmt"
"sort"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestRedisRegistry creates a RedisRegistry with a unique prefix for test isolation.
// Skips the test if Redis is unreachable.
func newTestRedisRegistry(t *testing.T) *RedisRegistry {
t.Helper()
prefix := fmt.Sprintf("test_reg_%d", time.Now().UnixNano())
reg, err := NewRedisRegistry(testRedisAddr,
WithRegistryRedisPrefix(prefix),
WithRegistryTTL(5*time.Minute),
)
if err != nil {
t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err)
}
t.Cleanup(func() {
ctx := context.Background()
_ = reg.FlushPrefix(ctx)
_ = reg.Close()
})
return reg
}
// --- Register tests ---
func TestRedisRegistry_Register_Good(t *testing.T) {
reg := newTestRedisRegistry(t)
err := reg.Register(AgentInfo{
ID: "agent-1",
Name: "Test Agent",
Capabilities: []string{"go", "testing"},
Status: AgentAvailable,
MaxLoad: 5,
})
require.NoError(t, err)
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, "agent-1", got.ID)
assert.Equal(t, "Test Agent", got.Name)
assert.Equal(t, []string{"go", "testing"}, got.Capabilities)
assert.Equal(t, AgentAvailable, got.Status)
assert.Equal(t, 5, got.MaxLoad)
}
func TestRedisRegistry_Register_Good_Overwrite(t *testing.T) {
reg := newTestRedisRegistry(t)
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", MaxLoad: 3})
err := reg.Register(AgentInfo{ID: "agent-1", Name: "Updated", MaxLoad: 10})
require.NoError(t, err)
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, "Updated", got.Name)
assert.Equal(t, 10, got.MaxLoad)
}
func TestRedisRegistry_Register_Bad_EmptyID(t *testing.T) {
reg := newTestRedisRegistry(t)
err := reg.Register(AgentInfo{ID: "", Name: "No ID"})
require.Error(t, err)
assert.Contains(t, err.Error(), "agent ID is required")
}
// --- Deregister tests ---
func TestRedisRegistry_Deregister_Good(t *testing.T) {
reg := newTestRedisRegistry(t)
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "To Remove"})
err := reg.Deregister("agent-1")
require.NoError(t, err)
_, err = reg.Get("agent-1")
require.Error(t, err)
}
func TestRedisRegistry_Deregister_Bad_NotFound(t *testing.T) {
reg := newTestRedisRegistry(t)
err := reg.Deregister("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
// --- Get tests ---
func TestRedisRegistry_Get_Good(t *testing.T) {
reg := newTestRedisRegistry(t)
now := time.Now().UTC().Truncate(time.Millisecond)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Name: "Getter",
Status: AgentBusy,
CurrentLoad: 2,
MaxLoad: 5,
LastHeartbeat: now,
})
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, AgentBusy, got.Status)
assert.Equal(t, 2, got.CurrentLoad)
assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond)
}
func TestRedisRegistry_Get_Bad_NotFound(t *testing.T) {
reg := newTestRedisRegistry(t)
_, err := reg.Get("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
func TestRedisRegistry_Get_Good_ReturnsCopy(t *testing.T) {
reg := newTestRedisRegistry(t)
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", CurrentLoad: 1})
got, _ := reg.Get("agent-1")
got.CurrentLoad = 99
got.Name = "Tampered"
// Re-read — should be unchanged (deserialized from Redis).
again, _ := reg.Get("agent-1")
assert.Equal(t, "Original", again.Name)
assert.Equal(t, 1, again.CurrentLoad)
}
// --- List tests ---
func TestRedisRegistry_List_Good_Empty(t *testing.T) {
reg := newTestRedisRegistry(t)
agents := reg.List()
assert.Empty(t, agents)
}
func TestRedisRegistry_List_Good_Multiple(t *testing.T) {
reg := newTestRedisRegistry(t)
_ = reg.Register(AgentInfo{ID: "a", Name: "Alpha"})
_ = reg.Register(AgentInfo{ID: "b", Name: "Beta"})
_ = reg.Register(AgentInfo{ID: "c", Name: "Charlie"})
agents := reg.List()
assert.Len(t, agents, 3)
// Sort by ID for deterministic assertion.
sort.Slice(agents, func(i, j int) bool { return agents[i].ID < agents[j].ID })
assert.Equal(t, "a", agents[0].ID)
assert.Equal(t, "b", agents[1].ID)
assert.Equal(t, "c", agents[2].ID)
}
// --- Heartbeat tests ---
func TestRedisRegistry_Heartbeat_Good(t *testing.T) {
reg := newTestRedisRegistry(t)
past := time.Now().UTC().Add(-5 * time.Minute)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentAvailable,
LastHeartbeat: past,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.True(t, got.LastHeartbeat.After(past))
assert.Equal(t, AgentAvailable, got.Status)
}
func TestRedisRegistry_Heartbeat_Good_RecoverFromOffline(t *testing.T) {
reg := newTestRedisRegistry(t)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentOffline,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.Equal(t, AgentAvailable, got.Status)
}
func TestRedisRegistry_Heartbeat_Good_BusyStaysBusy(t *testing.T) {
reg := newTestRedisRegistry(t)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentBusy,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.Equal(t, AgentBusy, got.Status)
}
func TestRedisRegistry_Heartbeat_Bad_NotFound(t *testing.T) {
reg := newTestRedisRegistry(t)
err := reg.Heartbeat("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
// --- Reap tests ---
func TestRedisRegistry_Reap_Good_StaleAgent(t *testing.T) {
reg := newTestRedisRegistry(t)
stale := time.Now().UTC().Add(-10 * time.Minute)
fresh := time.Now().UTC()
_ = reg.Register(AgentInfo{ID: "stale-1", Status: AgentAvailable, LastHeartbeat: stale})
_ = reg.Register(AgentInfo{ID: "fresh-1", Status: AgentAvailable, LastHeartbeat: fresh})
reaped := reg.Reap(5 * time.Minute)
assert.Len(t, reaped, 1)
assert.Contains(t, reaped, "stale-1")
got, _ := reg.Get("stale-1")
assert.Equal(t, AgentOffline, got.Status)
got, _ = reg.Get("fresh-1")
assert.Equal(t, AgentAvailable, got.Status)
}
func TestRedisRegistry_Reap_Good_AlreadyOfflineSkipped(t *testing.T) {
reg := newTestRedisRegistry(t)
stale := time.Now().UTC().Add(-10 * time.Minute)
_ = reg.Register(AgentInfo{ID: "already-off", Status: AgentOffline, LastHeartbeat: stale})
reaped := reg.Reap(5 * time.Minute)
assert.Empty(t, reaped)
}
func TestRedisRegistry_Reap_Good_NoStaleAgents(t *testing.T) {
reg := newTestRedisRegistry(t)
now := time.Now().UTC()
_ = reg.Register(AgentInfo{ID: "a", Status: AgentAvailable, LastHeartbeat: now})
_ = reg.Register(AgentInfo{ID: "b", Status: AgentBusy, LastHeartbeat: now})
reaped := reg.Reap(5 * time.Minute)
assert.Empty(t, reaped)
}
func TestRedisRegistry_Reap_Good_BusyAgentReaped(t *testing.T) {
reg := newTestRedisRegistry(t)
stale := time.Now().UTC().Add(-10 * time.Minute)
_ = reg.Register(AgentInfo{ID: "busy-stale", Status: AgentBusy, LastHeartbeat: stale})
reaped := reg.Reap(5 * time.Minute)
assert.Len(t, reaped, 1)
assert.Contains(t, reaped, "busy-stale")
got, _ := reg.Get("busy-stale")
assert.Equal(t, AgentOffline, got.Status)
}
// --- Concurrent access ---
func TestRedisRegistry_Concurrent_Good(t *testing.T) {
reg := newTestRedisRegistry(t)
var wg sync.WaitGroup
for i := range 20 {
wg.Add(1)
go func(n int) {
defer wg.Done()
id := "agent-" + string(rune('a'+n%5))
_ = reg.Register(AgentInfo{
ID: id,
Name: "Concurrent",
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
})
_, _ = reg.Get(id)
_ = reg.Heartbeat(id)
_ = reg.List()
_ = reg.Reap(1 * time.Minute)
}(i)
}
wg.Wait()
// No race conditions — test passes under -race.
agents := reg.List()
assert.True(t, len(agents) > 0)
}
// --- Constructor error case ---
func TestNewRedisRegistry_Bad_Unreachable(t *testing.T) {
_, err := NewRedisRegistry("127.0.0.1:1") // almost certainly unreachable
require.Error(t, err)
apiErr, ok := err.(*APIError)
require.True(t, ok, "expected *APIError")
assert.Equal(t, 500, apiErr.Code)
assert.Contains(t, err.Error(), "failed to connect to Redis")
}
// --- Config-based factory with redis backend ---
func TestNewAgentRegistryFromConfig_Good_Redis(t *testing.T) {
cfg := RegistryConfig{
RegistryBackend: "redis",
RegistryRedisAddr: testRedisAddr,
}
reg, err := NewAgentRegistryFromConfig(cfg)
if err != nil {
t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err)
}
rr, ok := reg.(*RedisRegistry)
assert.True(t, ok, "expected RedisRegistry")
_ = rr.Close()
}

View file

@ -1,267 +0,0 @@
package lifecycle
import (
"database/sql"
"encoding/json"
"iter"
"slices"
"strings"
"sync"
"time"
_ "modernc.org/sqlite"
)
// SQLiteRegistry implements AgentRegistry using a SQLite database.
// It provides persistent storage that survives process restarts.
type SQLiteRegistry struct {
db *sql.DB
mu sync.Mutex // serialises read-modify-write operations
}
// NewSQLiteRegistry creates a new SQLite-backed agent registry at the given path.
// Use ":memory:" for tests that do not need persistence.
func NewSQLiteRegistry(dbPath string) (*SQLiteRegistry, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, &APIError{Code: 500, Message: "failed to open SQLite registry: " + err.Error()}
}
db.SetMaxOpenConns(1)
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
db.Close()
return nil, &APIError{Code: 500, Message: "failed to set WAL mode: " + err.Error()}
}
if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil {
db.Close()
return nil, &APIError{Code: 500, Message: "failed to set busy_timeout: " + err.Error()}
}
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS agents (
id TEXT PRIMARY KEY,
name TEXT NOT NULL DEFAULT '',
capabilities TEXT NOT NULL DEFAULT '[]',
status TEXT NOT NULL DEFAULT 'available',
last_heartbeat DATETIME NOT NULL DEFAULT (datetime('now')),
current_load INTEGER NOT NULL DEFAULT 0,
max_load INTEGER NOT NULL DEFAULT 0,
registered_at DATETIME NOT NULL DEFAULT (datetime('now'))
)`); err != nil {
db.Close()
return nil, &APIError{Code: 500, Message: "failed to create agents table: " + err.Error()}
}
return &SQLiteRegistry{db: db}, nil
}
// Close releases the underlying SQLite database.
func (r *SQLiteRegistry) Close() error {
return r.db.Close()
}
// Register adds or updates an agent in the registry. Returns an error if the
// agent ID is empty.
func (r *SQLiteRegistry) Register(agent AgentInfo) error {
if agent.ID == "" {
return &APIError{Code: 400, Message: "agent ID is required"}
}
caps, err := json.Marshal(agent.Capabilities)
if err != nil {
return &APIError{Code: 500, Message: "failed to marshal capabilities: " + err.Error()}
}
hb := agent.LastHeartbeat
if hb.IsZero() {
hb = time.Now().UTC()
}
r.mu.Lock()
defer r.mu.Unlock()
_, err = r.db.Exec(`INSERT INTO agents (id, name, capabilities, status, last_heartbeat, current_load, max_load, registered_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
name = excluded.name,
capabilities = excluded.capabilities,
status = excluded.status,
last_heartbeat = excluded.last_heartbeat,
current_load = excluded.current_load,
max_load = excluded.max_load`,
agent.ID, agent.Name, string(caps), string(agent.Status), hb.Format(time.RFC3339Nano),
agent.CurrentLoad, agent.MaxLoad, hb.Format(time.RFC3339Nano))
if err != nil {
return &APIError{Code: 500, Message: "failed to register agent: " + err.Error()}
}
return nil
}
// Deregister removes an agent from the registry. Returns an error if the agent
// is not found.
func (r *SQLiteRegistry) Deregister(id string) error {
r.mu.Lock()
defer r.mu.Unlock()
res, err := r.db.Exec("DELETE FROM agents WHERE id = ?", id)
if err != nil {
return &APIError{Code: 500, Message: "failed to deregister agent: " + err.Error()}
}
n, err := res.RowsAffected()
if err != nil {
return &APIError{Code: 500, Message: "failed to check delete result: " + err.Error()}
}
if n == 0 {
return &APIError{Code: 404, Message: "agent not found: " + id}
}
return nil
}
// Get returns a copy of the agent info for the given ID. Returns an error if
// the agent is not found.
func (r *SQLiteRegistry) Get(id string) (AgentInfo, error) {
return r.scanAgent("SELECT id, name, capabilities, status, last_heartbeat, current_load, max_load FROM agents WHERE id = ?", id)
}
// List returns a copy of all registered agents.
func (r *SQLiteRegistry) List() []AgentInfo {
return slices.Collect(r.All())
}
// All returns an iterator over all registered agents.
func (r *SQLiteRegistry) All() iter.Seq[AgentInfo] {
return func(yield func(AgentInfo) bool) {
rows, err := r.db.Query("SELECT id, name, capabilities, status, last_heartbeat, current_load, max_load FROM agents")
if err != nil {
return
}
defer rows.Close()
for rows.Next() {
a, err := r.scanAgentRow(rows)
if err != nil {
continue
}
if !yield(a) {
return
}
}
}
}
// Heartbeat updates the agent's LastHeartbeat timestamp. If the agent was
// Offline, it transitions to Available.
func (r *SQLiteRegistry) Heartbeat(id string) error {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now().UTC().Format(time.RFC3339Nano)
// Update heartbeat for all agents, and transition offline agents to available.
res, err := r.db.Exec(`UPDATE agents SET
last_heartbeat = ?,
status = CASE WHEN status = ? THEN ? ELSE status END
WHERE id = ?`,
now, string(AgentOffline), string(AgentAvailable), id)
if err != nil {
return &APIError{Code: 500, Message: "failed to heartbeat agent: " + err.Error()}
}
n, err := res.RowsAffected()
if err != nil {
return &APIError{Code: 500, Message: "failed to check heartbeat result: " + err.Error()}
}
if n == 0 {
return &APIError{Code: 404, Message: "agent not found: " + id}
}
return nil
}
// Reap marks agents as Offline if their last heartbeat is older than ttl.
// Returns the IDs of agents that were reaped.
func (r *SQLiteRegistry) Reap(ttl time.Duration) []string {
return slices.Collect(r.Reaped(ttl))
}
// Reaped returns an iterator over the IDs of agents that were reaped.
func (r *SQLiteRegistry) Reaped(ttl time.Duration) iter.Seq[string] {
return func(yield func(string) bool) {
r.mu.Lock()
defer r.mu.Unlock()
cutoff := time.Now().UTC().Add(-ttl).Format(time.RFC3339Nano)
// Select agents that will be reaped before updating.
rows, err := r.db.Query(
"SELECT id FROM agents WHERE status != ? AND last_heartbeat < ?",
string(AgentOffline), cutoff)
if err != nil {
return
}
defer rows.Close()
var reaped []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
continue
}
reaped = append(reaped, id)
}
if err := rows.Err(); err != nil {
return
}
rows.Close()
if len(reaped) > 0 {
// Build placeholders for IN clause.
placeholders := make([]string, len(reaped))
args := make([]any, len(reaped))
for i, id := range reaped {
placeholders[i] = "?"
args[i] = id
}
query := "UPDATE agents SET status = ? WHERE id IN (" + strings.Join(placeholders, ",") + ")"
allArgs := append([]any{string(AgentOffline)}, args...)
_, _ = r.db.Exec(query, allArgs...)
for _, id := range reaped {
if !yield(id) {
return
}
}
}
}
}
// --- internal helpers ---
// scanAgent executes a query that returns a single agent row.
func (r *SQLiteRegistry) scanAgent(query string, args ...any) (AgentInfo, error) {
row := r.db.QueryRow(query, args...)
var a AgentInfo
var capsJSON string
var statusStr string
var hbStr string
err := row.Scan(&a.ID, &a.Name, &capsJSON, &statusStr, &hbStr, &a.CurrentLoad, &a.MaxLoad)
if err == sql.ErrNoRows {
return AgentInfo{}, &APIError{Code: 404, Message: "agent not found: " + args[0].(string)}
}
if err != nil {
return AgentInfo{}, &APIError{Code: 500, Message: "failed to scan agent: " + err.Error()}
}
if err := json.Unmarshal([]byte(capsJSON), &a.Capabilities); err != nil {
return AgentInfo{}, &APIError{Code: 500, Message: "failed to unmarshal capabilities: " + err.Error()}
}
a.Status = AgentStatus(statusStr)
a.LastHeartbeat, _ = time.Parse(time.RFC3339Nano, hbStr)
return a, nil
}
// scanAgentRow scans a single row from a rows iterator.
func (r *SQLiteRegistry) scanAgentRow(rows *sql.Rows) (AgentInfo, error) {
var a AgentInfo
var capsJSON string
var statusStr string
var hbStr string
err := rows.Scan(&a.ID, &a.Name, &capsJSON, &statusStr, &hbStr, &a.CurrentLoad, &a.MaxLoad)
if err != nil {
return AgentInfo{}, err
}
if err := json.Unmarshal([]byte(capsJSON), &a.Capabilities); err != nil {
return AgentInfo{}, err
}
a.Status = AgentStatus(statusStr)
a.LastHeartbeat, _ = time.Parse(time.RFC3339Nano, hbStr)
return a, nil
}

View file

@ -1,386 +0,0 @@
package lifecycle
import (
"path/filepath"
"sort"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestSQLiteRegistry creates a SQLiteRegistry backed by :memory: for testing.
func newTestSQLiteRegistry(t *testing.T) *SQLiteRegistry {
t.Helper()
reg, err := NewSQLiteRegistry(":memory:")
require.NoError(t, err)
t.Cleanup(func() { _ = reg.Close() })
return reg
}
// --- Register tests ---
func TestSQLiteRegistry_Register_Good(t *testing.T) {
reg := newTestSQLiteRegistry(t)
err := reg.Register(AgentInfo{
ID: "agent-1",
Name: "Test Agent",
Capabilities: []string{"go", "testing"},
Status: AgentAvailable,
MaxLoad: 5,
})
require.NoError(t, err)
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, "agent-1", got.ID)
assert.Equal(t, "Test Agent", got.Name)
assert.Equal(t, []string{"go", "testing"}, got.Capabilities)
assert.Equal(t, AgentAvailable, got.Status)
assert.Equal(t, 5, got.MaxLoad)
}
func TestSQLiteRegistry_Register_Good_Overwrite(t *testing.T) {
reg := newTestSQLiteRegistry(t)
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", MaxLoad: 3})
err := reg.Register(AgentInfo{ID: "agent-1", Name: "Updated", MaxLoad: 10})
require.NoError(t, err)
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, "Updated", got.Name)
assert.Equal(t, 10, got.MaxLoad)
}
func TestSQLiteRegistry_Register_Bad_EmptyID(t *testing.T) {
reg := newTestSQLiteRegistry(t)
err := reg.Register(AgentInfo{ID: "", Name: "No ID"})
require.Error(t, err)
assert.Contains(t, err.Error(), "agent ID is required")
}
func TestSQLiteRegistry_Register_Good_NilCapabilities(t *testing.T) {
reg := newTestSQLiteRegistry(t)
err := reg.Register(AgentInfo{
ID: "agent-1",
Name: "No Caps",
Capabilities: nil,
Status: AgentAvailable,
})
require.NoError(t, err)
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, "No Caps", got.Name)
// nil capabilities serialised as JSON null, deserialised back to nil.
}
// --- Deregister tests ---
func TestSQLiteRegistry_Deregister_Good(t *testing.T) {
reg := newTestSQLiteRegistry(t)
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "To Remove"})
err := reg.Deregister("agent-1")
require.NoError(t, err)
_, err = reg.Get("agent-1")
require.Error(t, err)
}
func TestSQLiteRegistry_Deregister_Bad_NotFound(t *testing.T) {
reg := newTestSQLiteRegistry(t)
err := reg.Deregister("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
// --- Get tests ---
func TestSQLiteRegistry_Get_Good(t *testing.T) {
reg := newTestSQLiteRegistry(t)
now := time.Now().UTC().Truncate(time.Microsecond)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Name: "Getter",
Status: AgentBusy,
CurrentLoad: 2,
MaxLoad: 5,
LastHeartbeat: now,
})
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, AgentBusy, got.Status)
assert.Equal(t, 2, got.CurrentLoad)
// Heartbeat stored via RFC3339Nano — allow small time difference from serialisation.
assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond)
}
func TestSQLiteRegistry_Get_Bad_NotFound(t *testing.T) {
reg := newTestSQLiteRegistry(t)
_, err := reg.Get("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
func TestSQLiteRegistry_Get_Good_ReturnsCopy(t *testing.T) {
reg := newTestSQLiteRegistry(t)
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", CurrentLoad: 1})
got, _ := reg.Get("agent-1")
got.CurrentLoad = 99
got.Name = "Tampered"
// Re-read — should be unchanged.
again, _ := reg.Get("agent-1")
assert.Equal(t, "Original", again.Name)
assert.Equal(t, 1, again.CurrentLoad)
}
// --- List tests ---
func TestSQLiteRegistry_List_Good_Empty(t *testing.T) {
reg := newTestSQLiteRegistry(t)
agents := reg.List()
assert.Empty(t, agents)
}
func TestSQLiteRegistry_List_Good_Multiple(t *testing.T) {
reg := newTestSQLiteRegistry(t)
_ = reg.Register(AgentInfo{ID: "a", Name: "Alpha"})
_ = reg.Register(AgentInfo{ID: "b", Name: "Beta"})
_ = reg.Register(AgentInfo{ID: "c", Name: "Charlie"})
agents := reg.List()
assert.Len(t, agents, 3)
// Sort by ID for deterministic assertion.
sort.Slice(agents, func(i, j int) bool { return agents[i].ID < agents[j].ID })
assert.Equal(t, "a", agents[0].ID)
assert.Equal(t, "b", agents[1].ID)
assert.Equal(t, "c", agents[2].ID)
}
// --- Heartbeat tests ---
func TestSQLiteRegistry_Heartbeat_Good(t *testing.T) {
reg := newTestSQLiteRegistry(t)
past := time.Now().UTC().Add(-5 * time.Minute)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentAvailable,
LastHeartbeat: past,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.True(t, got.LastHeartbeat.After(past))
assert.Equal(t, AgentAvailable, got.Status)
}
func TestSQLiteRegistry_Heartbeat_Good_RecoverFromOffline(t *testing.T) {
reg := newTestSQLiteRegistry(t)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentOffline,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.Equal(t, AgentAvailable, got.Status)
}
func TestSQLiteRegistry_Heartbeat_Good_BusyStaysBusy(t *testing.T) {
reg := newTestSQLiteRegistry(t)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentBusy,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.Equal(t, AgentBusy, got.Status)
}
func TestSQLiteRegistry_Heartbeat_Bad_NotFound(t *testing.T) {
reg := newTestSQLiteRegistry(t)
err := reg.Heartbeat("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
// --- Reap tests ---
func TestSQLiteRegistry_Reap_Good_StaleAgent(t *testing.T) {
reg := newTestSQLiteRegistry(t)
stale := time.Now().UTC().Add(-10 * time.Minute)
fresh := time.Now().UTC()
_ = reg.Register(AgentInfo{ID: "stale-1", Status: AgentAvailable, LastHeartbeat: stale})
_ = reg.Register(AgentInfo{ID: "fresh-1", Status: AgentAvailable, LastHeartbeat: fresh})
reaped := reg.Reap(5 * time.Minute)
assert.Len(t, reaped, 1)
assert.Contains(t, reaped, "stale-1")
got, _ := reg.Get("stale-1")
assert.Equal(t, AgentOffline, got.Status)
got, _ = reg.Get("fresh-1")
assert.Equal(t, AgentAvailable, got.Status)
}
func TestSQLiteRegistry_Reap_Good_AlreadyOfflineSkipped(t *testing.T) {
reg := newTestSQLiteRegistry(t)
stale := time.Now().UTC().Add(-10 * time.Minute)
_ = reg.Register(AgentInfo{ID: "already-off", Status: AgentOffline, LastHeartbeat: stale})
reaped := reg.Reap(5 * time.Minute)
assert.Empty(t, reaped)
}
func TestSQLiteRegistry_Reap_Good_NoStaleAgents(t *testing.T) {
reg := newTestSQLiteRegistry(t)
now := time.Now().UTC()
_ = reg.Register(AgentInfo{ID: "a", Status: AgentAvailable, LastHeartbeat: now})
_ = reg.Register(AgentInfo{ID: "b", Status: AgentBusy, LastHeartbeat: now})
reaped := reg.Reap(5 * time.Minute)
assert.Empty(t, reaped)
}
func TestSQLiteRegistry_Reap_Good_BusyAgentReaped(t *testing.T) {
reg := newTestSQLiteRegistry(t)
stale := time.Now().UTC().Add(-10 * time.Minute)
_ = reg.Register(AgentInfo{ID: "busy-stale", Status: AgentBusy, LastHeartbeat: stale})
reaped := reg.Reap(5 * time.Minute)
assert.Len(t, reaped, 1)
assert.Contains(t, reaped, "busy-stale")
got, _ := reg.Get("busy-stale")
assert.Equal(t, AgentOffline, got.Status)
}
// --- Concurrent access ---
func TestSQLiteRegistry_Concurrent_Good(t *testing.T) {
reg := newTestSQLiteRegistry(t)
var wg sync.WaitGroup
for i := range 20 {
wg.Add(1)
go func(n int) {
defer wg.Done()
id := "agent-" + string(rune('a'+n%5))
_ = reg.Register(AgentInfo{
ID: id,
Name: "Concurrent",
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
})
_, _ = reg.Get(id)
_ = reg.Heartbeat(id)
_ = reg.List()
_ = reg.Reap(1 * time.Minute)
}(i)
}
wg.Wait()
// No race conditions — test passes under -race.
agents := reg.List()
assert.True(t, len(agents) > 0)
}
// --- Persistence: close and reopen ---
func TestSQLiteRegistry_Persistence_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "registry.db")
// Phase 1: write data
r1, err := NewSQLiteRegistry(dbPath)
require.NoError(t, err)
now := time.Now().UTC().Truncate(time.Microsecond)
_ = r1.Register(AgentInfo{
ID: "agent-1",
Name: "Persistent",
Capabilities: []string{"go", "rust"},
Status: AgentBusy,
LastHeartbeat: now,
CurrentLoad: 3,
MaxLoad: 10,
})
require.NoError(t, r1.Close())
// Phase 2: reopen and verify
r2, err := NewSQLiteRegistry(dbPath)
require.NoError(t, err)
defer func() { _ = r2.Close() }()
got, err := r2.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, "Persistent", got.Name)
assert.Equal(t, []string{"go", "rust"}, got.Capabilities)
assert.Equal(t, AgentBusy, got.Status)
assert.Equal(t, 3, got.CurrentLoad)
assert.Equal(t, 10, got.MaxLoad)
assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond)
}
// --- Constructor error case ---
func TestNewSQLiteRegistry_Bad_InvalidPath(t *testing.T) {
_, err := NewSQLiteRegistry("/nonexistent/deeply/nested/dir/registry.db")
require.Error(t, err)
}
// --- Config-based factory ---
func TestNewAgentRegistryFromConfig_Good_Memory(t *testing.T) {
cfg := RegistryConfig{RegistryBackend: "memory"}
reg, err := NewAgentRegistryFromConfig(cfg)
require.NoError(t, err)
_, ok := reg.(*MemoryRegistry)
assert.True(t, ok, "expected MemoryRegistry")
}
func TestNewAgentRegistryFromConfig_Good_Default(t *testing.T) {
cfg := RegistryConfig{} // empty defaults to memory
reg, err := NewAgentRegistryFromConfig(cfg)
require.NoError(t, err)
_, ok := reg.(*MemoryRegistry)
assert.True(t, ok, "expected MemoryRegistry for empty config")
}
func TestNewAgentRegistryFromConfig_Good_SQLite(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "factory-registry.db")
cfg := RegistryConfig{
RegistryBackend: "sqlite",
RegistryPath: dbPath,
}
reg, err := NewAgentRegistryFromConfig(cfg)
require.NoError(t, err)
sr, ok := reg.(*SQLiteRegistry)
assert.True(t, ok, "expected SQLiteRegistry")
_ = sr.Close()
}
func TestNewAgentRegistryFromConfig_Bad_UnknownBackend(t *testing.T) {
cfg := RegistryConfig{RegistryBackend: "cassandra"}
_, err := NewAgentRegistryFromConfig(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported registry backend")
}

View file

@ -1,298 +0,0 @@
package lifecycle
import (
"sort"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Register tests ---
func TestMemoryRegistry_Register_Good(t *testing.T) {
reg := NewMemoryRegistry()
err := reg.Register(AgentInfo{
ID: "agent-1",
Name: "Test Agent",
Capabilities: []string{"go", "testing"},
Status: AgentAvailable,
MaxLoad: 5,
})
require.NoError(t, err)
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, "agent-1", got.ID)
assert.Equal(t, "Test Agent", got.Name)
assert.Equal(t, []string{"go", "testing"}, got.Capabilities)
assert.Equal(t, AgentAvailable, got.Status)
assert.Equal(t, 5, got.MaxLoad)
}
func TestMemoryRegistry_Register_Good_Overwrite(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", MaxLoad: 3})
err := reg.Register(AgentInfo{ID: "agent-1", Name: "Updated", MaxLoad: 10})
require.NoError(t, err)
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, "Updated", got.Name)
assert.Equal(t, 10, got.MaxLoad)
}
func TestMemoryRegistry_Register_Bad_EmptyID(t *testing.T) {
reg := NewMemoryRegistry()
err := reg.Register(AgentInfo{ID: "", Name: "No ID"})
require.Error(t, err)
assert.Contains(t, err.Error(), "agent ID is required")
}
func TestMemoryRegistry_Register_Good_CopySemantics(t *testing.T) {
reg := NewMemoryRegistry()
agent := AgentInfo{
ID: "agent-1",
Name: "Copy Test",
Capabilities: []string{"go"},
Status: AgentAvailable,
}
_ = reg.Register(agent)
// Mutate the original — should not affect the stored copy.
agent.Name = "Mutated"
agent.Capabilities[0] = "rust"
got, _ := reg.Get("agent-1")
assert.Equal(t, "Copy Test", got.Name)
// Note: slice header is copied, but underlying array is shared.
// This is consistent with the MemoryStore pattern in allowance.go.
}
// --- Deregister tests ---
func TestMemoryRegistry_Deregister_Good(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "To Remove"})
err := reg.Deregister("agent-1")
require.NoError(t, err)
_, err = reg.Get("agent-1")
require.Error(t, err)
}
func TestMemoryRegistry_Deregister_Bad_NotFound(t *testing.T) {
reg := NewMemoryRegistry()
err := reg.Deregister("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
// --- Get tests ---
func TestMemoryRegistry_Get_Good(t *testing.T) {
reg := NewMemoryRegistry()
now := time.Now().UTC()
_ = reg.Register(AgentInfo{
ID: "agent-1",
Name: "Getter",
Status: AgentBusy,
CurrentLoad: 2,
MaxLoad: 5,
LastHeartbeat: now,
})
got, err := reg.Get("agent-1")
require.NoError(t, err)
assert.Equal(t, AgentBusy, got.Status)
assert.Equal(t, 2, got.CurrentLoad)
assert.Equal(t, now, got.LastHeartbeat)
}
func TestMemoryRegistry_Get_Bad_NotFound(t *testing.T) {
reg := NewMemoryRegistry()
_, err := reg.Get("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
func TestMemoryRegistry_Get_Good_ReturnsCopy(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", CurrentLoad: 1})
got, _ := reg.Get("agent-1")
got.CurrentLoad = 99
got.Name = "Tampered"
// Re-read — should be unchanged.
again, _ := reg.Get("agent-1")
assert.Equal(t, "Original", again.Name)
assert.Equal(t, 1, again.CurrentLoad)
}
// --- List tests ---
func TestMemoryRegistry_List_Good_Empty(t *testing.T) {
reg := NewMemoryRegistry()
agents := reg.List()
assert.Empty(t, agents)
}
func TestMemoryRegistry_List_Good_Multiple(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{ID: "a", Name: "Alpha"})
_ = reg.Register(AgentInfo{ID: "b", Name: "Beta"})
_ = reg.Register(AgentInfo{ID: "c", Name: "Charlie"})
agents := reg.List()
assert.Len(t, agents, 3)
// Sort by ID for deterministic assertion.
sort.Slice(agents, func(i, j int) bool { return agents[i].ID < agents[j].ID })
assert.Equal(t, "a", agents[0].ID)
assert.Equal(t, "b", agents[1].ID)
assert.Equal(t, "c", agents[2].ID)
}
// --- Heartbeat tests ---
func TestMemoryRegistry_Heartbeat_Good(t *testing.T) {
reg := NewMemoryRegistry()
past := time.Now().UTC().Add(-5 * time.Minute)
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentAvailable,
LastHeartbeat: past,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.True(t, got.LastHeartbeat.After(past))
assert.Equal(t, AgentAvailable, got.Status)
}
func TestMemoryRegistry_Heartbeat_Good_RecoverFromOffline(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentOffline,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.Equal(t, AgentAvailable, got.Status)
}
func TestMemoryRegistry_Heartbeat_Good_BusyStaysBusy(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{
ID: "agent-1",
Status: AgentBusy,
})
err := reg.Heartbeat("agent-1")
require.NoError(t, err)
got, _ := reg.Get("agent-1")
assert.Equal(t, AgentBusy, got.Status)
}
func TestMemoryRegistry_Heartbeat_Bad_NotFound(t *testing.T) {
reg := NewMemoryRegistry()
err := reg.Heartbeat("nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), "agent not found")
}
// --- Reap tests ---
func TestMemoryRegistry_Reap_Good_StaleAgent(t *testing.T) {
reg := NewMemoryRegistry()
stale := time.Now().UTC().Add(-10 * time.Minute)
fresh := time.Now().UTC()
_ = reg.Register(AgentInfo{ID: "stale-1", Status: AgentAvailable, LastHeartbeat: stale})
_ = reg.Register(AgentInfo{ID: "fresh-1", Status: AgentAvailable, LastHeartbeat: fresh})
reaped := reg.Reap(5 * time.Minute)
assert.Len(t, reaped, 1)
assert.Contains(t, reaped, "stale-1")
got, _ := reg.Get("stale-1")
assert.Equal(t, AgentOffline, got.Status)
got, _ = reg.Get("fresh-1")
assert.Equal(t, AgentAvailable, got.Status)
}
func TestMemoryRegistry_Reap_Good_AlreadyOfflineSkipped(t *testing.T) {
reg := NewMemoryRegistry()
stale := time.Now().UTC().Add(-10 * time.Minute)
_ = reg.Register(AgentInfo{ID: "already-off", Status: AgentOffline, LastHeartbeat: stale})
reaped := reg.Reap(5 * time.Minute)
assert.Empty(t, reaped)
}
func TestMemoryRegistry_Reap_Good_NoStaleAgents(t *testing.T) {
reg := NewMemoryRegistry()
now := time.Now().UTC()
_ = reg.Register(AgentInfo{ID: "a", Status: AgentAvailable, LastHeartbeat: now})
_ = reg.Register(AgentInfo{ID: "b", Status: AgentBusy, LastHeartbeat: now})
reaped := reg.Reap(5 * time.Minute)
assert.Empty(t, reaped)
}
func TestMemoryRegistry_Reap_Good_BusyAgentReaped(t *testing.T) {
reg := NewMemoryRegistry()
stale := time.Now().UTC().Add(-10 * time.Minute)
_ = reg.Register(AgentInfo{ID: "busy-stale", Status: AgentBusy, LastHeartbeat: stale})
reaped := reg.Reap(5 * time.Minute)
assert.Len(t, reaped, 1)
assert.Contains(t, reaped, "busy-stale")
got, _ := reg.Get("busy-stale")
assert.Equal(t, AgentOffline, got.Status)
}
// --- Concurrent access ---
func TestMemoryRegistry_Concurrent_Good(t *testing.T) {
reg := NewMemoryRegistry()
var wg sync.WaitGroup
for i := range 20 {
wg.Add(1)
go func(n int) {
defer wg.Done()
id := "agent-" + string(rune('a'+n%5))
_ = reg.Register(AgentInfo{
ID: id,
Name: "Concurrent",
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
})
_, _ = reg.Get(id)
_ = reg.Heartbeat(id)
_ = reg.List()
_ = reg.Reap(1 * time.Minute)
}(i)
}
wg.Wait()
// No race conditions — test passes under -race.
agents := reg.List()
assert.True(t, len(agents) > 0)
}

View file

@ -1,131 +0,0 @@
package lifecycle
import (
"cmp"
"slices"
coreerr "forge.lthn.ai/core/go-log"
)
// ErrNoEligibleAgent is returned when no agent matches the task requirements.
var ErrNoEligibleAgent = coreerr.E("TaskRouter", "no eligible agent for task", nil)
// TaskRouter selects an agent for a given task from a list of candidates.
type TaskRouter interface {
// Route picks the best agent for the task and returns its ID.
// Returns ErrNoEligibleAgent if no agent qualifies.
Route(task *Task, agents []AgentInfo) (string, error)
}
// DefaultRouter implements TaskRouter with capability matching and load-based
// scoring. For critical priority tasks it picks the least-loaded agent directly.
type DefaultRouter struct{}
// NewDefaultRouter creates a new DefaultRouter.
func NewDefaultRouter() *DefaultRouter {
return &DefaultRouter{}
}
// Route selects the best agent for the task:
// 1. Filter by availability (Available, or Busy with capacity).
// 2. Filter by capabilities (task.Labels must be a subset of agent.Capabilities).
// 3. For critical tasks, pick the least-loaded agent.
// 4. For other tasks, score by load ratio and pick the highest-scored agent.
// 5. Ties are broken by agent ID (alphabetical) for determinism.
func (r *DefaultRouter) Route(task *Task, agents []AgentInfo) (string, error) {
eligible := r.filterEligible(task, agents)
if len(eligible) == 0 {
return "", ErrNoEligibleAgent
}
if task.Priority == PriorityCritical {
return r.leastLoaded(eligible), nil
}
return r.highestScored(eligible), nil
}
// filterEligible returns agents that are available (or busy with room) and
// possess all required capabilities.
func (r *DefaultRouter) filterEligible(task *Task, agents []AgentInfo) []AgentInfo {
var result []AgentInfo
for _, a := range agents {
if !r.isAvailable(a) {
continue
}
if !r.hasCapabilities(a, task.Labels) {
continue
}
result = append(result, a)
}
return result
}
// isAvailable returns true if the agent can accept work.
func (r *DefaultRouter) isAvailable(a AgentInfo) bool {
switch a.Status {
case AgentAvailable:
return true
case AgentBusy:
// Busy agents can still accept work if they have capacity.
return a.MaxLoad == 0 || a.CurrentLoad < a.MaxLoad
default:
return false
}
}
// hasCapabilities checks that the agent has all required labels. If the task
// has no labels, any agent qualifies.
func (r *DefaultRouter) hasCapabilities(a AgentInfo, labels []string) bool {
if len(labels) == 0 {
return true
}
for _, label := range labels {
if !slices.Contains(a.Capabilities, label) {
return false
}
}
return true
}
// leastLoaded picks the agent with the lowest CurrentLoad. Ties are broken by
// agent ID (alphabetical).
func (r *DefaultRouter) leastLoaded(agents []AgentInfo) string {
// Sort: lowest load first, then by ID for determinism.
slices.SortFunc(agents, func(a, b AgentInfo) int {
if a.CurrentLoad != b.CurrentLoad {
return cmp.Compare(a.CurrentLoad, b.CurrentLoad)
}
return cmp.Compare(a.ID, b.ID)
})
return agents[0].ID
}
// highestScored picks the agent with the highest availability score.
// Score = 1.0 - (CurrentLoad / MaxLoad). If MaxLoad is 0, score is 1.0.
// Ties are broken by agent ID (alphabetical).
func (r *DefaultRouter) highestScored(agents []AgentInfo) string {
type scored struct {
id string
score float64
}
scores := make([]scored, len(agents))
for i, a := range agents {
s := 1.0
if a.MaxLoad > 0 {
s = 1.0 - float64(a.CurrentLoad)/float64(a.MaxLoad)
}
scores[i] = scored{id: a.ID, score: s}
}
// Sort: highest score first, then by ID for determinism.
slices.SortFunc(scores, func(a, b scored) int {
if a.score != b.score {
return cmp.Compare(b.score, a.score) // highest first
}
return cmp.Compare(a.id, b.id)
})
return scores[0].id
}

View file

@ -1,239 +0,0 @@
package lifecycle
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func makeAgent(id string, status AgentStatus, caps []string, load, maxLoad int) AgentInfo {
return AgentInfo{
ID: id,
Name: id,
Capabilities: caps,
Status: status,
LastHeartbeat: time.Now().UTC(),
CurrentLoad: load,
MaxLoad: maxLoad,
}
}
// --- Capability matching ---
func TestDefaultRouter_Route_Good_MatchesCapabilities(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Labels: []string{"go", "testing"}}
agents := []AgentInfo{
makeAgent("agent-a", AgentAvailable, []string{"go", "testing", "frontend"}, 0, 5),
makeAgent("agent-b", AgentAvailable, []string{"python"}, 0, 5),
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "agent-a", id)
}
func TestDefaultRouter_Route_Good_NoLabelsMatchesAll(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Labels: nil}
agents := []AgentInfo{
makeAgent("agent-a", AgentAvailable, []string{"go"}, 0, 5),
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "agent-a", id)
}
func TestDefaultRouter_Route_Good_EmptyLabelsMatchesAll(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Labels: []string{}}
agents := []AgentInfo{
makeAgent("agent-a", AgentAvailable, nil, 0, 5),
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "agent-a", id)
}
// --- Availability filtering ---
func TestDefaultRouter_Route_Good_SkipsOfflineAgents(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1"}
agents := []AgentInfo{
makeAgent("offline-1", AgentOffline, nil, 0, 5),
makeAgent("online-1", AgentAvailable, nil, 0, 5),
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "online-1", id)
}
func TestDefaultRouter_Route_Good_BusyWithCapacity(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1"}
agents := []AgentInfo{
makeAgent("busy-1", AgentBusy, nil, 2, 5), // has capacity
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "busy-1", id)
}
func TestDefaultRouter_Route_Good_BusyUnlimited(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1"}
agents := []AgentInfo{
makeAgent("busy-unlimited", AgentBusy, nil, 10, 0), // MaxLoad 0 = unlimited
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "busy-unlimited", id)
}
func TestDefaultRouter_Route_Bad_BusyAtCapacity(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1"}
agents := []AgentInfo{
makeAgent("full-1", AgentBusy, nil, 5, 5), // at capacity
}
_, err := router.Route(task, agents)
require.ErrorIs(t, err, ErrNoEligibleAgent)
}
func TestDefaultRouter_Route_Bad_NoAgents(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1"}
_, err := router.Route(task, nil)
require.ErrorIs(t, err, ErrNoEligibleAgent)
}
func TestDefaultRouter_Route_Bad_NoCapableAgent(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Labels: []string{"rust"}}
agents := []AgentInfo{
makeAgent("go-agent", AgentAvailable, []string{"go"}, 0, 5),
makeAgent("py-agent", AgentAvailable, []string{"python"}, 0, 5),
}
_, err := router.Route(task, agents)
require.ErrorIs(t, err, ErrNoEligibleAgent)
}
func TestDefaultRouter_Route_Bad_AllOffline(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1"}
agents := []AgentInfo{
makeAgent("off-1", AgentOffline, nil, 0, 5),
makeAgent("off-2", AgentOffline, nil, 0, 5),
}
_, err := router.Route(task, agents)
require.ErrorIs(t, err, ErrNoEligibleAgent)
}
// --- Load balancing ---
func TestDefaultRouter_Route_Good_LeastLoaded(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Priority: PriorityMedium}
agents := []AgentInfo{
makeAgent("agent-a", AgentAvailable, nil, 3, 10),
makeAgent("agent-b", AgentAvailable, nil, 1, 10),
makeAgent("agent-c", AgentAvailable, nil, 5, 10),
}
id, err := router.Route(task, agents)
require.NoError(t, err)
// agent-b has score 0.9, agent-a has 0.7, agent-c has 0.5
assert.Equal(t, "agent-b", id)
}
func TestDefaultRouter_Route_Good_UnlimitedGetsMaxScore(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Priority: PriorityLow}
agents := []AgentInfo{
makeAgent("limited", AgentAvailable, nil, 1, 10), // score 0.9
makeAgent("unlimited", AgentAvailable, nil, 5, 0), // score 1.0
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "unlimited", id)
}
// --- Critical priority ---
func TestDefaultRouter_Route_Good_CriticalPicksLeastLoaded(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Priority: PriorityCritical}
agents := []AgentInfo{
makeAgent("agent-a", AgentAvailable, nil, 4, 10),
makeAgent("agent-b", AgentAvailable, nil, 1, 5), // lowest absolute load
makeAgent("agent-c", AgentAvailable, nil, 2, 10),
}
id, err := router.Route(task, agents)
require.NoError(t, err)
// Critical: picks least loaded by CurrentLoad, not by ratio.
assert.Equal(t, "agent-b", id)
}
// --- Tie-breaking ---
func TestDefaultRouter_Route_Good_TieBreakByID(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Priority: PriorityMedium}
agents := []AgentInfo{
makeAgent("charlie", AgentAvailable, nil, 0, 5),
makeAgent("alpha", AgentAvailable, nil, 0, 5),
makeAgent("bravo", AgentAvailable, nil, 0, 5),
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "alpha", id)
}
func TestDefaultRouter_Route_Good_CriticalTieBreakByID(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Priority: PriorityCritical}
agents := []AgentInfo{
makeAgent("charlie", AgentAvailable, nil, 0, 5),
makeAgent("alpha", AgentAvailable, nil, 0, 5),
makeAgent("bravo", AgentAvailable, nil, 0, 5),
}
id, err := router.Route(task, agents)
require.NoError(t, err)
assert.Equal(t, "alpha", id)
}
// --- Mixed scenarios ---
func TestDefaultRouter_Route_Good_MixedStatusAndCapabilities(t *testing.T) {
router := NewDefaultRouter()
task := &Task{ID: "t1", Labels: []string{"go"}, Priority: PriorityHigh}
agents := []AgentInfo{
makeAgent("offline-go", AgentOffline, []string{"go"}, 0, 5),
makeAgent("busy-py", AgentBusy, []string{"python"}, 1, 5),
makeAgent("busy-go-full", AgentBusy, []string{"go"}, 5, 5), // at capacity
makeAgent("busy-go-room", AgentBusy, []string{"go"}, 2, 5), // has room
makeAgent("avail-go", AgentAvailable, []string{"go"}, 1, 5), // available
}
id, err := router.Route(task, agents)
require.NoError(t, err)
// avail-go: score = 1.0 - 1/5 = 0.8
// busy-go-room: score = 1.0 - 2/5 = 0.6
assert.Equal(t, "avail-go", id)
}

View file

@ -1,147 +0,0 @@
package lifecycle
import (
"bytes"
"context"
"encoding/json"
"net/http"
"forge.lthn.ai/core/go-log"
)
// ScoreContentRequest is the payload for content scoring.
type ScoreContentRequest struct {
Text string `json:"text"`
Prompt string `json:"prompt,omitempty"`
}
// ScoreImprintRequest is the payload for linguistic imprint analysis.
type ScoreImprintRequest struct {
Text string `json:"text"`
}
// ScoreResult holds the response from the scoring engine.
// The shape is proxied from the EaaS Go binary, so fields are dynamic.
type ScoreResult map[string]any
// ScoreHealthResponse holds the health check result.
type ScoreHealthResponse struct {
Status string `json:"status"`
UpstreamStatus int `json:"upstream_status,omitempty"`
}
// ScoreContent scores text for AI patterns via POST /v1/score/content.
func (c *Client) ScoreContent(ctx context.Context, req ScoreContentRequest) (ScoreResult, error) {
const op = "agentic.Client.ScoreContent"
if req.Text == "" {
return nil, log.E(op, "text is required", nil)
}
data, err := json.Marshal(req)
if err != nil {
return nil, log.E(op, "failed to marshal request", err)
}
endpoint := c.BaseURL + "/v1/score/content"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(httpReq)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result ScoreResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return result, nil
}
// ScoreImprint performs linguistic imprint analysis via POST /v1/score/imprint.
func (c *Client) ScoreImprint(ctx context.Context, req ScoreImprintRequest) (ScoreResult, error) {
const op = "agentic.Client.ScoreImprint"
if req.Text == "" {
return nil, log.E(op, "text is required", nil)
}
data, err := json.Marshal(req)
if err != nil {
return nil, log.E(op, "failed to marshal request", err)
}
endpoint := c.BaseURL + "/v1/score/imprint"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(httpReq)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result ScoreResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return result, nil
}
// ScoreHealth checks the scoring engine health via GET /v1/score/health.
// This endpoint does not require authentication.
func (c *Client) ScoreHealth(ctx context.Context) (*ScoreHealthResponse, error) {
const op = "agentic.Client.ScoreHealth"
endpoint := c.BaseURL + "/v1/score/health"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
// Health endpoint is unauthenticated but we still set headers for consistency.
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("User-Agent", "core-agentic-client/1.0")
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result ScoreHealthResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &result, nil
}

View file

@ -1,166 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClient_ScoreContent_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/v1/score/content", r.URL.Path)
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
var req ScoreContentRequest
err := json.NewDecoder(r.Body).Decode(&req)
require.NoError(t, err)
assert.Contains(t, req.Text, "sample text for scoring")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"score": 0.23,
"confidence": 0.91,
"label": "human",
})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.ScoreContent(context.Background(), ScoreContentRequest{
Text: "This is some sample text for scoring that is at least twenty characters",
})
require.NoError(t, err)
assert.InDelta(t, 0.23, result["score"], 0.001)
assert.Equal(t, "human", result["label"])
}
func TestClient_ScoreContent_Good_WithPrompt(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req ScoreContentRequest
err := json.NewDecoder(r.Body).Decode(&req)
require.NoError(t, err)
assert.Equal(t, "Check for formality", req.Prompt)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{"score": 0.5})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.ScoreContent(context.Background(), ScoreContentRequest{
Text: "This text should be checked for formality and style patterns",
Prompt: "Check for formality",
})
require.NoError(t, err)
assert.InDelta(t, 0.5, result["score"], 0.001)
}
func TestClient_ScoreContent_Bad_EmptyText(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
result, err := client.ScoreContent(context.Background(), ScoreContentRequest{})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "text is required")
}
func TestClient_ScoreContent_Bad_ServiceDown(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": "scoring_unavailable",
"message": "Could not reach the scoring service.",
})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.ScoreContent(context.Background(), ScoreContentRequest{
Text: "This text needs at least twenty characters to validate",
})
assert.Error(t, err)
assert.Nil(t, result)
}
func TestClient_ScoreImprint_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/v1/score/imprint", r.URL.Path)
var req ScoreImprintRequest
err := json.NewDecoder(r.Body).Decode(&req)
require.NoError(t, err)
assert.NotEmpty(t, req.Text)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"imprint": "abc123def456",
"confidence": 0.88,
})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.ScoreImprint(context.Background(), ScoreImprintRequest{
Text: "This text has a distinct linguistic imprint pattern to analyse",
})
require.NoError(t, err)
assert.Equal(t, "abc123def456", result["imprint"])
}
func TestClient_ScoreImprint_Bad_EmptyText(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
result, err := client.ScoreImprint(context.Background(), ScoreImprintRequest{})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "text is required")
}
func TestClient_ScoreHealth_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/v1/score/health", r.URL.Path)
// Health endpoint should not require auth token
assert.Empty(t, r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(ScoreHealthResponse{
Status: "healthy",
})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.ScoreHealth(context.Background())
require.NoError(t, err)
assert.Equal(t, "healthy", result.Status)
}
func TestClient_ScoreHealth_Bad_Unhealthy(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
_ = json.NewEncoder(w).Encode(ScoreHealthResponse{
Status: "unhealthy",
UpstreamStatus: 503,
})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
result, err := client.ScoreHealth(context.Background())
assert.Error(t, err)
assert.Nil(t, result)
}

View file

@ -1,142 +0,0 @@
package lifecycle
import (
"context"
"os"
"os/exec"
"strings"
"forge.lthn.ai/core/go/pkg/core"
"forge.lthn.ai/core/go-log"
)
// Tasks for AI service
// TaskCommit requests Claude to create a commit.
type TaskCommit struct {
Path string
Name string
CanEdit bool // allow Write/Edit tools
}
// TaskPrompt sends a custom prompt to Claude.
type TaskPrompt struct {
Prompt string
WorkDir string
AllowedTools []string
taskID string
}
func (t *TaskPrompt) SetTaskID(id string) { t.taskID = id }
func (t *TaskPrompt) GetTaskID() string { return t.taskID }
// ServiceOptions for configuring the AI service.
type ServiceOptions struct {
DefaultTools []string
AllowEdit bool // global permission for Write/Edit tools
}
// DefaultServiceOptions returns sensible defaults.
func DefaultServiceOptions() ServiceOptions {
return ServiceOptions{
DefaultTools: []string{"Bash", "Read", "Glob", "Grep"},
AllowEdit: false,
}
}
// Service provides AI/Claude operations as a Core service.
type Service struct {
*core.ServiceRuntime[ServiceOptions]
}
// NewService creates an AI service factory.
func NewService(opts ServiceOptions) func(*core.Core) (any, error) {
return func(c *core.Core) (any, error) {
return &Service{
ServiceRuntime: core.NewServiceRuntime(c, opts),
}, nil
}
}
// OnStartup registers task handlers.
func (s *Service) OnStartup(ctx context.Context) error {
s.Core().RegisterTask(s.handleTask)
return nil
}
func (s *Service) handleTask(c *core.Core, t core.Task) (any, bool, error) {
switch m := t.(type) {
case TaskCommit:
err := s.doCommit(m)
if err != nil {
log.Error("agentic: commit task failed", "err", err, "path", m.Path)
}
return nil, true, err
case TaskPrompt:
err := s.doPrompt(m)
if err != nil {
log.Error("agentic: prompt task failed", "err", err)
}
return nil, true, err
}
return nil, false, nil
}
func (s *Service) doCommit(task TaskCommit) error {
prompt := Prompt("commit")
tools := []string{"Bash", "Read", "Glob", "Grep"}
if task.CanEdit {
tools = []string{"Bash", "Read", "Write", "Edit", "Glob", "Grep"}
}
cmd := exec.CommandContext(context.Background(), "claude", "-p", prompt, "--allowedTools", strings.Join(tools, ","))
cmd.Dir = task.Path
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = os.Stdin
return cmd.Run()
}
func (s *Service) doPrompt(task TaskPrompt) error {
if task.taskID != "" {
s.Core().Progress(task.taskID, 0.1, "Starting Claude...", &task)
}
opts := s.Opts()
tools := opts.DefaultTools
if len(tools) == 0 {
tools = []string{"Bash", "Read", "Glob", "Grep"}
}
if len(task.AllowedTools) > 0 {
tools = task.AllowedTools
}
cmd := exec.CommandContext(context.Background(), "claude", "-p", task.Prompt, "--allowedTools", strings.Join(tools, ","))
if task.WorkDir != "" {
cmd.Dir = task.WorkDir
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = os.Stdin
if task.taskID != "" {
s.Core().Progress(task.taskID, 0.5, "Running Claude prompt...", &task)
}
err := cmd.Run()
if task.taskID != "" {
if err != nil {
s.Core().Progress(task.taskID, 1.0, "Failed: "+err.Error(), &task)
} else {
s.Core().Progress(task.taskID, 1.0, "Completed", &task)
}
}
return err
}

View file

@ -1,79 +0,0 @@
package lifecycle
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDefaultServiceOptions_Good(t *testing.T) {
opts := DefaultServiceOptions()
assert.Equal(t, []string{"Bash", "Read", "Glob", "Grep"}, opts.DefaultTools)
assert.False(t, opts.AllowEdit, "default should not allow edit")
}
func TestTaskPrompt_SetGetTaskID(t *testing.T) {
tp := &TaskPrompt{
Prompt: "test prompt",
WorkDir: "/tmp",
}
assert.Empty(t, tp.GetTaskID(), "should start empty")
tp.SetTaskID("task-abc-123")
assert.Equal(t, "task-abc-123", tp.GetTaskID())
tp.SetTaskID("task-def-456")
assert.Equal(t, "task-def-456", tp.GetTaskID(), "should allow overwriting")
}
func TestTaskPrompt_SetGetTaskID_Empty(t *testing.T) {
tp := &TaskPrompt{}
tp.SetTaskID("")
assert.Empty(t, tp.GetTaskID())
}
func TestTaskCommit_Fields(t *testing.T) {
tc := TaskCommit{
Path: "/home/user/project",
Name: "test-commit",
CanEdit: true,
}
assert.Equal(t, "/home/user/project", tc.Path)
assert.Equal(t, "test-commit", tc.Name)
assert.True(t, tc.CanEdit)
}
func TestTaskCommit_DefaultCanEdit(t *testing.T) {
tc := TaskCommit{
Path: "/tmp",
Name: "no-edit",
}
assert.False(t, tc.CanEdit, "default should be false")
}
func TestServiceOptions_CustomTools(t *testing.T) {
opts := ServiceOptions{
DefaultTools: []string{"Bash", "Read", "Write", "Edit"},
AllowEdit: true,
}
assert.Len(t, opts.DefaultTools, 4)
assert.True(t, opts.AllowEdit)
}
func TestTaskPrompt_AllFields(t *testing.T) {
tp := TaskPrompt{
Prompt: "Refactor the authentication module",
WorkDir: "/home/user/project",
AllowedTools: []string{"Bash", "Read", "Edit"},
}
assert.Equal(t, "Refactor the authentication module", tp.Prompt)
assert.Equal(t, "/home/user/project", tp.WorkDir)
assert.Equal(t, []string{"Bash", "Read", "Edit"}, tp.AllowedTools)
}

View file

@ -1,287 +0,0 @@
package lifecycle
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"forge.lthn.ai/core/go-log"
)
// SessionStatus represents the state of a session.
type SessionStatus string
const (
SessionActive SessionStatus = "active"
SessionPaused SessionStatus = "paused"
SessionCompleted SessionStatus = "completed"
SessionFailed SessionStatus = "failed"
)
// Session represents an agent session from the PHP API.
type Session struct {
SessionID string `json:"session_id"`
AgentType string `json:"agent_type"`
Status SessionStatus `json:"status"`
PlanSlug string `json:"plan_slug,omitempty"`
Plan string `json:"plan,omitempty"`
Duration string `json:"duration,omitempty"`
StartedAt string `json:"started_at,omitempty"`
LastActiveAt string `json:"last_active_at,omitempty"`
EndedAt string `json:"ended_at,omitempty"`
ActionCount int `json:"action_count,omitempty"`
ArtifactCount int `json:"artifact_count,omitempty"`
ContextSummary map[string]any `json:"context_summary,omitempty"`
HandoffNotes string `json:"handoff_notes,omitempty"`
ContinuedFrom string `json:"continued_from,omitempty"`
}
// StartSessionRequest is the payload for starting a new session.
type StartSessionRequest struct {
AgentType string `json:"agent_type"`
PlanSlug string `json:"plan_slug,omitempty"`
Context map[string]any `json:"context,omitempty"`
}
// EndSessionRequest is the payload for ending a session.
type EndSessionRequest struct {
Status string `json:"status"`
Summary string `json:"summary,omitempty"`
}
// ListSessionOptions specifies filters for listing sessions.
type ListSessionOptions struct {
Status SessionStatus `json:"status,omitempty"`
PlanSlug string `json:"plan_slug,omitempty"`
Limit int `json:"limit,omitempty"`
}
// sessionListResponse wraps the list endpoint response.
type sessionListResponse struct {
Sessions []Session `json:"sessions"`
Total int `json:"total"`
}
// sessionStartResponse wraps the session create endpoint response.
type sessionStartResponse struct {
SessionID string `json:"session_id"`
AgentType string `json:"agent_type"`
Plan string `json:"plan,omitempty"`
Status string `json:"status"`
}
// sessionEndResponse wraps the session end endpoint response.
type sessionEndResponse struct {
SessionID string `json:"session_id"`
Status string `json:"status"`
Duration string `json:"duration,omitempty"`
}
// sessionContinueResponse wraps the session continue endpoint response.
type sessionContinueResponse struct {
SessionID string `json:"session_id"`
AgentType string `json:"agent_type"`
Plan string `json:"plan,omitempty"`
Status string `json:"status"`
ContinuedFrom string `json:"continued_from,omitempty"`
}
// ListSessions retrieves sessions matching the given options.
func (c *Client) ListSessions(ctx context.Context, opts ListSessionOptions) ([]Session, error) {
const op = "agentic.Client.ListSessions"
params := url.Values{}
if opts.Status != "" {
params.Set("status", string(opts.Status))
}
if opts.PlanSlug != "" {
params.Set("plan_slug", opts.PlanSlug)
}
if opts.Limit > 0 {
params.Set("limit", strconv.Itoa(opts.Limit))
}
endpoint := c.BaseURL + "/v1/sessions"
if len(params) > 0 {
endpoint += "?" + params.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result sessionListResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return result.Sessions, nil
}
// GetSession retrieves a session by ID.
func (c *Client) GetSession(ctx context.Context, sessionID string) (*Session, error) {
const op = "agentic.Client.GetSession"
if sessionID == "" {
return nil, log.E(op, "session ID is required", nil)
}
endpoint := fmt.Sprintf("%s/v1/sessions/%s", c.BaseURL, url.PathEscape(sessionID))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var session Session
if err := json.NewDecoder(resp.Body).Decode(&session); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &session, nil
}
// StartSession starts a new agent session.
func (c *Client) StartSession(ctx context.Context, req StartSessionRequest) (*sessionStartResponse, error) {
const op = "agentic.Client.StartSession"
if req.AgentType == "" {
return nil, log.E(op, "agent_type is required", nil)
}
data, err := json.Marshal(req)
if err != nil {
return nil, log.E(op, "failed to marshal request", err)
}
endpoint := c.BaseURL + "/v1/sessions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(httpReq)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result sessionStartResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &result, nil
}
// EndSession ends a session with a final status and optional summary.
func (c *Client) EndSession(ctx context.Context, sessionID string, status string, summary string) error {
const op = "agentic.Client.EndSession"
if sessionID == "" {
return log.E(op, "session ID is required", nil)
}
if status == "" {
return log.E(op, "status is required", nil)
}
payload := EndSessionRequest{Status: status, Summary: summary}
data, err := json.Marshal(payload)
if err != nil {
return log.E(op, "failed to marshal request", err)
}
endpoint := fmt.Sprintf("%s/v1/sessions/%s/end", c.BaseURL, url.PathEscape(sessionID))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
return c.checkResponse(resp)
}
// ContinueSession creates a new session continuing from a previous one (multi-agent handoff).
func (c *Client) ContinueSession(ctx context.Context, previousSessionID, agentType string) (*sessionContinueResponse, error) {
const op = "agentic.Client.ContinueSession"
if previousSessionID == "" {
return nil, log.E(op, "previous session ID is required", nil)
}
if agentType == "" {
return nil, log.E(op, "agent_type is required", nil)
}
data, err := json.Marshal(map[string]string{"agent_type": agentType})
if err != nil {
return nil, log.E(op, "failed to marshal request", err)
}
endpoint := fmt.Sprintf("%s/v1/sessions/%s/continue", c.BaseURL, url.PathEscape(previousSessionID))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, log.E(op, "failed to create request", err)
}
c.setHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, log.E(op, "request failed", err)
}
defer func() { _ = resp.Body.Close() }()
if err := c.checkResponse(resp); err != nil {
return nil, log.E(op, "API error", err)
}
var result sessionContinueResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, log.E(op, "failed to decode response", err)
}
return &result, nil
}

View file

@ -1,137 +0,0 @@
package lifecycle
import (
"cmp"
"context"
"fmt"
"slices"
"strings"
"forge.lthn.ai/core/go-log"
)
// StatusSummary aggregates status from the agent registry, task client, and
// allowance service for CLI display.
type StatusSummary struct {
// Agents is the list of registered agents.
Agents []AgentInfo
// PendingTasks is the count of tasks with StatusPending.
PendingTasks int
// InProgressTasks is the count of tasks with StatusInProgress.
InProgressTasks int
// AllowanceRemaining maps agent ID to remaining daily tokens. -1 means unlimited.
AllowanceRemaining map[string]int64
}
// GetStatus aggregates status from the registry, client, and allowance service.
// Any of registry, client, or allowanceSvc can be nil -- those sections are
// simply skipped. Returns what we can collect without failing on nil components.
func GetStatus(ctx context.Context, registry AgentRegistry, client *Client, allowanceSvc *AllowanceService) (*StatusSummary, error) {
const op = "agentic.GetStatus"
summary := &StatusSummary{
AllowanceRemaining: make(map[string]int64),
}
// Collect agents from registry.
if registry != nil {
summary.Agents = registry.List()
}
// Count tasks by status via client.
if client != nil {
pending, err := client.ListTasks(ctx, ListOptions{Status: StatusPending})
if err != nil {
return nil, log.E(op, "failed to list pending tasks", err)
}
summary.PendingTasks = len(pending)
inProgress, err := client.ListTasks(ctx, ListOptions{Status: StatusInProgress})
if err != nil {
return nil, log.E(op, "failed to list in-progress tasks", err)
}
summary.InProgressTasks = len(inProgress)
}
// Collect allowance remaining per agent.
if allowanceSvc != nil {
for _, agent := range summary.Agents {
check, err := allowanceSvc.Check(agent.ID, "")
if err != nil {
// Skip agents whose allowance cannot be resolved.
continue
}
summary.AllowanceRemaining[agent.ID] = check.RemainingTokens
}
}
return summary, nil
}
// FormatStatus renders the summary as a human-readable table string suitable
// for CLI output.
func FormatStatus(s *StatusSummary) string {
var b strings.Builder
// Count agents by status.
available := 0
busy := 0
for _, a := range s.Agents {
switch a.Status {
case AgentAvailable:
available++
case AgentBusy:
busy++
}
}
total := len(s.Agents)
statusParts := make([]string, 0, 2)
if available > 0 {
statusParts = append(statusParts, fmt.Sprintf("%d available", available))
}
if busy > 0 {
statusParts = append(statusParts, fmt.Sprintf("%d busy", busy))
}
offline := total - available - busy
if offline > 0 {
statusParts = append(statusParts, fmt.Sprintf("%d offline", offline))
}
if len(statusParts) > 0 {
fmt.Fprintf(&b, "Agents: %d (%s)\n", total, strings.Join(statusParts, ", "))
} else {
fmt.Fprintf(&b, "Agents: %d\n", total)
}
fmt.Fprintf(&b, "Tasks: %d pending, %d in progress\n", s.PendingTasks, s.InProgressTasks)
if len(s.Agents) > 0 {
// Sort agents by ID for deterministic output.
agents := slices.Clone(s.Agents)
slices.SortFunc(agents, func(a, b AgentInfo) int {
return cmp.Compare(a.ID, b.ID)
})
fmt.Fprintf(&b, "%-16s%-12s%-8s%s\n", "Agent", "Status", "Load", "Remaining")
for _, a := range agents {
load := fmt.Sprintf("%d/%d", a.CurrentLoad, a.MaxLoad)
if a.MaxLoad == 0 {
load = fmt.Sprintf("%d/-", a.CurrentLoad)
}
remaining := "unknown"
if tokens, ok := s.AllowanceRemaining[a.ID]; ok {
if tokens < 0 {
remaining = "unlimited"
} else {
remaining = fmt.Sprintf("%d tokens", tokens)
}
}
fmt.Fprintf(&b, "%-16s%-12s%-8s%s\n", a.ID, string(a.Status), load, remaining)
}
}
return b.String()
}

View file

@ -1,270 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- GetStatus tests ---
func TestGetStatus_Good_AllNil(t *testing.T) {
summary, err := GetStatus(context.Background(), nil, nil, nil)
require.NoError(t, err)
assert.Empty(t, summary.Agents)
assert.Equal(t, 0, summary.PendingTasks)
assert.Equal(t, 0, summary.InProgressTasks)
assert.Empty(t, summary.AllowanceRemaining)
}
func TestGetStatus_Good_RegistryOnly(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{
ID: "virgil",
Name: "Virgil",
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
MaxLoad: 5,
})
_ = reg.Register(AgentInfo{
ID: "charon",
Name: "Charon",
Status: AgentBusy,
CurrentLoad: 3,
MaxLoad: 5,
LastHeartbeat: time.Now().UTC(),
})
summary, err := GetStatus(context.Background(), reg, nil, nil)
require.NoError(t, err)
assert.Len(t, summary.Agents, 2)
assert.Equal(t, 0, summary.PendingTasks)
assert.Equal(t, 0, summary.InProgressTasks)
}
func TestGetStatus_Good_FullSummary(t *testing.T) {
// Set up mock server returning task counts.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
status := r.URL.Query().Get("status")
w.Header().Set("Content-Type", "application/json")
switch status {
case "pending":
tasks := []Task{
{ID: "t1", Status: StatusPending},
{ID: "t2", Status: StatusPending},
{ID: "t3", Status: StatusPending},
}
_ = json.NewEncoder(w).Encode(tasks)
case "in_progress":
tasks := []Task{
{ID: "t4", Status: StatusInProgress},
}
_ = json.NewEncoder(w).Encode(tasks)
default:
_ = json.NewEncoder(w).Encode([]Task{})
}
}))
defer server.Close()
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{
ID: "virgil",
Name: "Virgil",
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
MaxLoad: 5,
})
_ = reg.Register(AgentInfo{
ID: "charon",
Name: "Charon",
Status: AgentBusy,
CurrentLoad: 3,
MaxLoad: 5,
LastHeartbeat: time.Now().UTC(),
})
store := NewMemoryStore()
_ = store.SetAllowance(&AgentAllowance{
AgentID: "virgil",
DailyTokenLimit: 50000,
})
_ = store.SetAllowance(&AgentAllowance{
AgentID: "charon",
DailyTokenLimit: 50000,
})
// Simulate charon has used 38000 tokens.
_ = store.IncrementUsage("charon", 38000, 0)
svc := NewAllowanceService(store)
client := NewClient(server.URL, "test-token")
summary, err := GetStatus(context.Background(), reg, client, svc)
require.NoError(t, err)
assert.Len(t, summary.Agents, 2)
assert.Equal(t, 3, summary.PendingTasks)
assert.Equal(t, 1, summary.InProgressTasks)
assert.Equal(t, int64(50000), summary.AllowanceRemaining["virgil"])
assert.Equal(t, int64(12000), summary.AllowanceRemaining["charon"])
}
func TestGetStatus_Good_UnlimitedAllowance(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{
ID: "darbs",
Name: "Darbs",
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
MaxLoad: 3,
})
store := NewMemoryStore()
// DailyTokenLimit 0 means unlimited.
_ = store.SetAllowance(&AgentAllowance{
AgentID: "darbs",
DailyTokenLimit: 0,
})
svc := NewAllowanceService(store)
summary, err := GetStatus(context.Background(), reg, nil, svc)
require.NoError(t, err)
// Unlimited: Check returns RemainingTokens = -1.
assert.Equal(t, int64(-1), summary.AllowanceRemaining["darbs"])
}
func TestGetStatus_Good_AllowanceSkipsUnknownAgents(t *testing.T) {
reg := NewMemoryRegistry()
_ = reg.Register(AgentInfo{
ID: "unknown-agent",
Name: "Unknown",
Status: AgentAvailable,
LastHeartbeat: time.Now().UTC(),
})
store := NewMemoryStore()
// No allowance set for "unknown-agent" -- GetAllowance will error.
svc := NewAllowanceService(store)
summary, err := GetStatus(context.Background(), reg, nil, svc)
require.NoError(t, err)
// AllowanceRemaining should not have an entry for unknown-agent.
_, exists := summary.AllowanceRemaining["unknown-agent"]
assert.False(t, exists)
}
func TestGetStatus_Bad_ClientError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(APIError{Message: "server error"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
summary, err := GetStatus(context.Background(), nil, client, nil)
assert.Error(t, err)
assert.Nil(t, summary)
assert.Contains(t, err.Error(), "pending tasks")
}
// --- FormatStatus tests ---
func TestFormatStatus_Good_Empty(t *testing.T) {
s := &StatusSummary{
AllowanceRemaining: make(map[string]int64),
}
output := FormatStatus(s)
assert.Contains(t, output, "Agents: 0")
assert.Contains(t, output, "Tasks: 0 pending, 0 in progress")
// No agent table rows when there are no agents — only the summary lines.
assert.NotContains(t, output, "Status")
}
func TestFormatStatus_Good_FullTable(t *testing.T) {
s := &StatusSummary{
Agents: []AgentInfo{
{ID: "virgil", Status: AgentAvailable, CurrentLoad: 0, MaxLoad: 5},
{ID: "charon", Status: AgentBusy, CurrentLoad: 3, MaxLoad: 5},
{ID: "darbs", Status: AgentAvailable, CurrentLoad: 0, MaxLoad: 3},
},
PendingTasks: 5,
InProgressTasks: 2,
AllowanceRemaining: map[string]int64{
"virgil": 45000,
"charon": 12000,
"darbs": -1,
},
}
output := FormatStatus(s)
assert.Contains(t, output, "Agents: 3 (2 available, 1 busy)")
assert.Contains(t, output, "Tasks: 5 pending, 2 in progress")
assert.Contains(t, output, "virgil")
assert.Contains(t, output, "available")
assert.Contains(t, output, "45000 tokens")
assert.Contains(t, output, "charon")
assert.Contains(t, output, "busy")
assert.Contains(t, output, "12000 tokens")
assert.Contains(t, output, "darbs")
assert.Contains(t, output, "unlimited")
// Verify deterministic sort order (agents sorted by ID).
lines := strings.Split(output, "\n")
var agentLines []string
for _, line := range lines {
if strings.HasPrefix(line, "charon") || strings.HasPrefix(line, "darbs") || strings.HasPrefix(line, "virgil") {
agentLines = append(agentLines, line)
}
}
require.Len(t, agentLines, 3)
assert.True(t, strings.HasPrefix(agentLines[0], "charon"))
assert.True(t, strings.HasPrefix(agentLines[1], "darbs"))
assert.True(t, strings.HasPrefix(agentLines[2], "virgil"))
}
func TestFormatStatus_Good_OfflineAgent(t *testing.T) {
s := &StatusSummary{
Agents: []AgentInfo{
{ID: "offline-bot", Status: AgentOffline, CurrentLoad: 0, MaxLoad: 5},
},
AllowanceRemaining: map[string]int64{
"offline-bot": 30000,
},
}
output := FormatStatus(s)
assert.Contains(t, output, "1 offline")
assert.Contains(t, output, "offline-bot")
}
func TestFormatStatus_Good_UnlimitedMaxLoad(t *testing.T) {
s := &StatusSummary{
Agents: []AgentInfo{
{ID: "unlimited", Status: AgentAvailable, CurrentLoad: 2, MaxLoad: 0},
},
AllowanceRemaining: map[string]int64{
"unlimited": -1,
},
}
output := FormatStatus(s)
assert.Contains(t, output, "2/-")
assert.Contains(t, output, "unlimited")
}
func TestFormatStatus_Good_UnknownAllowance(t *testing.T) {
s := &StatusSummary{
Agents: []AgentInfo{
{ID: "mystery", Status: AgentAvailable, MaxLoad: 5},
},
AllowanceRemaining: make(map[string]int64),
}
output := FormatStatus(s)
assert.Contains(t, output, "unknown")
}

View file

@ -1,35 +0,0 @@
package lifecycle
import (
"context"
"time"
"forge.lthn.ai/core/go-log"
)
// SubmitTask creates a new task with the given parameters via the API client.
// It validates that title is non-empty, sets CreatedAt to the current time,
// and delegates creation to client.CreateTask.
func SubmitTask(ctx context.Context, client *Client, title, description string, labels []string, priority TaskPriority) (*Task, error) {
const op = "agentic.SubmitTask"
if title == "" {
return nil, log.E(op, "title is required", nil)
}
task := Task{
Title: title,
Description: description,
Labels: labels,
Priority: priority,
Status: StatusPending,
CreatedAt: time.Now().UTC(),
}
created, err := client.CreateTask(ctx, task)
if err != nil {
return nil, log.E(op, "failed to create task", err)
}
return created, nil
}

View file

@ -1,134 +0,0 @@
package lifecycle
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Client.CreateTask tests ---
func TestClient_CreateTask_Good(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/api/tasks", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
var task Task
err := json.NewDecoder(r.Body).Decode(&task)
require.NoError(t, err)
assert.Equal(t, "New feature", task.Title)
assert.Equal(t, PriorityHigh, task.Priority)
// Return the task with an assigned ID.
task.ID = "task-new-1"
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(task)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task := Task{
Title: "New feature",
Description: "Build something great",
Priority: PriorityHigh,
Labels: []string{"feature"},
Status: StatusPending,
}
created, err := client.CreateTask(context.Background(), task)
require.NoError(t, err)
assert.Equal(t, "task-new-1", created.ID)
assert.Equal(t, "New feature", created.Title)
assert.Equal(t, PriorityHigh, created.Priority)
}
func TestClient_CreateTask_Bad_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(APIError{Message: "validation failed"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
task := Task{Title: "Bad task"}
created, err := client.CreateTask(context.Background(), task)
assert.Error(t, err)
assert.Nil(t, created)
assert.Contains(t, err.Error(), "validation failed")
}
// --- SubmitTask tests ---
func TestSubmitTask_Good_AllFields(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var task Task
err := json.NewDecoder(r.Body).Decode(&task)
require.NoError(t, err)
assert.Equal(t, "Implement login", task.Title)
assert.Equal(t, "OAuth2 login flow", task.Description)
assert.Equal(t, []string{"auth", "frontend"}, task.Labels)
assert.Equal(t, PriorityHigh, task.Priority)
assert.Equal(t, StatusPending, task.Status)
assert.False(t, task.CreatedAt.IsZero())
task.ID = "task-submit-1"
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(task)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
created, err := SubmitTask(context.Background(), client, "Implement login", "OAuth2 login flow", []string{"auth", "frontend"}, PriorityHigh)
require.NoError(t, err)
assert.Equal(t, "task-submit-1", created.ID)
assert.Equal(t, "Implement login", created.Title)
}
func TestSubmitTask_Good_MinimalFields(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var task Task
_ = json.NewDecoder(r.Body).Decode(&task)
task.ID = "task-minimal"
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(task)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
created, err := SubmitTask(context.Background(), client, "Simple task", "", nil, PriorityLow)
require.NoError(t, err)
assert.Equal(t, "task-minimal", created.ID)
}
func TestSubmitTask_Bad_EmptyTitle(t *testing.T) {
client := NewClient("https://api.example.com", "test-token")
created, err := SubmitTask(context.Background(), client, "", "description", nil, PriorityMedium)
assert.Error(t, err)
assert.Nil(t, created)
assert.Contains(t, err.Error(), "title is required")
}
func TestSubmitTask_Bad_ClientError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(APIError{Message: "internal error"})
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
created, err := SubmitTask(context.Background(), client, "Good title", "", nil, PriorityMedium)
assert.Error(t, err)
assert.Nil(t, created)
assert.Contains(t, err.Error(), "create task")
}

View file

@ -1,150 +0,0 @@
// Package agentic provides an API client for core-agentic, an AI-assisted task
// management service. It enables developers and AI agents to discover, claim,
// and complete development tasks.
package lifecycle
import (
"time"
)
// TaskStatus represents the state of a task in the system.
type TaskStatus string
const (
// StatusPending indicates the task is available to be claimed.
StatusPending TaskStatus = "pending"
// StatusInProgress indicates the task has been claimed and is being worked on.
StatusInProgress TaskStatus = "in_progress"
// StatusCompleted indicates the task has been successfully completed.
StatusCompleted TaskStatus = "completed"
// StatusBlocked indicates the task cannot proceed due to dependencies.
StatusBlocked TaskStatus = "blocked"
// StatusFailed indicates the task has exceeded its retry limit and been dead-lettered.
StatusFailed TaskStatus = "failed"
)
// TaskPriority represents the urgency level of a task.
type TaskPriority string
const (
// PriorityCritical indicates the task requires immediate attention.
PriorityCritical TaskPriority = "critical"
// PriorityHigh indicates the task is important and should be addressed soon.
PriorityHigh TaskPriority = "high"
// PriorityMedium indicates the task has normal priority.
PriorityMedium TaskPriority = "medium"
// PriorityLow indicates the task can be addressed when time permits.
PriorityLow TaskPriority = "low"
)
// Task represents a development task in the core-agentic system.
type Task struct {
// ID is the unique identifier for the task.
ID string `json:"id"`
// Title is the short description of the task.
Title string `json:"title"`
// Description provides detailed information about what needs to be done.
Description string `json:"description"`
// Priority indicates the urgency of the task.
Priority TaskPriority `json:"priority"`
// Status indicates the current state of the task.
Status TaskStatus `json:"status"`
// Labels are tags used to categorize the task.
Labels []string `json:"labels,omitempty"`
// Files lists the files that are relevant to this task.
Files []string `json:"files,omitempty"`
// CreatedAt is when the task was created.
CreatedAt time.Time `json:"created_at"`
// UpdatedAt is when the task was last modified.
UpdatedAt time.Time `json:"updated_at"`
// ClaimedBy is the identifier of the agent or developer who claimed the task.
ClaimedBy string `json:"claimed_by,omitempty"`
// ClaimedAt is when the task was claimed.
ClaimedAt *time.Time `json:"claimed_at,omitempty"`
// Project is the project this task belongs to.
Project string `json:"project,omitempty"`
// Dependencies lists task IDs that must be completed before this task.
Dependencies []string `json:"dependencies,omitempty"`
// Blockers lists task IDs that this task is blocking.
Blockers []string `json:"blockers,omitempty"`
// MaxRetries is the maximum dispatch attempts before dead-lettering. 0 uses DefaultMaxRetries.
MaxRetries int `json:"max_retries,omitempty"`
// RetryCount is the number of failed dispatch attempts so far.
RetryCount int `json:"retry_count,omitempty"`
// LastAttempt is when the last dispatch attempt occurred.
LastAttempt *time.Time `json:"last_attempt,omitempty"`
// FailReason explains why the task was moved to failed status.
FailReason string `json:"fail_reason,omitempty"`
}
// TaskUpdate contains fields that can be updated on a task.
type TaskUpdate struct {
// Status is the new status for the task.
Status TaskStatus `json:"status,omitempty"`
// Progress is a percentage (0-100) indicating completion.
Progress int `json:"progress,omitempty"`
// Notes are additional comments about the update.
Notes string `json:"notes,omitempty"`
}
// TaskResult contains the outcome of a completed task.
type TaskResult struct {
// Success indicates whether the task was completed successfully.
Success bool `json:"success"`
// Output is the result or summary of the completed work.
Output string `json:"output,omitempty"`
// Artifacts are files or resources produced by the task.
Artifacts []string `json:"artifacts,omitempty"`
// ErrorMessage contains details if the task failed.
ErrorMessage string `json:"error_message,omitempty"`
}
// ListOptions specifies filters for listing tasks.
type ListOptions struct {
// Status filters tasks by their current status.
Status TaskStatus `json:"status,omitempty"`
// Labels filters tasks that have all specified labels.
Labels []string `json:"labels,omitempty"`
// Priority filters tasks by priority level.
Priority TaskPriority `json:"priority,omitempty"`
// Limit is the maximum number of tasks to return.
Limit int `json:"limit,omitempty"`
// Project filters tasks by project.
Project string `json:"project,omitempty"`
// ClaimedBy filters tasks claimed by a specific agent.
ClaimedBy string `json:"claimed_by,omitempty"`
}
// APIError represents an error response from the API.
type APIError struct {
// Code is the HTTP status code.
Code int `json:"code"`
// Message is the error description.
Message string `json:"message"`
// Details provides additional context about the error.
Details string `json:"details,omitempty"`
}
// Error implements the error interface for APIError.
func (e *APIError) Error() string {
if e.Details != "" {
return e.Message + ": " + e.Details
}
return e.Message
}
// ClaimResponse is returned when a task is successfully claimed.
type ClaimResponse struct {
// Task is the claimed task with updated fields.
Task *Task `json:"task"`
// Message provides additional context about the claim.
Message string `json:"message,omitempty"`
}
// CompleteResponse is returned when a task is completed.
type CompleteResponse struct {
// Task is the completed task with final status.
Task *Task `json:"task"`
// Message provides additional context about the completion.
Message string `json:"message,omitempty"`
}

Some files were not shown because too many files have changed in this diff Show more